mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-07 01:33:43 +08:00
Compare commits
157 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e6f14cb9e | ||
|
|
8dc03a7509 | ||
|
|
57b1b44645 | ||
|
|
aa17a33093 | ||
|
|
80e27c40e9 | ||
|
|
8250e876a5 | ||
|
|
9f98491368 | ||
|
|
fe160f978b | ||
|
|
7da5b7163c | ||
|
|
cffc722622 | ||
|
|
a7baf1dc9e | ||
|
|
488169683f | ||
|
|
2ba3c52e6e | ||
|
|
18179613fc | ||
|
|
8af0fec8ec | ||
|
|
acee2d9d81 | ||
|
|
cbf06eea24 | ||
|
|
989b4a64d6 | ||
|
|
b01b10014a | ||
|
|
e857f98e5c | ||
|
|
274cff71b1 | ||
|
|
06573c5d12 | ||
|
|
937e5befa2 | ||
|
|
fb403bde8b | ||
|
|
ba174ef3ee | ||
|
|
b7b702862f | ||
|
|
6df2b5735b | ||
|
|
130e151a06 | ||
|
|
ab903e3cc1 | ||
|
|
237387b2ab | ||
|
|
0c1f650e9c | ||
|
|
357c77ef30 | ||
|
|
dc7c049a7b | ||
|
|
710b008453 | ||
|
|
6e7aecc568 | ||
|
|
b68f7e3fd1 | ||
|
|
d30d5585c6 | ||
|
|
1b7c7a0dc1 | ||
|
|
207f2b5ac4 | ||
|
|
d13fa1392f | ||
|
|
9bf886fe98 | ||
|
|
aeef77ac24 | ||
|
|
9a97a1ee72 | ||
|
|
6aaf607ed7 | ||
|
|
cff0397735 | ||
|
|
2aa0b51c09 | ||
|
|
ce8a2d0222 | ||
|
|
135755d21d | ||
|
|
5be4e83876 | ||
|
|
cbc9eb3a59 | ||
|
|
0593359ef8 | ||
|
|
2081d3ce29 | ||
|
|
41d9c097e8 | ||
|
|
1fe1e40a43 | ||
|
|
ad6e2dd370 | ||
|
|
bb63f23414 | ||
|
|
43f6bf74f2 | ||
|
|
662d7b099e | ||
|
|
d5eeeea764 | ||
|
|
43c507c597 | ||
|
|
e356771049 | ||
|
|
48139290ed | ||
|
|
bd852c82b7 | ||
|
|
13564993d7 | ||
|
|
bfc1e1bc2c | ||
|
|
ba20717a09 | ||
|
|
52e40daf23 | ||
|
|
430a7b2297 | ||
|
|
c91a38a882 | ||
|
|
6e02bee4b7 | ||
|
|
b62218110e | ||
|
|
e2960b2607 | ||
|
|
88e7c39066 | ||
|
|
2a6dd636fa | ||
|
|
6bf38f78d5 | ||
|
|
5a04a935be | ||
|
|
8923e938d2 | ||
|
|
1a1734abf0 | ||
|
|
8093a3eeb2 | ||
|
|
9edb3d0a82 | ||
|
|
d95fab11be | ||
|
|
6ef09c8ad5 | ||
|
|
283a023a06 | ||
|
|
d315edef5f | ||
|
|
5fa17b300e | ||
|
|
32919de7a7 | ||
|
|
7d126aab41 | ||
|
|
16ac57ced3 | ||
|
|
4976b967e7 | ||
|
|
e874178782 | ||
|
|
8cb66ad01b | ||
|
|
f887a39912 | ||
|
|
2beffd3dd3 | ||
|
|
d8cb92d8d4 | ||
|
|
158db83965 | ||
|
|
603bfa7def | ||
|
|
829fb879a6 | ||
|
|
0385e60ce1 | ||
|
|
aaea23f785 | ||
|
|
131efd6ba5 | ||
|
|
866564370d | ||
|
|
dcdc0d8918 | ||
|
|
6c7fa17e50 | ||
|
|
38a0d00142 | ||
|
|
5c77e67b0f | ||
|
|
961cee5e41 | ||
|
|
c9cc93be8c | ||
|
|
49f2e1a71e | ||
|
|
97eff6085a | ||
|
|
8b2e2d61af | ||
|
|
c096efb416 | ||
|
|
cdaf6fb9dc | ||
|
|
78f443ed6d | ||
|
|
54e8d72b10 | ||
|
|
05161f48fd | ||
|
|
e971bf6b88 | ||
|
|
55b979784c | ||
|
|
97aa922b5f | ||
|
|
11c760a4e8 | ||
|
|
87b03332d9 | ||
|
|
8b14eeadf4 | ||
|
|
e0ead127e0 | ||
|
|
0887bcdee0 | ||
|
|
67d83041d7 | ||
|
|
1350f388f0 | ||
|
|
65dde9e69d | ||
|
|
2e5bd238b7 | ||
|
|
8fc8fd6cba | ||
|
|
dfc6c87250 | ||
|
|
b63e01225e | ||
|
|
561b82027a | ||
|
|
f6d8fbf570 | ||
|
|
568201ebbb | ||
|
|
ab421f2185 | ||
|
|
f71a2f5263 | ||
|
|
d000cc5a67 | ||
|
|
04d6ba0853 | ||
|
|
8d7c028ca8 | ||
|
|
3ae7ebfeaf | ||
|
|
aa42d38387 | ||
|
|
43843b92f2 | ||
|
|
5da879600a | ||
|
|
87ed2064e3 | ||
|
|
34e96e91d4 | ||
|
|
8c4c2b89ce | ||
|
|
373021c191 | ||
|
|
740c3c1b00 | ||
|
|
67c7132e6b | ||
|
|
c77843424b | ||
|
|
2d4959aa7d | ||
|
|
167c59a159 | ||
|
|
1d0006ce59 | ||
|
|
6a8b4ee2f1 | ||
|
|
72b1515b68 | ||
|
|
3f0252b498 | ||
|
|
1d9d487f0e | ||
|
|
1bcbf74883 |
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -1,5 +1,5 @@
|
|||||||
name: Bug 报告 🐛
|
name: Bug 报告 🐛
|
||||||
description: 为 geekai 提交错误报告
|
description: 为 chatgpt-plus 提交错误报告
|
||||||
labels: ['Bug']
|
labels: ['Bug']
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
@@ -1,5 +1,5 @@
|
|||||||
name: 功能优化 🚀
|
name: 功能优化 🚀
|
||||||
description: 为 geekai 提交优化建议
|
description: 为 chatgpt-plus 提交优化建议
|
||||||
labels: ['feature']
|
labels: ['feature']
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
|
|||||||
531
CHANGELOG.md
531
CHANGELOG.md
@@ -1,290 +1,383 @@
|
|||||||
# 更新日志
|
# 更新日志
|
||||||
|
|
||||||
|
## v4.1.9
|
||||||
|
|
||||||
|
- 功能优化:优化系统配置,移除已废弃的配置项
|
||||||
|
- 功能优化:GPT-O1 模型支持流式输出
|
||||||
|
- 功能优化:优化代码引用快样式,支持主题切换
|
||||||
|
- 功能优化:登录,注册页面允许替换用户自己的 Logo 和 Title
|
||||||
|
- Bug 修复:修复 OpenAI 实时语音通话没有检测用户算力不足的 Bug
|
||||||
|
- 功能新增:管理后台增加算力日志查询功能,支持按用户,按模型,按日期,按类型查询算力日志
|
||||||
|
- 功能优化:支持为模型绑定 Dalle 和 chat 类型的 API KEY
|
||||||
|
- 功能新增:支持管理后台设置 ICP 备案号
|
||||||
|
|
||||||
|
## v4.1.8
|
||||||
|
|
||||||
|
- 功能优化:**UI 全新改版,支持主题切换**。 :rocket: :rocket: :rocket:
|
||||||
|
- 功能新增:Gitee AI API 接口接入,目前支持 Gitee 的 SD 绘图接口,支持 Gitee 的 AI 对话接口。:rocket: :rocket: :rocket:
|
||||||
|
- Bug 修复:修复音 Luma API 更新导致任务响应解析失败的错误
|
||||||
|
- 功能优化:支持 Suno v4.0 模型支持
|
||||||
|
- Bug 修复:修复 Suno 已完成任务删除失败的 错误
|
||||||
|
- 功能新增:支持 OpenAI 实时语音通话功能,目前已经支持按次收费,支持管理员设置每次实时语音通话的算力消耗
|
||||||
|
- 功能新增:生成提示词需要消耗算力,支持管理员设置每次生成提示词的算力消耗,防止被白嫖
|
||||||
|
- 功能新增:DALL-E-3 绘图支持 Flux 绘图模型,支持在管理后添加 Flux,SD 等绘图模型
|
||||||
|
- 功能优化:Markdown 支持解析 emoji 表情
|
||||||
|
- 功能优化:当管理后台禁用了某个绘图菜单的时候,移动端绘图菜单也会同步禁用(不显示该功能)
|
||||||
|
|
||||||
|
## v4.1.7
|
||||||
|
|
||||||
|
- Bug 修复:手机邮箱相关的注册问题 [#IB0HS5](https://gitee.com/blackfox/geekai/issues/IB0HS5)
|
||||||
|
- Bug 修复:音乐视频无法下载,思维导图下载后看不清文字[#IB0N2E](https://gitee.com/blackfox/geekai/issues/IB0N2E)
|
||||||
|
- 功能优化:保存所有 AIGC 任务的原始信息,程序启动之后自动将未执行的任务加入到 redis 队列
|
||||||
|
- 功能优化:失败的任务自动退回算力,而不需要在删除的时候再退回
|
||||||
|
- 功能新增:支持设置一个专门的模型来翻译提示词,提供 Mate 提示词生成功能
|
||||||
|
- Bug 修复:修复图片对话的时候,上下文不起作用的 Bug
|
||||||
|
- 功能新增:管理后台新增批量导出兑换码功能
|
||||||
|
|
||||||
|
## v4.1.6
|
||||||
|
|
||||||
|
- 功能新增:**支持 OpenAI 实时语音对话功能** :rocket: :rocket: :rocket:, Beta 版,目前没有做算力计费控制,目前只有 VIP 用户可以使用。
|
||||||
|
- 功能优化:优化 MysQL 容器配置文档,解决 MysQL 容器资源占用过高问题
|
||||||
|
- 功能新增:管理后台增加 AI 绘图任务管理,可在管理后台浏览和删除用户的绘图任务
|
||||||
|
- 功能新增:管理后台增加 Suno 和 Luma 任务管理功能
|
||||||
|
- Bug 修复:修复管理后台删除兑换码报 404 错误
|
||||||
|
- 功能优化:优化充值产品定价逻辑,可以设置原价和优惠价,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**。
|
||||||
|
|
||||||
|
## v4.1.5
|
||||||
|
|
||||||
|
- 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
|
||||||
|
- Bug 修复:兼容手机端原生微信支付和支付宝支付渠道
|
||||||
|
- Bug 修复:修复删除绘图任务时候因为字段长度过短导致 SQL 执行失败问题
|
||||||
|
- 功能优化:优化 Vue 组件通信代码,使用共享数据来替换之前的事件订阅模式,效率更高一些
|
||||||
|
- 功能优化:优化思维导图生成功果页面,优化用户体验
|
||||||
|
|
||||||
|
## v4.1.4
|
||||||
|
|
||||||
|
- 功能优化:用户文件列表组件增加分页功能支持
|
||||||
|
- Bug 修复:修复用户注册失败 Bug,注册操作只弹出一次行为验证码
|
||||||
|
- 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
|
||||||
|
- 功能新增:给 AI 应用(角色)增加分类,前端支持分类筛选
|
||||||
|
- 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
|
||||||
|
- 功能优化:移除 PayJS 支付渠道支持,PayJs 已经关闭注册服务,请使用其他支付方式。
|
||||||
|
- 功能新增:新增 GeeK 易支付支付渠道,支持支付宝,微信支付,QQ 钱包,京东支付,抖音支付,Paypal 支付等支付方式
|
||||||
|
- Bug 修复:修复注册页面 tab 组件没有自动选中问题 [#6](https://github.com/yangjian102621/geekai-plus/issues/6)
|
||||||
|
- 功能优化:Luma 生成视频任务增加自动翻译功能
|
||||||
|
- Bug 修复:Suno 和 Luma 任务没有判断用户算力
|
||||||
|
- 功能新增:邮箱注册增加邮箱后缀白名单,防止使用某些垃圾邮箱注册薅羊毛
|
||||||
|
- 功能优化:清空未支付订单时,只清空超过 15 分钟未支付的订单
|
||||||
|
|
||||||
|
## v4.1.3
|
||||||
|
|
||||||
|
- 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
|
||||||
|
- 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
|
||||||
|
- 功能优化:管理后台给可以拖动排序的组件添加拖动图标
|
||||||
|
- 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
|
||||||
|
- Bug 修复:手机端角色和模型选择不生效
|
||||||
|
- Bug 修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
|
||||||
|
- 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
|
||||||
|
- 功能优化:给算力增减服务全部加上数据库事务和同步锁
|
||||||
|
- 功能优化:支持用户在前端对话界面选择插件
|
||||||
|
- 功能新增:支持 Luma 文生视频功能
|
||||||
|
|
||||||
|
## v4.1.2
|
||||||
|
|
||||||
|
- Bug 修复:修复思维导图页面获取模型失败的问题
|
||||||
|
- 功能优化:优化 MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
|
||||||
|
- Bug 修复:修复后台拖动排序组件 Bug
|
||||||
|
- 功能优化:更新数据库失败时候显示具体的的报错信息
|
||||||
|
- Bug 修复:修复管理后台对话详情页内容显示异常问题
|
||||||
|
- 功能优化:管理后台新增清空所有未支付订单的功能
|
||||||
|
- 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
|
||||||
|
- 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
|
||||||
|
|
||||||
## v4.1.1
|
## v4.1.1
|
||||||
* Bug修复:修复 GPT 模型 function call 调用后没有输出的问题
|
|
||||||
* 功能新增:允许获取 License 授权用户可以自定义版权信息
|
- Bug 修复:修复 GPT 模型 function call 调用后没有输出的问题
|
||||||
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
|
- 功能新增:允许获取 License 授权用户可以自定义版权信息
|
||||||
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
|
- 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
|
||||||
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
|
- 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
|
||||||
* 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
|
- 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
|
||||||
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
|
- 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
|
||||||
* 功能新增:允许在管理后台设置首页显示的导航菜单
|
- 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
|
||||||
* Bug修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
|
- 功能新增:允许在管理后台设置首页显示的导航菜单
|
||||||
* 功能新增:增加 Suno 文生歌曲功能
|
- Bug 修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
|
||||||
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
|
- 功能新增:增加 Suno 文生歌曲功能
|
||||||
* 功能优化:在所有列表页面增加返回顶部按钮
|
- 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
|
||||||
|
- 功能优化:在所有列表页面增加返回顶部按钮
|
||||||
|
|
||||||
## v4.1.0
|
## v4.1.0
|
||||||
* bug修复:修复移动端修改聊天标题不生效的问题
|
|
||||||
* Bug修复:修复用户注册不显示用户名的问题
|
|
||||||
* Bug修复:修复管理后台拖动排序不生效的问题
|
|
||||||
* 功能优化:允许用户设置自定义首页背景图片
|
|
||||||
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
|
|
||||||
* 功能优化:优化聊天界面的用户上传文件的列表样式
|
|
||||||
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
|
|
||||||
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
|
|
||||||
|
|
||||||
|
- bug 修复:修复移动端修改聊天标题不生效的问题
|
||||||
|
- Bug 修复:修复用户注册不显示用户名的问题
|
||||||
|
- Bug 修复:修复管理后台拖动排序不生效的问题
|
||||||
|
- 功能优化:允许用户设置自定义首页背景图片
|
||||||
|
- 功能新增:**支持 AI 解读 PDF, Word, Excel 等文件**
|
||||||
|
- 功能优化:优化聊天界面的用户上传文件的列表样式
|
||||||
|
- 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
|
||||||
|
- 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
|
||||||
|
|
||||||
## v4.0.9
|
## v4.0.9
|
||||||
* 环境升级:升级 Golang 到 go1.22.4
|
|
||||||
* 功能增加:接入微信商户号支付渠道
|
- 环境升级:升级 Golang 到 go1.22.4
|
||||||
* Bug修复:修复前端页面菜单把页面撑开,底部留白问题
|
- 功能增加:接入微信商户号支付渠道
|
||||||
* 功能优化:聊天页面自动根据内容调整输入框的高度
|
- Bug 修复:修复前端页面菜单把页面撑开,底部留白问题
|
||||||
* Bug修复:修复Dalle绘图失败退回算力的问题
|
- 功能优化:聊天页面自动根据内容调整输入框的高度
|
||||||
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
|
- Bug 修复:修复 Dalle 绘图失败退回算力的问题
|
||||||
* 功能优化:允许设置邮件验证码的抬头
|
- 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
|
||||||
* Bug修复:修复免费模型不会记录聊天记录的bug
|
- 功能优化:允许设置邮件验证码的抬头
|
||||||
* Bug修复:修复聊天输入公式显示异常的Bug
|
- Bug 修复:修复免费模型不会记录聊天记录的 bug
|
||||||
|
- Bug 修复:修复聊天输入公式显示异常的 Bug
|
||||||
|
|
||||||
## v4.0.8
|
## v4.0.8
|
||||||
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
|
|
||||||
* 功能优化:当数据库更新失败的时候记录错误日志
|
- 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
|
||||||
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
|
- 功能优化:当数据库更新失败的时候记录错误日志
|
||||||
* Bug修复:修复移动端聊天页面模型切换不生效的Bug
|
- 功能优化:聊天输入框会随着输入内容的增多自动调整高度
|
||||||
* 功能优化:给PC端扫码支付增加签名验证和有效期验证
|
- Bug 修复:修复移动端聊天页面模型切换不生效的 Bug
|
||||||
* Bug修复:修复支付码生成API权限控制的问题
|
- 功能优化:给 PC 端扫码支付增加签名验证和有效期验证
|
||||||
* Bug修复:模型算力设置为0时,不扣减用户算力,并且不记录算力消费日志
|
- Bug 修复:修复支付码生成 API 权限控制的问题
|
||||||
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
|
- Bug 修复:模型算力设置为 0 时,不扣减用户算力,并且不记录算力消费日志
|
||||||
* 功能新增:H5端支持 Dalle 绘图
|
- 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
|
||||||
|
- 功能新增:H5 端支持 Dalle 绘图
|
||||||
|
|
||||||
## v4.0.7
|
## v4.0.7
|
||||||
|
|
||||||
* 功能优化:升级quic-go,支持 Go1.21
|
- 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
|
||||||
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
|
- Bug 修复:修复弹窗等于图形验证码一直验证失败的问题
|
||||||
* Bug修复:修复弹窗等于图形验证码一直验证失败的问题
|
- 功能重构:重构前端 UI 页面,增加顶部导航
|
||||||
* 功能重构:重构前端 UI 页面,增加顶部导航
|
- 功能优化:优化 Vue 非父子组件之间的通信方式
|
||||||
* 功能优化:优化 Vue 非父子组件之间的通信方式
|
- 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
|
||||||
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
|
|
||||||
|
|
||||||
## v4.0.6
|
## v4.0.6
|
||||||
|
|
||||||
* Bug修复:修复PC端画廊页面的瀑布流组件样式错乱问题
|
- Bug 修复:修复 PC 端画廊页面的瀑布流组件样式错乱问题
|
||||||
* 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
|
- 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
|
||||||
* Bug修复:修复思维导图不扣费的Bug
|
- Bug 修复:修复思维导图不扣费的 Bug
|
||||||
* Bug修复:修复管理后台角色删除失败的Bug
|
- Bug 修复:修复管理后台角色删除失败的 Bug
|
||||||
* Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数
|
- Bug 修复:兼容最新版秋叶 SD 懒人包的 SD API,新增 scheduler 参数
|
||||||
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
|
- 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
|
||||||
* Bug修复:修复注册用户提示注册人数达到上限的 Bug
|
- Bug 修复:修复注册用户提示注册人数达到上限的 Bug
|
||||||
* 功能优化:将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
|
- 功能优化:将 MJ,SD,Dall 绘画页面的任务列表全改成瀑布流组件
|
||||||
|
|
||||||
## v4.0.5
|
## v4.0.5
|
||||||
|
|
||||||
* 功能优化:已授权系统在后台显示授权信息
|
- 功能优化:已授权系统在后台显示授权信息
|
||||||
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
|
- 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
|
||||||
* 功能优化:优化首页登录注册页面的 UI
|
- 功能优化:优化首页登录注册页面的 UI
|
||||||
* BUG修复:修复License验证的逻辑漏洞
|
- BUG 修复:修复 License 验证的逻辑漏洞
|
||||||
* Bug修复:后台添加用户的时候密码规则限制跟前台注册保持一致
|
- Bug 修复:后台添加用户的时候密码规则限制跟前台注册保持一致
|
||||||
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
|
- 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
|
||||||
* 功能新增:移动端新增 DALL-E 绘画功能
|
- 功能新增:移动端新增 DALL-E 绘画功能
|
||||||
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
|
- 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
|
||||||
* 功能新增:移动支持免登录预览功能
|
- 功能新增:移动支持免登录预览功能
|
||||||
* Bug修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
|
- Bug 修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
|
||||||
* Bug修复:修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
|
- Bug 修复:修复部分中转 API 模型会出现第一输出的字符被淹没的 Bug
|
||||||
|
|
||||||
## v4.0.4
|
## v4.0.4
|
||||||
|
|
||||||
* Bug修复:修复统一千问第二句不回复的问题
|
- Bug 修复:修复统一千问第二句不回复的问题
|
||||||
* 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
|
- 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
|
||||||
* 功能新增:Dalle AI 绘画功能实现
|
- 功能新增:Dalle AI 绘画功能实现
|
||||||
* Bug修复:修复思维导图格式乱码问题
|
- Bug 修复:修复思维导图格式乱码问题
|
||||||
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
|
- 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
|
||||||
* 功能新增:支持从应用列表直接和某个应用对话
|
- 功能新增:支持从应用列表直接和某个应用对话
|
||||||
* 功能优化:优化算力日志的页面和首页的UI
|
- 功能优化:优化算力日志的页面和首页的 UI
|
||||||
* 功能新增:支持思维导图导出 PNG 图片下载
|
- 功能新增:支持思维导图导出 PNG 图片下载
|
||||||
|
|
||||||
## v4.0.3
|
## v4.0.3
|
||||||
|
|
||||||
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
|
- 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
|
||||||
* Bug修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
|
- Bug 修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
|
||||||
* Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
|
- Bug 修复:修复 MidJourney 在任务超时后出现后面的任务覆盖前面任务的问题
|
||||||
* 功能新增:支持上传图片和视觉模型
|
- 功能新增:支持上传图片和视觉模型
|
||||||
* 功能优化:优化聊天页面的复制代码按钮样式乱码
|
- 功能优化:优化聊天页面的复制代码按钮样式乱码
|
||||||
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
|
- 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
|
||||||
* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4
|
- 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用 GPT3.5 或者 GPT4
|
||||||
* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
|
- 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
|
||||||
* 功能新增:支持管理后台 Logo 修改
|
- 功能新增:支持管理后台 Logo 修改
|
||||||
|
|
||||||
## 4.0.2
|
## 4.0.2
|
||||||
|
|
||||||
* 功能新增:支持前端菜单可以配置
|
- 功能新增:支持前端菜单可以配置
|
||||||
* 功能优化:在登录和注册界面标题显示软件版本号
|
- 功能优化:在登录和注册界面标题显示软件版本号
|
||||||
* 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
|
- 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
|
||||||
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
|
- 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
|
||||||
* Bug修复:解决因为图片上传使用相对路径而导致融图失败的问题。
|
- Bug 修复:解决因为图片上传使用相对路径而导致融图失败的问题。
|
||||||
* 功能新增:手机端支持 Stable-Diffusion 绘画
|
- 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||||
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
- 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
||||||
|
|
||||||
## v4.0.1
|
## v4.0.1
|
||||||
|
|
||||||
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
|
- 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
|
||||||
发行版,稳定性更强一些
|
发行版,稳定性更强一些
|
||||||
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
|
- 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
|
||||||
MJ-Plus 中转
|
MJ-Plus 中转
|
||||||
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
- 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||||
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
|
- Bug 修复:修复 iphone 手机无法通过图形验证码的 Bug,使用滑动验证码替换
|
||||||
* Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
|
- Bug 修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的 Bug
|
||||||
|
|
||||||
## v4.0.0
|
## v4.0.0
|
||||||
|
|
||||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI 对话,MJ 绘画,SD 绘画,DALL 绘画)全部使用算力来兑换。
|
||||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ
|
只要你的算力值余额不为 0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗 1 个单位算力,一次 GPT4 对话消耗 10 个算力。一次 MJ
|
||||||
对话消耗15个算力...
|
对话消耗 15 个算力...
|
||||||
|
|
||||||
* 功能重构:重构整体系统,全部采用算力来进行结算
|
- 功能重构:重构整体系统,全部采用算力来进行结算
|
||||||
* 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
- 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
||||||
* 功能优化:移动端聊天页面图片支持预览和放大功能
|
- 功能优化:移动端聊天页面图片支持预览和放大功能
|
||||||
* 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
|
- 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
|
||||||
* 功能优化:**PC端不登录也可以预览功能,只有在发起操作的时候才需要登录**
|
- 功能优化:**PC 端不登录也可以预览功能,只有在发起操作的时候才需要登录**
|
||||||
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
|
- 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
|
||||||
* 功能新增:支持H5支付
|
- 功能新增:支持 H5 支付
|
||||||
* 功能优化:支持数学公式的识别和美化输出
|
- 功能优化:支持数学公式的识别和美化输出
|
||||||
* 功能新增:新增算力消费日志功能
|
- 功能新增:新增算力消费日志功能
|
||||||
* 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
|
- 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
|
||||||
* 功能新增:管理后台新增7日内新增用户和新增订单统计
|
- 功能新增:管理后台新增 7 日内新增用户和新增订单统计
|
||||||
|
|
||||||
## v3.2.7
|
## v3.2.7
|
||||||
|
|
||||||
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
- 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
||||||
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
- 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
||||||
* Bug修复:修复 issue [
|
- Bug 修复:修复 issue [
|
||||||
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
|
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
|
||||||
* 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型
|
- 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型
|
||||||
* Bug修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
|
- Bug 修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
|
||||||
* Bug修复:重新生成按钮功能失效问题
|
- Bug 修复:重新生成按钮功能失效问题
|
||||||
* Bug修复:对话输入HTML标签不显示的问题
|
- Bug 修复:对话输入 HTML 标签不显示的问题
|
||||||
* 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
|
- 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
|
||||||
* 功能新增:新增删除文件功能
|
- 功能新增:新增删除文件功能
|
||||||
* Bug修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
|
- Bug 修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
|
||||||
* 功能新增:后台管理新怎对话查看和检索功能
|
- 功能新增:后台管理新怎对话查看和检索功能
|
||||||
|
|
||||||
## v3.2.6
|
## v3.2.6
|
||||||
|
|
||||||
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
|
- 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
|
||||||
* 功能优化:兼用旧版本微信收款消息解析
|
- 功能优化:兼用旧版本微信收款消息解析
|
||||||
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
|
- 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
|
||||||
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
|
- 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
|
||||||
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
|
- 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
|
||||||
* 功能新增:新增网站公告,可以在管理后台自定义配置
|
- 功能新增:新增网站公告,可以在管理后台自定义配置
|
||||||
* 功能新增:新增阿里通义千问大模型支持
|
- 功能新增:新增阿里通义千问大模型支持
|
||||||
* Bug修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
|
- Bug 修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
|
||||||
* 功能优化:新增虎皮椒和PayJS订单状态校验功能,增加安全性
|
- 功能优化:新增虎皮椒和 PayJS 订单状态校验功能,增加安全性
|
||||||
* Bug修复:修复微信转账交易 ID 提取失败 Bug
|
- Bug 修复:修复微信转账交易 ID 提取失败 Bug
|
||||||
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
|
- 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
|
||||||
* 功能新增:新增短信宝短信平台发送平台集成
|
- 功能新增:新增短信宝短信平台发送平台集成
|
||||||
|
|
||||||
## v3.2.5
|
## v3.2.5
|
||||||
|
|
||||||
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
|
- 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
|
||||||
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
|
- 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
|
||||||
Plus 账号了!!!
|
Plus 账号了!!!
|
||||||
* 功能优化:增强 markdown 图片和引用块解析。
|
- 功能优化:增强 markdown 图片和引用块解析。
|
||||||
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
|
- 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
|
||||||
* 功能优化:function call 兼用中转 API。
|
- 功能优化:function call 兼用中转 API。
|
||||||
* Bug修复:修复部分已知的 Bug。
|
- Bug 修复:修复部分已知的 Bug。
|
||||||
|
|
||||||
## v3.2.4.1
|
## v3.2.4.1
|
||||||
|
|
||||||
* 功能新增:新增 PayJs 支付通道
|
- 功能新增:新增 PayJs 支付通道
|
||||||
* Bug修复:紧急修复后台添加用户失败问题
|
- Bug 修复:紧急修复后台添加用户失败问题
|
||||||
* Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题
|
- Bug 修复:紧急修复使用中转 API-KEY 无法绘图的问题
|
||||||
* Bug修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
|
- Bug 修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
|
||||||
|
|
||||||
## v3.2.4
|
## v3.2.4
|
||||||
|
|
||||||
* 功能新增:重磅更新,支持邮箱注册
|
- 功能新增:重磅更新,支持邮箱注册
|
||||||
* 功能优化:优化函数调用授权
|
- 功能优化:优化函数调用授权
|
||||||
* 功能优化:给用户表新增 nickname 字段
|
- 功能优化:给用户表新增 nickname 字段
|
||||||
* 功能优化:管理后台给聊天角色增加启用/禁用开关
|
- 功能优化:管理后台给聊天角色增加启用/禁用开关
|
||||||
* Bug修复:SD绘画出现重复扣减绘图次数
|
- Bug 修复:SD 绘画出现重复扣减绘图次数
|
||||||
* 功能优化:优化聊天对话导出样式,适应移动端
|
- 功能优化:优化聊天对话导出样式,适应移动端
|
||||||
* 功能新增:众筹核销可以选择兑换对话还是绘图的额度
|
- 功能新增:众筹核销可以选择兑换对话还是绘图的额度
|
||||||
* Bug修复:修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
|
- Bug 修复:修复[从历史记录获取 reply 有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
|
||||||
* Bug修复:修复 MidJourney 绘图任务调度Bug,为 task_id 建议唯一索引
|
- Bug 修复:修复 MidJourney 绘图任务调度 Bug,为 task_id 建议唯一索引
|
||||||
* 功能重构:重构了 API KEY模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
|
- 功能重构:重构了 API KEY 模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
|
||||||
|
|
||||||
## v3.2.3
|
## v3.2.3
|
||||||
|
|
||||||
* 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
|
- 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
|
||||||
* 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
|
- 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
|
||||||
* Bug修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
|
- Bug 修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
|
||||||
* Bug修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
|
- Bug 修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
|
||||||
* 功能新增:新增删除绘图任何和图片功能
|
- 功能新增:新增删除绘图任何和图片功能
|
||||||
* Bug修复:修复虎皮椒支付二维码重复扫码时报错问题
|
- Bug 修复:修复虎皮椒支付二维码重复扫码时报错问题
|
||||||
* 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
|
- 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
|
||||||
* 功能优化:优化AI绘画的大图压缩算法,新增图片缓存
|
- 功能优化:优化 AI 绘画的大图压缩算法,新增图片缓存
|
||||||
* 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
|
- 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
|
||||||
* Bug修复:修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
|
- Bug 修复:修复[Azure Api 更换 api-version 参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
|
||||||
* Bug修复:修复科大讯飞 V1.5 API 请求失败的问题
|
- Bug 修复:修复科大讯飞 V1.5 API 请求失败的问题
|
||||||
* Bug修复:绘图失败后,自动恢复用户的剩余绘图次数
|
- Bug 修复:绘图失败后,自动恢复用户的剩余绘图次数
|
||||||
* 功能新增:为移动端新增 SD 绘图功能,分享功能
|
- 功能新增:为移动端新增 SD 绘图功能,分享功能
|
||||||
|
|
||||||
## v3.2.2
|
## v3.2.2
|
||||||
|
|
||||||
* 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
|
- 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
|
||||||
* 功能新增:AI绘画页面增加翻译和重写提示词功能
|
- 功能新增:AI 绘画页面增加翻译和重写提示词功能
|
||||||
* 功能优化:OSS上传组件支持在 Bucket 下设置二级目录
|
- 功能优化:OSS 上传组件支持在 Bucket 下设置二级目录
|
||||||
* Bug修复:修复阿里云 OSS 访问路径错误
|
- Bug 修复:修复阿里云 OSS 访问路径错误
|
||||||
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
|
- 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
|
||||||
|
|
||||||
## v3.2.1
|
## v3.2.1
|
||||||
|
|
||||||
* 功能优化:切换角色和模型的时候自动创建新的对话
|
- 功能优化:切换角色和模型的时候自动创建新的对话
|
||||||
* Bug修复:修复文件上传失败No such file bug
|
- Bug 修复:修复文件上传失败 No such file bug
|
||||||
* 功能新增:MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
|
- 功能新增:MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
|
||||||
* Bug修复:[PC端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
|
- Bug 修复:[PC 端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
|
||||||
* 功能新增:增加 arm64 架构打包脚本
|
- 功能新增:增加 arm64 架构打包脚本
|
||||||
* 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
|
- 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
|
||||||
* 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
|
- 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
|
||||||
|
|
||||||
## v3.2.0
|
## v3.2.0
|
||||||
|
|
||||||
* 功能新增:新增邀请注册功能
|
- 功能新增:新增邀请注册功能
|
||||||
* 功能优化:增加中间件自动对HTTP请求的参数去掉首尾空格
|
- 功能优化:增加中间件自动对 HTTP 请求的参数去掉首尾空格
|
||||||
* 功能优化:增加中间件自动为大图片生成缩略图
|
- 功能优化:增加中间件自动为大图片生成缩略图
|
||||||
* 功能优化:MidJourney 页面图片加载优化,实现图片预览懒加载
|
- 功能优化:MidJourney 页面图片加载优化,实现图片预览懒加载
|
||||||
* 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
|
- 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
|
||||||
* Bug修复:修复阿里云 OSS 域名设置不起做用的bug
|
- Bug 修复:修复阿里云 OSS 域名设置不起做用的 bug
|
||||||
* Bug修复:修复MidJourney绘图失败后重复添加到队列的问题
|
- Bug 修复:修复 MidJourney 绘图失败后重复添加到队列的问题
|
||||||
|
|
||||||
## v3.1.9
|
## v3.1.9
|
||||||
|
|
||||||
* 功能新增:增加讯飞星火大模型 v3.0 支持
|
- 功能新增:增加讯飞星火大模型 v3.0 支持
|
||||||
* 功能新增:新增找回密码功能
|
- 功能新增:新增找回密码功能
|
||||||
* 功能新增:支持 Markdown 代码复制功能
|
- 功能新增:支持 Markdown 代码复制功能
|
||||||
* Bug修复: xxl-job 任务调度失败的 Bug
|
- Bug 修复: xxl-job 任务调度失败的 Bug
|
||||||
* 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font
|
- 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font
|
||||||
* Bug修复:Stable-Diffusion 绘画成功之后没有扣减用户画图次数
|
- Bug 修复:Stable-Diffusion 绘画成功之后没有扣减用户画图次数
|
||||||
* 功能优化:优化会员充值页面 ItemList 组件
|
- 功能优化:优化会员充值页面 ItemList 组件
|
||||||
* 功能优化:给首页 Logo 增加链接
|
- 功能优化:给首页 Logo 增加链接
|
||||||
* Bug修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51)
|
- Bug 修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51)
|
||||||
* Bug修复:聊天上下文失效问题
|
- Bug 修复:聊天上下文失效问题
|
||||||
* 功能优化:关闭注册时显示联系管理员二维码
|
- 功能优化:关闭注册时显示联系管理员二维码
|
||||||
* 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能
|
- 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能
|
||||||
* Bug修复:后台启用用户 VIP 不生效问题
|
- Bug 修复:后台启用用户 VIP 不生效问题
|
||||||
* 功能优化:充值支付页面的支付说明文字可以后台配置
|
- 功能优化:充值支付页面的支付说明文字可以后台配置
|
||||||
* Bug修复:ChatGLM,百度文心,科大讯飞模型输出代码不换行问题
|
- Bug 修复:ChatGLM,百度文心,科大讯飞模型输出代码不换行问题
|
||||||
|
|
||||||
## v3.1.8
|
## v3.1.8
|
||||||
|
|
||||||
1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道
|
1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道
|
||||||
2. Bug修复:修复 MidJourney API 参数版本更新导致调用失败问题
|
2. Bug 修复:修复 MidJourney API 参数版本更新导致调用失败问题
|
||||||
3. Bug修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题
|
3. Bug 修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题
|
||||||
4. Bug修复:修复七牛云上传报错 expired token
|
4. Bug 修复:修复七牛云上传报错 expired token
|
||||||
5. Bug修复:修复高权重模型导致的对话次数为负数的漏洞
|
5. Bug 修复:修复高权重模型导致的对话次数为负数的漏洞
|
||||||
6. 功能优化:将聊天报错信息定义为统一常量,方便修改
|
6. 功能优化:将聊天报错信息定义为统一常量,方便修改
|
||||||
7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式
|
7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式
|
||||||
8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单
|
8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单
|
||||||
|
|
||||||
## v3.1.7
|
## v3.1.7
|
||||||
|
|
||||||
1. 功能新增:支持文心4.0 AI 模型
|
1. 功能新增:支持文心 4.0 AI 模型
|
||||||
2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型
|
2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型
|
||||||
3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如GPT4是GPT3.5的10倍
|
3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如 GPT4 是 GPT3.5 的 10 倍
|
||||||
4. 功能新增:新增系统配置关闭 AI 模型的函数功能
|
4. 功能新增:新增系统配置关闭 AI 模型的函数功能
|
||||||
5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式
|
5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式
|
||||||
|
|
||||||
## v3.1.6
|
## v3.1.6
|
||||||
|
|
||||||
1. 功能新增:新增AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品
|
1. 功能新增:新增 AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品
|
||||||
2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用
|
2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用
|
||||||
3. 功能优化:优化瀑布流组件的页面布局
|
3. 功能优化:优化瀑布流组件的页面布局
|
||||||
4. 功能优化:新注册用户成功之后自动登录
|
4. 功能优化:新注册用户成功之后自动登录
|
||||||
@@ -296,55 +389,55 @@
|
|||||||
2. 功能新增:新增科大讯飞星火大模型 API 接入支持
|
2. 功能新增:新增科大讯飞星火大模型 API 接入支持
|
||||||
3. 功能重构:将 chat_handler 的所有功能实现放入单独的包中
|
3. 功能重构:将 chat_handler 的所有功能实现放入单独的包中
|
||||||
4. 功能新增:新增系统配置 `enabled_function` 用于启用和关闭函数功能
|
4. 功能新增:新增系统配置 `enabled_function` 用于启用和关闭函数功能
|
||||||
5. Bug修复:修复管理后台更新 API Key 失败的 Bug
|
5. Bug 修复:修复管理后台更新 API Key 失败的 Bug
|
||||||
6. Bug修复:修复新建的对话无法更新对话标题的 Bug
|
6. Bug 修复:修复新建的对话无法更新对话标题的 Bug
|
||||||
7. 功能优化:其他一些小的体验优化工作
|
7. 功能优化:其他一些小的体验优化工作
|
||||||
|
|
||||||
## v3.1.4
|
## v3.1.4
|
||||||
|
|
||||||
1. 功能新增:新增阿里云 OSS 图片上传实现,目前已支持本地存储,七牛云,Minio和阿里云 OSS 四种存储介质。
|
1. 功能新增:新增阿里云 OSS 图片上传实现,目前已支持本地存储,七牛云,Minio 和阿里云 OSS 四种存储介质。
|
||||||
2. 功能新增:**增加 Stable Diffusion 绘画功能页面**。
|
2. 功能新增:**增加 Stable Diffusion 绘画功能页面**。
|
||||||
3. 功能重构:将 [chatgpt-plus-exts](https://github.com/yangjian102621/chatgpt-plus-exts) 合并到本项目,部署更加简单,无需部署两个项目了。
|
3. 功能重构:将 [chatgpt-plus-exts](https://github.com/yangjian102621/chatgpt-plus-exts) 合并到本项目,部署更加简单,无需部署两个项目了。
|
||||||
4. Bug修复:修复[用户注册报错BUG #37](https://github.com/yangjian102621/chatgpt-plus/issues/37)。
|
4. Bug 修复:修复[用户注册报错 BUG #37](https://github.com/yangjian102621/chatgpt-plus/issues/37)。
|
||||||
5. Bug修复:修复 MidJourney API 接口升级导致图片文保存失败的 Bug。
|
5. Bug 修复:修复 MidJourney API 接口升级导致图片文保存失败的 Bug。
|
||||||
6. 功能优化:增加阿里云短信服务配置项 `Sign` 和 `CodeTempId` 用来配置自己的短信签名和短信验证码模版 ID。
|
6. 功能优化:增加阿里云短信服务配置项 `Sign` 和 `CodeTempId` 用来配置自己的短信签名和短信验证码模版 ID。
|
||||||
7. 功能优化:添加系统配置用来设置自定义的众筹微信收款二维码。
|
7. 功能优化:添加系统配置用来设置自定义的众筹微信收款二维码。
|
||||||
8. 功能优化:优化绘画页面的弹窗样式和页面布局。
|
8. 功能优化:优化绘画页面的弹窗样式和页面布局。
|
||||||
|
|
||||||
## v3.1.3
|
## v3.1.3
|
||||||
|
|
||||||
1. 页面重构:重后 Home 页面,拆分成聊天,MJ绘画,SD 绘画,应用广场等多个功能菜单。
|
1. 页面重构:重后 Home 页面,拆分成聊天,MJ 绘画,SD 绘画,应用广场等多个功能菜单。
|
||||||
2. 功能新增:新增 MidJourney 专业绘画页面,开放更高级的 MJ 绘画姿势。
|
2. 功能新增:新增 MidJourney 专业绘画页面,开放更高级的 MJ 绘画姿势。
|
||||||
3. 功能优化:采用队列的方式控制绘画任务并发,简化任务回调通知逻辑,给任务回调加锁。
|
3. 功能优化:采用队列的方式控制绘画任务并发,简化任务回调通知逻辑,给任务回调加锁。
|
||||||
4. 功能优化:精简用户表字段,删除用户名和昵称,只保留手机号。
|
4. 功能优化:精简用户表字段,删除用户名和昵称,只保留手机号。
|
||||||
5. 功能优化:优化文件上传服务工厂实现,只创建激活的 Uploader 服务,节省资源。
|
5. 功能优化:优化文件上传服务工厂实现,只创建激活的 Uploader 服务,节省资源。
|
||||||
6. Bug修复:修复 JWT token 有效期计算错误的 Bug。
|
6. Bug 修复:修复 JWT token 有效期计算错误的 Bug。
|
||||||
|
|
||||||
## v3.1.2
|
## v3.1.2
|
||||||
|
|
||||||
1. 功能新增:新增七牛云 OSS 实现,目前已支持三种文件上传服务:Local, Minio, QiNiu OSS。
|
1. 功能新增:新增七牛云 OSS 实现,目前已支持三种文件上传服务:Local, Minio, QiNiu OSS。
|
||||||
2. 功能新增:新增桌面版,使用 electron 套壳网页版。
|
2. 功能新增:新增桌面版,使用 electron 套壳网页版。
|
||||||
3. Bug修复:自动去除众筹核销时候转账单号中的空格,防止复制的时候多复制了空格。
|
3. Bug 修复:自动去除众筹核销时候转账单号中的空格,防止复制的时候多复制了空格。
|
||||||
4. 功能优化:ChatPlus.vue 页面支持通过 chat_id path variable 来定位到指定的聊天。
|
4. 功能优化:ChatPlus.vue 页面支持通过 chat_id path variable 来定位到指定的聊天。
|
||||||
5. 功能优化:取消导出聊天页面的授权验证
|
5. 功能优化:取消导出聊天页面的授权验证
|
||||||
6. 功能优化:所有路由跳转都使用绝对路径
|
6. 功能优化:所有路由跳转都使用绝对路径
|
||||||
|
|
||||||
## v3.1.1
|
## v3.1.1
|
||||||
|
|
||||||
紧急修复版本,采用弹窗的方式显示验证码,解决验证码在低分辨率下被掩盖的Bug
|
紧急修复版本,采用弹窗的方式显示验证码,解决验证码在低分辨率下被掩盖的 Bug
|
||||||
|
|
||||||
## v3.1.0(大版本更新)
|
## v3.1.0(大版本更新)
|
||||||
|
|
||||||
1. 功能重构:将聊天模型独立拆分,以便支持多平台模型,目前已经内置支持 OPenAI,Azure 以及
|
1. 功能重构:将聊天模型独立拆分,以便支持多平台模型,目前已经内置支持 OPenAI,Azure 以及
|
||||||
ChatGLM,用户可以在这两个平台的模型中随意切换,体验不同的模型聊天。
|
ChatGLM,用户可以在这两个平台的模型中随意切换,体验不同的模型聊天。
|
||||||
2. 功能重构:重写系统 API 授权机制,使用 JWT 替换传统的 session 会话授权,使得 API 授权变得更加灵活。
|
2. 功能重构:重写系统 API 授权机制,使用 JWT 替换传统的 session 会话授权,使得 API 授权变得更加灵活。
|
||||||
3. 功能重构:重构文件夹上传服务,支持多种文件上传存储handler,目前已经实现本地存储和 minio oss 存储。
|
3. 功能重构:重构文件夹上传服务,支持多种文件上传存储 handler,目前已经实现本地存储和 minio oss 存储。
|
||||||
4. 功能优化:更新头像自动删除旧的图片资源。
|
4. 功能优化:更新头像自动删除旧的图片资源。
|
||||||
5. 功能优化:将应用日志在终端输出的同时存盘,方便 docker 部署查看日志。
|
5. 功能优化:将应用日志在终端输出的同时存盘,方便 docker 部署查看日志。
|
||||||
6. 功能新增:允许用户配置自己的 OPenAI,Azure 以及 ChatGLM API KEY。
|
6. 功能新增:允许用户配置自己的 OPenAI,Azure 以及 ChatGLM API KEY。
|
||||||
7. 功能优化:优化移动版的行为验证码样式,修复低分辨率显示器验证码被遮挡的 Bug
|
7. 功能优化:优化移动版的行为验证码样式,修复低分辨率显示器验证码被遮挡的 Bug
|
||||||
8. 升级 gin, element-plus,redis 组件到最新版本。
|
8. 升级 gin, element-plus,redis 组件到最新版本。
|
||||||
9. Bug修复:修复若干已知的的 Bug
|
9. Bug 修复:修复若干已知的的 Bug
|
||||||
|
|
||||||
## v3.0.7
|
## v3.0.7
|
||||||
|
|
||||||
@@ -354,7 +447,7 @@
|
|||||||
4. 功能新增:支持导出聊天记录为 PDF 文件。
|
4. 功能新增:支持导出聊天记录为 PDF 文件。
|
||||||
5. 功能优化:在后台 dashboard 页面新增统计今日众筹收入。
|
5. 功能优化:在后台 dashboard 页面新增统计今日众筹收入。
|
||||||
6. 功能优化:支持用户设置默认的 GPT 模型
|
6. 功能优化:支持用户设置默认的 GPT 模型
|
||||||
7. Bug修复:修复若干已知的的 Bug
|
7. Bug 修复:修复若干已知的的 Bug
|
||||||
|
|
||||||
## v3.0.6
|
## v3.0.6
|
||||||
|
|
||||||
@@ -362,8 +455,8 @@
|
|||||||
2. 管理后台:新增重置用户密码功能
|
2. 管理后台:新增重置用户密码功能
|
||||||
3. 管理后台:支持关闭注册功能,新增添加用户功能,适用于内部使用场景
|
3. 管理后台:支持关闭注册功能,新增添加用户功能,适用于内部使用场景
|
||||||
4. 管理后台:新增仪表盘页面,统计当天的新增用户,新增会话数据,以及 Token 消耗
|
4. 管理后台:新增仪表盘页面,统计当天的新增用户,新增会话数据,以及 Token 消耗
|
||||||
5. Bug修复:修复注册页面验证码不显示 Bug
|
5. Bug 修复:修复注册页面验证码不显示 Bug
|
||||||
6. Bug修复:优化上下文 Token 计算算法,修复聊天上下文超出限制时循环发送消息的 Bug
|
6. Bug 修复:优化上下文 Token 计算算法,修复聊天上下文超出限制时循环发送消息的 Bug
|
||||||
7. 功能修正:允许用户使用手机号码登录
|
7. 功能修正:允许用户使用手机号码登录
|
||||||
8. 功能优化:更新系统配置后同步更新服务端内存变量数据
|
8. 功能优化:更新系统配置后同步更新服务端内存变量数据
|
||||||
9. 功能优化:优化打包脚本,减少容器镜像大小
|
9. 功能优化:优化打包脚本,减少容器镜像大小
|
||||||
@@ -421,5 +514,5 @@
|
|||||||
4. 新增聊天设置功能,用户可以导入自己的 API KEY
|
4. 新增聊天设置功能,用户可以导入自己的 API KEY
|
||||||
5. 保存聊天记录,支持聊天上下文。
|
5. 保存聊天记录,支持聊天上下文。
|
||||||
6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。
|
6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。
|
||||||
7. 引入 ip2region 组件,记录用户的登录IP和地址。
|
7. 引入 ip2region 组件,记录用户的登录 IP 和地址。
|
||||||
8. 支持会话搜索过滤。
|
8. 支持会话搜索过滤。
|
||||||
|
|||||||
97
README.md
97
README.md
@@ -1,92 +1,19 @@
|
|||||||
# GeekAI
|
# GeekAI-PLUS
|
||||||
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
基于 GeekAI 项目开发的高级版,增加了很多高级功能,比如思维导图,Dalle 绘画等。**高级版源码不会一次性开放,只提供镜像给大家免费使用**,源码会逐步逐步按照版同步迁移到[社区版(GeekAI)](https://github.com/yangjian102621/geekai)。所以如果大家想要二次开发,请移步去社区版。
|
||||||
|
|
||||||
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
|
## 演示站点
|
||||||
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
|
[Geek-AI 创作系统](https://www.geekai.me)
|
||||||
|
|
||||||
主要特性:
|
## 文档地址
|
||||||
|
[Geek-AI 文档](https://www.geekai.me/docs/)
|
||||||
|
|
||||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
## 部署
|
||||||
- 基于 Websocket 实现,完美的打字机体验。
|
1. 安装 docker 和 docker-compose 程序,这个自行解决。
|
||||||
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
2. 直接在项目根目录运行启动命令:
|
||||||
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
```shell
|
||||||
- 支持 Suno 文生音乐
|
docker-compose up -d
|
||||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
|
```
|
||||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
|
||||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
|
||||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
|
||||||
绘画函数插件。
|
|
||||||
|
|
||||||
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
|
|
||||||
|
|
||||||
- [x] 更友好的 UI 界面
|
|
||||||
- [x] 支持 Dall-E 文生图功能
|
|
||||||
- [x] 支持文生思维导图
|
|
||||||
- [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能
|
|
||||||
- [x] 支持网站 Logo 版权等信息的修改
|
|
||||||
|
|
||||||
## 功能截图
|
## 功能截图
|
||||||
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
|
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
|
||||||
|
|
||||||
### 体验地址
|
|
||||||
|
|
||||||
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
|
|
||||||
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
|
|
||||||
|
|
||||||
## 快速部署
|
|
||||||
|
|
||||||
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
|
|
||||||
|
|
||||||
## 使用须知
|
|
||||||
|
|
||||||
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
|
||||||
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
|
|
||||||
|
|
||||||
## 项目地址
|
|
||||||
|
|
||||||
* Github 地址:https://github.com/yangjian102621/geekai
|
|
||||||
* 码云地址:https://gitee.com/blackfox/geekai
|
|
||||||
|
|
||||||
## 客户端下载
|
|
||||||
|
|
||||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
|
|
||||||
|
|
||||||
## TODOLIST
|
|
||||||
|
|
||||||
* [ ] 支持基于知识库的 AI 问答
|
|
||||||
* [ ] 文生视频,文生歌曲功能
|
|
||||||
* [ ] 微信支付功能
|
|
||||||
|
|
||||||
## 项目文档
|
|
||||||
|
|
||||||
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
|
|
||||||
|
|
||||||
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
|
|
||||||
|
|
||||||
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。**
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## 参与贡献
|
|
||||||
|
|
||||||
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
|
|
||||||
|
|
||||||
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合!
|
|
||||||
|
|
||||||
### Commit 类型
|
|
||||||
|
|
||||||
* feat: 新特性或功能
|
|
||||||
* fix: 缺陷修复
|
|
||||||
* docs: 文档更新
|
|
||||||
* style: 代码风格或者组件样式更新
|
|
||||||
* refactor: 代码重构,不引入新功能和缺陷修复
|
|
||||||
* opt: 性能优化
|
|
||||||
* chore: 一些不涉及到功能变动的小提交,比如修改文字表述,修改注释等
|
|
||||||
|
|
||||||
## 打赏
|
|
||||||
|
|
||||||
如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ NAME := geekai
|
|||||||
all: amd64 arm64
|
all: amd64 arm64
|
||||||
|
|
||||||
amd64:
|
amd64:
|
||||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go
|
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w" -o bin/$(NAME)-linux main.go
|
||||||
.PHONY: amd64
|
.PHONY: amd64
|
||||||
|
|
||||||
arm64:
|
arm64:
|
||||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go
|
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -ldflags "-s -w" -o bin/$(NAME)-linux main.go
|
||||||
.PHONY: arm64
|
.PHONY: arm64
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ ProxyURL = "" # 如 http://127.0.0.1:7777
|
|||||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
|
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
|
||||||
StaticDir = "./static" # 静态资源的目录
|
StaticDir = "./static" # 静态资源的目录
|
||||||
StaticUrl = "/static" # 静态资源访问 URL
|
StaticUrl = "/static" # 静态资源访问 URL
|
||||||
AesEncryptKey = ""
|
|
||||||
WeChatBot = false
|
|
||||||
TikaHost = "http://tika:9998"
|
TikaHost = "http://tika:9998"
|
||||||
|
|
||||||
[Session]
|
[Session]
|
||||||
@@ -65,23 +63,6 @@ TikaHost = "http://tika:9998"
|
|||||||
SubDir = ""
|
SubDir = ""
|
||||||
Domain = ""
|
Domain = ""
|
||||||
|
|
||||||
[[MjProxyConfigs]]
|
|
||||||
Enabled = true
|
|
||||||
ApiURL = "http://midjourney-proxy:8082"
|
|
||||||
ApiKey = "sk-geekmaster"
|
|
||||||
|
|
||||||
[[MjPlusConfigs]]
|
|
||||||
Enabled = false
|
|
||||||
ApiURL = "https://api.chat-plus.net"
|
|
||||||
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
|
|
||||||
ApiKey = "sk-xxx"
|
|
||||||
|
|
||||||
[[SdConfigs]]
|
|
||||||
Enabled = false
|
|
||||||
ApiURL = ""
|
|
||||||
ApiKey = ""
|
|
||||||
Txt2ImgJsonPath = "res/sd/text2img.json"
|
|
||||||
|
|
||||||
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
|
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
|
||||||
Enabled = false # 是否启用 XXL JOB 服务
|
Enabled = false # 是否启用 XXL JOB 服务
|
||||||
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
|
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
|
||||||
@@ -90,6 +71,15 @@ TikaHost = "http://tika:9998"
|
|||||||
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
|
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
|
||||||
RegistryKey = "chatgpt-plus" # 任务注册 key
|
RegistryKey = "chatgpt-plus" # 任务注册 key
|
||||||
|
|
||||||
|
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
|
||||||
|
UseTls = false
|
||||||
|
Host = "smtp.163.com"
|
||||||
|
Port = 25
|
||||||
|
AppName = "极客学长"
|
||||||
|
From = "test@163.com" # 发件邮箱人地址
|
||||||
|
Password = "" #邮箱 stmp 服务授权码
|
||||||
|
|
||||||
|
# 支付宝商户支付
|
||||||
[AlipayConfig]
|
[AlipayConfig]
|
||||||
Enabled = false # 启用支付宝支付通道
|
Enabled = false # 启用支付宝支付通道
|
||||||
SandBox = false # 是否启用沙盒模式
|
SandBox = false # 是否启用沙盒模式
|
||||||
@@ -99,31 +89,13 @@ TikaHost = "http://tika:9998"
|
|||||||
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
|
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
|
||||||
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
|
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
|
||||||
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
|
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
|
||||||
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
|
|
||||||
|
|
||||||
|
# 虎皮椒支付
|
||||||
[HuPiPayConfig]
|
[HuPiPayConfig]
|
||||||
Enabled = false
|
Enabled = false
|
||||||
Name = "wechat"
|
|
||||||
AppId = ""
|
AppId = ""
|
||||||
AppSecret = ""
|
AppSecret = ""
|
||||||
ApiURL = "https://api.xunhupay.com"
|
ApiURL = "https://api.xunhupay.com"
|
||||||
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
|
|
||||||
|
|
||||||
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
|
|
||||||
UseTls = false
|
|
||||||
Host = "smtp.163.com"
|
|
||||||
Port = 25
|
|
||||||
AppName = "极客学长"
|
|
||||||
From = "test@163.com" # 发件邮箱人地址
|
|
||||||
Password = "" #邮箱 stmp 服务授权码
|
|
||||||
|
|
||||||
[JPayConfig] # PayJs 支付配置
|
|
||||||
Enabled = false
|
|
||||||
Name = "wechat" # 请不要改动
|
|
||||||
AppId = "" # 商户 ID
|
|
||||||
PrivateKey = "" # 秘钥
|
|
||||||
ApiURL = "https://payjs.cn"
|
|
||||||
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
|
|
||||||
|
|
||||||
# 微信商户支付
|
# 微信商户支付
|
||||||
[WechatPayConfig]
|
[WechatPayConfig]
|
||||||
@@ -133,6 +105,11 @@ TikaHost = "http://tika:9998"
|
|||||||
SerialNo = "" # API 证书序列号
|
SerialNo = "" # API 证书序列号
|
||||||
PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
|
PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
|
||||||
ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的
|
ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的
|
||||||
NotifyURL = "https://ai.r9it.com/api/payment/wechat/notify" # 支付成功异步回调地址,域名改成自己的
|
|
||||||
ReturnURL = "" # 支付成功同步回调地址
|
|
||||||
|
|
||||||
|
# 易支付
|
||||||
|
[GeekPayConfig]
|
||||||
|
Enabled = true
|
||||||
|
AppId = "" # 商户ID
|
||||||
|
PrivateKey = "" # 商户私钥
|
||||||
|
ApiURL = "https://pay.geekai.cn"
|
||||||
|
Methods = ["alipay", "wxpay", "qqpay", "jdpay", "douyin", "paypal"] # 支持的支付方式
|
||||||
|
|||||||
@@ -15,12 +15,6 @@ import (
|
|||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/nfnt/resize"
|
|
||||||
"golang.org/x/image/webp"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"image"
|
"image"
|
||||||
"image/jpeg"
|
"image/jpeg"
|
||||||
"io"
|
"io"
|
||||||
@@ -29,6 +23,13 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/nfnt/resize"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AppServer struct {
|
type AppServer struct {
|
||||||
@@ -51,9 +52,9 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
|||||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
|
func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||||
if debug { // 调试模式允许跨域请求 API
|
if debug { // 调试模式允许跨域请求 API
|
||||||
s.Debug = debug
|
s.Debug = debug
|
||||||
|
s.Engine.Use(corsMiddleware())
|
||||||
logger.Info("Enabled debug mode")
|
logger.Info("Enabled debug mode")
|
||||||
}
|
}
|
||||||
s.Engine.Use(corsMiddleware())
|
|
||||||
s.Engine.Use(staticResourceMiddleware())
|
s.Engine.Use(staticResourceMiddleware())
|
||||||
s.Engine.Use(authorizeMiddleware(s, client))
|
s.Engine.Use(authorizeMiddleware(s, client))
|
||||||
s.Engine.Use(parameterHandlerMiddleware())
|
s.Engine.Use(parameterHandlerMiddleware())
|
||||||
@@ -65,13 +66,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
|
|||||||
func (s *AppServer) Run(db *gorm.DB) error {
|
func (s *AppServer) Run(db *gorm.DB) error {
|
||||||
// load system configs
|
// load system configs
|
||||||
var sysConfig model.Config
|
var sysConfig model.Config
|
||||||
res := db.Where("marker", "system").First(&sysConfig)
|
err := db.Where("marker", "system").First(&sysConfig).Error
|
||||||
if res.Error != nil {
|
|
||||||
return res.Error
|
|
||||||
}
|
|
||||||
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to load system config: %v", err)
|
||||||
|
}
|
||||||
|
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to decode system config: %v", err)
|
||||||
}
|
}
|
||||||
logger.Infof("http://%s", s.Config.Listen)
|
logger.Infof("http://%s", s.Config.Listen)
|
||||||
return s.Engine.Run(s.Config.Listen)
|
return s.Engine.Run(s.Config.Listen)
|
||||||
@@ -101,9 +102,9 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
c.Header("Access-Control-Allow-Origin", origin)
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||||
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization")
|
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
||||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||||
//设置缓存时间
|
//设置缓存时间
|
||||||
c.Header("Access-Control-Max-Age", "172800")
|
c.Header("Access-Control-Max-Age", "172800")
|
||||||
//允许客户端传递校验信息比如 cookie (重要)
|
//允许客户端传递校验信息比如 cookie (重要)
|
||||||
@@ -127,12 +128,19 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
// 用户授权验证
|
// 用户授权验证
|
||||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||||
var tokenString string
|
var tokenString string
|
||||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||||
if isAdminApi { // 后台管理 API
|
if isAdminApi { // 后台管理 API
|
||||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||||
} else if c.Request.URL.Path == "/api/chat/new" {
|
} else if clientProtocols != "" { // Websocket 连接
|
||||||
tokenString = c.Query("token")
|
// 解析子协议内容
|
||||||
|
protocols := strings.Split(clientProtocols, ",")
|
||||||
|
if protocols[0] == "realtime" {
|
||||||
|
tokenString = strings.TrimSpace(protocols[1][25:])
|
||||||
|
} else if protocols[0] == "token" {
|
||||||
|
tokenString = strings.TrimSpace(protocols[1])
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||||
}
|
}
|
||||||
@@ -201,33 +209,29 @@ func needLogin(c *gin.Context) bool {
|
|||||||
c.Request.URL.Path == "/api/admin/logout" ||
|
c.Request.URL.Path == "/api/admin/logout" ||
|
||||||
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||||
c.Request.URL.Path == "/api/user/register" ||
|
c.Request.URL.Path == "/api/user/register" ||
|
||||||
c.Request.URL.Path == "/api/user/session" ||
|
|
||||||
c.Request.URL.Path == "/api/chat/history" ||
|
c.Request.URL.Path == "/api/chat/history" ||
|
||||||
c.Request.URL.Path == "/api/chat/detail" ||
|
c.Request.URL.Path == "/api/chat/detail" ||
|
||||||
c.Request.URL.Path == "/api/chat/list" ||
|
c.Request.URL.Path == "/api/chat/list" ||
|
||||||
c.Request.URL.Path == "/api/role/list" ||
|
c.Request.URL.Path == "/api/app/list" ||
|
||||||
|
c.Request.URL.Path == "/api/app/type/list" ||
|
||||||
|
c.Request.URL.Path == "/api/app/list/user" ||
|
||||||
c.Request.URL.Path == "/api/model/list" ||
|
c.Request.URL.Path == "/api/model/list" ||
|
||||||
c.Request.URL.Path == "/api/mj/imgWall" ||
|
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||||
c.Request.URL.Path == "/api/mj/client" ||
|
|
||||||
c.Request.URL.Path == "/api/mj/notify" ||
|
c.Request.URL.Path == "/api/mj/notify" ||
|
||||||
c.Request.URL.Path == "/api/invite/hits" ||
|
c.Request.URL.Path == "/api/invite/hits" ||
|
||||||
c.Request.URL.Path == "/api/sd/imgWall" ||
|
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||||
c.Request.URL.Path == "/api/sd/client" ||
|
|
||||||
c.Request.URL.Path == "/api/dall/imgWall" ||
|
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||||
c.Request.URL.Path == "/api/dall/client" ||
|
|
||||||
c.Request.URL.Path == "/api/product/list" ||
|
c.Request.URL.Path == "/api/product/list" ||
|
||||||
c.Request.URL.Path == "/api/menu/list" ||
|
c.Request.URL.Path == "/api/menu/list" ||
|
||||||
c.Request.URL.Path == "/api/markMap/client" ||
|
c.Request.URL.Path == "/api/markMap/client" ||
|
||||||
c.Request.URL.Path == "/api/payment/alipay/notify" ||
|
|
||||||
c.Request.URL.Path == "/api/payment/hupipay/notify" ||
|
|
||||||
c.Request.URL.Path == "/api/payment/payjs/notify" ||
|
|
||||||
c.Request.URL.Path == "/api/payment/wechat/notify" ||
|
|
||||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||||
c.Request.URL.Path == "/api/suno/client" ||
|
c.Request.URL.Path == "/api/suno/detail" ||
|
||||||
c.Request.URL.Path == "/api/suno/Detail" ||
|
|
||||||
c.Request.URL.Path == "/api/suno/play" ||
|
c.Request.URL.Path == "/api/suno/play" ||
|
||||||
|
c.Request.URL.Path == "/api/download" ||
|
||||||
|
c.Request.URL.Path == "/api/dall/models" ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||||
@@ -367,6 +371,7 @@ func staticResourceMiddleware() gin.HandlerFunc {
|
|||||||
// 直接输出图像数据流
|
// 直接输出图像数据流
|
||||||
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
||||||
c.Abort() // 中断请求
|
c.Abort() // 中断请求
|
||||||
|
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
BasePath: "./static/upload",
|
BasePath: "./static/upload",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
WeChatBot: false,
|
|
||||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,14 +9,15 @@ package types
|
|||||||
|
|
||||||
// ApiRequest API 请求实体
|
// ApiRequest API 请求实体
|
||||||
type ApiRequest struct {
|
type ApiRequest struct {
|
||||||
Model string `json:"model,omitempty"` // 兼容百度文心一言
|
Model string `json:"model,omitempty"`
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Stream bool `json:"stream"`
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||||
Messages []interface{} `json:"messages,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
Messages []interface{} `json:"messages,omitempty"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||||
|
ResponseFormat interface{} `json:"response_format,omitempty"` // 响应格式
|
||||||
|
|
||||||
ToolChoice string `json:"tool_choice,omitempty"`
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
|
|
||||||
@@ -52,16 +53,17 @@ type Delta struct {
|
|||||||
|
|
||||||
// ChatSession 聊天会话对象
|
// ChatSession 聊天会话对象
|
||||||
type ChatSession struct {
|
type ChatSession struct {
|
||||||
SessionId string `json:"session_id"`
|
UserId uint `json:"user_id"`
|
||||||
UserId uint `json:"user_id"`
|
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
Model ChatModel `json:"model"` // GPT 模型
|
||||||
Model ChatModel `json:"model"` // GPT 模型
|
Start int64 `json:"start"` // 开始请求时间戳
|
||||||
|
Tools []int `json:"tools"` // 工具函数列表
|
||||||
|
Stream bool `json:"stream"` // 是否采用流式输出
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatModel struct {
|
type ChatModel struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Platform string `json:"platform"`
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
Power int `json:"power"`
|
Power int `json:"power"`
|
||||||
@@ -91,7 +93,7 @@ const (
|
|||||||
PowerConsume = PowerType(2) // 消费
|
PowerConsume = PowerType(2) // 消费
|
||||||
PowerRefund = PowerType(3) // 任务(SD,MJ)执行失败,退款
|
PowerRefund = PowerType(3) // 任务(SD,MJ)执行失败,退款
|
||||||
PowerInvite = PowerType(4) // 邀请奖励
|
PowerInvite = PowerType(4) // 邀请奖励
|
||||||
PowerReward = PowerType(5) // 众筹
|
PowerRedeem = PowerType(5) // 众筹
|
||||||
PowerGift = PowerType(6) // 系统赠送
|
PowerGift = PowerType(6) // 系统赠送
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -103,9 +105,12 @@ func (t PowerType) String() string {
|
|||||||
return "消费"
|
return "消费"
|
||||||
case PowerRefund:
|
case PowerRefund:
|
||||||
return "退款"
|
return "退款"
|
||||||
case PowerReward:
|
case PowerRedeem:
|
||||||
return "众筹"
|
return "兑换"
|
||||||
|
case PowerGift:
|
||||||
|
return "赠送"
|
||||||
|
case PowerInvite:
|
||||||
|
return "邀请"
|
||||||
}
|
}
|
||||||
return "其他"
|
return "其他"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,15 +17,17 @@ var ErrConClosed = errors.New("connection Closed")
|
|||||||
|
|
||||||
// WsClient websocket client
|
// WsClient websocket client
|
||||||
type WsClient struct {
|
type WsClient struct {
|
||||||
|
Id string
|
||||||
Conn *websocket.Conn
|
Conn *websocket.Conn
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
mt int
|
mt int
|
||||||
Closed bool
|
Closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWsClient(conn *websocket.Conn) *WsClient {
|
func NewWsClient(conn *websocket.Conn, id string) *WsClient {
|
||||||
return &WsClient{
|
return &WsClient{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
|
Id: id,
|
||||||
lock: sync.Mutex{},
|
lock: sync.Mutex{},
|
||||||
mt: 2, // fixed bug for 'Invalid UTF-8 in text frame'
|
mt: 2, // fixed bug for 'Invalid UTF-8 in text frame'
|
||||||
Closed: false,
|
Closed: false,
|
||||||
|
|||||||
@@ -12,28 +12,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
Path string `toml:"-"`
|
Path string `toml:"-"`
|
||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
AdminSession Session
|
AdminSession Session
|
||||||
ProxyURL string
|
ProxyURL string
|
||||||
MysqlDns string // mysql 连接地址
|
MysqlDns string // mysql 连接地址
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||||
SMS SMSConfig // send mobile message config
|
SMS SMSConfig // send mobile message config
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
MjProxyConfigs []MjProxyConfig // MJ proxy config
|
SmtpConfig SmtpConfig // 邮件发送配置
|
||||||
MjPlusConfigs []MjPlusConfig // MJ plus config
|
|
||||||
WeChatBot bool // 是否启用微信机器人
|
|
||||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
|
||||||
|
|
||||||
XXLConfig XXLConfig
|
XXLConfig XXLConfig
|
||||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||||
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
||||||
SmtpConfig SmtpConfig // 邮件发送配置
|
GeekPayConfig GeekPayConfig // GEEK 支付配置
|
||||||
JPayConfig JPayConfig // payjs 支付配置
|
|
||||||
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
||||||
TikaHost string // TiKa 服务器地址
|
TikaHost string // TiKa 服务器地址
|
||||||
}
|
}
|
||||||
@@ -53,27 +48,6 @@ type ApiConfig struct {
|
|||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MjProxyConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
ApiURL string // api 地址
|
|
||||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
|
||||||
ApiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type StableDiffusionConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
Model string // 模型名称
|
|
||||||
ApiURL string
|
|
||||||
ApiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type MjPlusConfig struct {
|
|
||||||
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
|
||||||
ApiURL string // api 地址
|
|
||||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
|
||||||
ApiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type AlipayConfig struct {
|
type AlipayConfig struct {
|
||||||
Enabled bool // 是否启用该支付通道
|
Enabled bool // 是否启用该支付通道
|
||||||
SandBox bool // 是否沙盒环境
|
SandBox bool // 是否沙盒环境
|
||||||
@@ -83,8 +57,8 @@ type AlipayConfig struct {
|
|||||||
PublicKey string // 用户公钥文件路径
|
PublicKey string // 用户公钥文件路径
|
||||||
AlipayPublicKey string // 支付宝公钥文件路径
|
AlipayPublicKey string // 支付宝公钥文件路径
|
||||||
RootCert string // Root 秘钥路径
|
RootCert string // Root 秘钥路径
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知地址
|
||||||
ReturnURL string // 支付成功返回地址
|
ReturnURL string // 同步通知地址
|
||||||
}
|
}
|
||||||
|
|
||||||
type WechatPayConfig struct {
|
type WechatPayConfig struct {
|
||||||
@@ -94,29 +68,27 @@ type WechatPayConfig struct {
|
|||||||
SerialNo string // 商户证书的证书序列号
|
SerialNo string // 商户证书的证书序列号
|
||||||
PrivateKey string // 用户私钥文件路径
|
PrivateKey string // 用户私钥文件路径
|
||||||
ApiV3Key string // API V3 秘钥
|
ApiV3Key string // API V3 秘钥
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知地址
|
||||||
ReturnURL string // 支付成功返回地址
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||||
Enabled bool // 是否启用该支付通道
|
Enabled bool // 是否启用该支付通道
|
||||||
Name string // 支付名称,如:wechat/alipay
|
|
||||||
AppId string // App ID
|
AppId string // App ID
|
||||||
AppSecret string // app 密钥
|
AppSecret string // app 密钥
|
||||||
ApiURL string // 支付网关
|
ApiURL string // 支付网关
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知地址
|
||||||
ReturnURL string // 支付成功返回地址
|
ReturnURL string // 同步通知地址
|
||||||
}
|
}
|
||||||
|
|
||||||
// JPayConfig PayJs 支付配置
|
// GeekPayConfig GEEK支付配置
|
||||||
type JPayConfig struct {
|
type GeekPayConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Name string // 支付名称,默认 wechat
|
AppId string // 商户 ID
|
||||||
AppId string // 商户 ID
|
PrivateKey string // 私钥
|
||||||
PrivateKey string // 私钥
|
ApiURL string // API 网关
|
||||||
ApiURL string // API 网关
|
NotifyURL string // 异步通知地址
|
||||||
NotifyURL string // 异步回调地址
|
ReturnURL string // 同步通知地址
|
||||||
ReturnURL string // 支付成功返回地址
|
Methods []string // 支付方式
|
||||||
}
|
}
|
||||||
|
|
||||||
type XXLConfig struct { // XXL 任务调度配置
|
type XXLConfig struct { // XXL 任务调度配置
|
||||||
@@ -156,31 +128,30 @@ func (c RedisConfig) Url() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SystemConfig struct {
|
type SystemConfig struct {
|
||||||
Title string `json:"title,omitempty"` // 网站标题
|
Title string `json:"title,omitempty"` // 网站标题
|
||||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||||
Logo string `json:"logo,omitempty"`
|
Logo string `json:"logo,omitempty"` // 圆形 Logo
|
||||||
|
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
|
||||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||||
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
|
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||||
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
||||||
|
|
||||||
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
||||||
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||||
|
|
||||||
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
|
|
||||||
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
|
|
||||||
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
|
|
||||||
|
|
||||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||||
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
||||||
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
|
|
||||||
|
|
||||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||||
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
|
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
||||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||||
|
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||||
|
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||||
|
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
|
||||||
|
|
||||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||||
|
|
||||||
@@ -188,8 +159,15 @@ type SystemConfig struct {
|
|||||||
ContextDeep int `json:"context_deep,omitempty"`
|
ContextDeep int `json:"context_deep,omitempty"`
|
||||||
|
|
||||||
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||||
|
MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo
|
||||||
|
|
||||||
|
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||||
|
Copyright string `json:"copyright"` // 版权信息
|
||||||
|
ICP string `json:"icp"` // ICP 备案号
|
||||||
|
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
||||||
|
|
||||||
|
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||||
|
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||||
|
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
|
||||||
|
|
||||||
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
|
|
||||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
|
||||||
Copyright string `json:"copyright"` // 版权信息
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type MKey interface {
|
|||||||
string | int | uint
|
string | int | uint
|
||||||
}
|
}
|
||||||
type MValue interface {
|
type MValue interface {
|
||||||
*WsClient | *ChatSession | context.CancelFunc | []Message
|
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
||||||
}
|
}
|
||||||
type LMap[K MKey, T MValue] struct {
|
type LMap[K MKey, T MValue] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
|||||||
@@ -22,3 +22,18 @@ type OrderRemark struct {
|
|||||||
Price float64 `json:"price"`
|
Price float64 `json:"price"`
|
||||||
Discount float64 `json:"discount"`
|
Discount float64 `json:"discount"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var PayMethods = map[string]string{
|
||||||
|
"alipay": "支付宝商号",
|
||||||
|
"wechat": "微信商号",
|
||||||
|
"hupi": "虎皮椒",
|
||||||
|
"geek": "易支付",
|
||||||
|
}
|
||||||
|
var PayNames = map[string]string{
|
||||||
|
"alipay": "支付宝",
|
||||||
|
"wxpay": "微信支付",
|
||||||
|
"qqpay": "QQ钱包",
|
||||||
|
"jdpay": "京东支付",
|
||||||
|
"douyin": "抖音支付",
|
||||||
|
"paypal": "PayPal支付",
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,30 +24,35 @@ const (
|
|||||||
|
|
||||||
// MjTask MidJourney 任务
|
// MjTask MidJourney 任务
|
||||||
type MjTask struct {
|
type MjTask struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"` // 任务ID
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"` // 中转任务ID
|
||||||
ImgArr []string `json:"img_arr"`
|
ClientId string `json:"client_id"`
|
||||||
ChannelId string `json:"channel_id"`
|
ImgArr []string `json:"img_arr"`
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
NegPrompt string `json:"neg_prompt,omitempty"`
|
NegPrompt string `json:"neg_prompt,omitempty"`
|
||||||
Params string `json:"full_prompt"`
|
Params string `json:"full_prompt"`
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
MessageId string `json:"message_id,omitempty"`
|
MessageId string `json:"message_id,omitempty"`
|
||||||
MessageHash string `json:"message_hash,omitempty"`
|
MessageHash string `json:"message_hash,omitempty"`
|
||||||
RetryCount int `json:"retry_count"`
|
ChannelId string `json:"channel_id"` // 渠道ID,用来区分是哪个渠道创建的任务,一个任务的 create 和 action 操作必须要再同一个渠道
|
||||||
|
Mode string `json:"mode"` // 绘画模式,relax, fast, turbo
|
||||||
|
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type SdTask struct {
|
type SdTask struct {
|
||||||
Id int `json:"id"` // job 数据库ID
|
Id int `json:"id"` // job 数据库ID
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
ClientId string `json:"client_id"`
|
||||||
Params SdTaskParams `json:"params"`
|
UserId int `json:"user_id"`
|
||||||
RetryCount int `json:"retry_count"`
|
Params SdTaskParams `json:"params"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type SdTaskParams struct {
|
type SdTaskParams struct {
|
||||||
|
ClientId string `json:"client_id"` // 客户端ID
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
Prompt string `json:"prompt"` // 提示词
|
Prompt string `json:"prompt"` // 提示词
|
||||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||||
@@ -68,29 +73,63 @@ type SdTaskParams struct {
|
|||||||
|
|
||||||
// DallTask DALL-E task
|
// DallTask DALL-E task
|
||||||
type DallTask struct {
|
type DallTask struct {
|
||||||
JobId uint `json:"job_id"`
|
ClientId string `json:"client_id"`
|
||||||
UserId uint `json:"user_id"`
|
ModelId uint `json:"model_id"`
|
||||||
Prompt string `json:"prompt"`
|
ModelName string `json:"model_name"`
|
||||||
N int `json:"n"`
|
Id uint `json:"id"`
|
||||||
Quality string `json:"quality"`
|
UserId uint `json:"user_id"`
|
||||||
Size string `json:"size"`
|
Prompt string `json:"prompt"`
|
||||||
Style string `json:"style"`
|
N int `json:"n"`
|
||||||
|
Quality string `json:"quality"`
|
||||||
Power int `json:"power"`
|
Size string `json:"size"`
|
||||||
|
Style string `json:"style"`
|
||||||
|
Power int `json:"power"`
|
||||||
|
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type SunoTask struct {
|
type SunoTask struct {
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Channel string `json:"channel"`
|
Channel string `json:"channel"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Type int `json:"type"`
|
Type int `json:"type"`
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
RefTaskId string `json:"ref_task_id"`
|
RefTaskId string `json:"ref_task_id,omitempty"`
|
||||||
RefSongId string `json:"ref_song_id"`
|
RefSongId string `json:"ref_song_id,omitempty"`
|
||||||
Prompt string `json:"prompt"` // 提示词/歌词
|
Prompt string `json:"prompt"` // 提示词/歌词
|
||||||
Tags string `json:"tags"`
|
Tags string `json:"tags"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Instrumental bool `json:"instrumental"` // 是否纯音乐
|
Instrumental bool `json:"instrumental"` // 是否纯音乐
|
||||||
ExtendSecs int `json:"extend_secs"` // 延长秒杀
|
ExtendSecs int `json:"extend_secs,omitempty"` // 延长秒杀
|
||||||
|
SongId string `json:"song_id,omitempty"` // 合并歌曲ID
|
||||||
|
AudioURL string `json:"audio_url"` // 用户上传音频地址
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
VideoLuma = "luma"
|
||||||
|
VideoRunway = "runway"
|
||||||
|
VideoCog = "cog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VideoTask struct {
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Channel string `json:"channel"`
|
||||||
|
UserId int `json:"user_id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
Prompt string `json:"prompt"` // 提示词
|
||||||
|
Params VideoParams `json:"params"`
|
||||||
|
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||||
|
}
|
||||||
|
|
||||||
|
type VideoParams struct {
|
||||||
|
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
|
||||||
|
Loop bool `json:"loop"` // 是否循环参考图
|
||||||
|
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
|
||||||
|
EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址
|
||||||
|
Model string `json:"model"` // 使用哪个模型生成视频
|
||||||
|
Radio string `json:"radio"` // 视频尺寸
|
||||||
|
Style string `json:"style"` // 风格
|
||||||
|
Duration int `json:"duration"` // 视频时长(秒)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,21 +17,48 @@ type BizVo struct {
|
|||||||
Data interface{} `json:"data,omitempty"`
|
Data interface{} `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// WsMessage Websocket message
|
// ReplyMessage 对话回复消息结构
|
||||||
type WsMessage struct {
|
type ReplyMessage struct {
|
||||||
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat
|
||||||
Content interface{} `json:"content"`
|
ClientId string `json:"clientId"` // 客户端ID
|
||||||
|
Type WsMsgType `json:"type"` // 消息类别
|
||||||
|
Body interface{} `json:"body"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WsMsgType string
|
type WsMsgType string
|
||||||
|
type WsChannel string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
WsStart = WsMsgType("start")
|
MsgTypeText = WsMsgType("text") // 输出内容
|
||||||
WsMiddle = WsMsgType("middle")
|
MsgTypeEnd = WsMsgType("end")
|
||||||
WsEnd = WsMsgType("end")
|
MsgTypeErr = WsMsgType("error")
|
||||||
WsErr = WsMsgType("error")
|
MsgTypePing = WsMsgType("ping") // 心跳消息
|
||||||
|
|
||||||
|
ChPing = WsChannel("ping")
|
||||||
|
ChChat = WsChannel("chat")
|
||||||
|
ChMj = WsChannel("mj")
|
||||||
|
ChSd = WsChannel("sd")
|
||||||
|
ChDall = WsChannel("dall")
|
||||||
|
ChSuno = WsChannel("suno")
|
||||||
|
ChLuma = WsChannel("luma")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// InputMessage 对话输入消息结构
|
||||||
|
type InputMessage struct {
|
||||||
|
Channel WsChannel `json:"channel"` // 消息频道
|
||||||
|
Type WsMsgType `json:"type"` // 消息类别
|
||||||
|
Body interface{} `json:"body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessage struct {
|
||||||
|
Tools []int `json:"tools,omitempty"` // 允许调用工具列表
|
||||||
|
Stream bool `json:"stream,omitempty"` // 是否采用流式输出
|
||||||
|
RoleId int `json:"role_id"`
|
||||||
|
ModelId int `json:"model_id"`
|
||||||
|
ChatId string `json:"chat_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
type BizCode int
|
type BizCode int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ require (
|
|||||||
github.com/BurntSushi/toml v1.1.0
|
github.com/BurntSushi/toml v1.1.0
|
||||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
|
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
|
||||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
|
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
|
||||||
github.com/eatmoreapple/openwechat v1.2.1
|
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0
|
github.com/golang-jwt/jwt/v5 v5.0.0
|
||||||
@@ -30,7 +29,6 @@ require (
|
|||||||
github.com/go-pay/gopay v1.5.101
|
github.com/go-pay/gopay v1.5.101
|
||||||
github.com/google/go-tika v0.3.1
|
github.com/google/go-tika v0.3.1
|
||||||
github.com/microcosm-cc/bluemonday v1.0.26
|
github.com/microcosm-cc/bluemonday v1.0.26
|
||||||
github.com/mojocn/base64Captcha v1.3.6
|
|
||||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
github.com/shopspring/decimal v1.3.1
|
github.com/shopspring/decimal v1.3.1
|
||||||
github.com/syndtr/goleveldb v1.0.0
|
github.com/syndtr/goleveldb v1.0.0
|
||||||
@@ -45,9 +43,13 @@ require (
|
|||||||
github.com/go-pay/util v0.0.2 // indirect
|
github.com/go-pay/util v0.0.2 // indirect
|
||||||
github.com/go-pay/xlog v0.0.2 // indirect
|
github.com/go-pay/xlog v0.0.2 // indirect
|
||||||
github.com/go-pay/xtime v0.0.2 // indirect
|
github.com/go-pay/xtime v0.0.2 // indirect
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||||
github.com/gorilla/css v1.0.0 // indirect
|
github.com/gorilla/css v1.0.0 // indirect
|
||||||
|
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99 // indirect
|
||||||
|
github.com/howeyc/fsnotify v0.9.0 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
|
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a // indirect
|
||||||
|
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
|
|||||||
26
api/go.sum
26
api/go.sum
@@ -28,8 +28,6 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
|
|||||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
|
|
||||||
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
|
||||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||||
@@ -84,8 +82,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
|
||||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||||
@@ -104,11 +100,15 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
|
|||||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99 h1:A6qlLfihaWef15viqtecCz4XknZcgjgD7mEuhu7bHEc=
|
||||||
|
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99/go.mod h1:ukFDwXV66bGV7JnfyxFKuKiVp4zH4orBKXML+VCSrhI=
|
||||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||||
|
github.com/howeyc/fsnotify v0.9.0 h1:0gtV5JmOKH4A8SsFxG2BczSeXWWPvcMT0euZt5gDAxY=
|
||||||
|
github.com/howeyc/fsnotify v0.9.0/go.mod h1:41HzSPxBGeFRQKEEwgh49TRw/nKBsYZ2cF1OzPjSJsA=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
|
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
|
||||||
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
|
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
|
||||||
@@ -141,6 +141,9 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
|||||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs=
|
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs=
|
||||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
|
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
|
||||||
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
|
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
|
||||||
@@ -157,8 +160,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
|||||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/mojocn/base64Captcha v1.3.6 h1:gZEKu1nsKpttuIAQgWHO+4Mhhls8cAKyiV2Ew03H+Tw=
|
|
||||||
github.com/mojocn/base64Captcha v1.3.6/go.mod h1:i5CtHvm+oMbj1UzEPXaA8IH/xHFZ3DGY3Wh3dBpZ28E=
|
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||||
@@ -176,6 +177,10 @@ github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:Ff
|
|||||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
|
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||||
|
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a h1:Tg4E4cXPZSZyd3H1tJlYo6ZreXV0ZJvE/lorNqyw1AU=
|
||||||
|
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a/go.mod h1:9Or9aIl95Kp43zONcHd5tLZGKXb9iLx0pZjau0uJ5zg=
|
||||||
|
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99 h1:+X7Gb40b5Bl3v5+3MiGK8Jhemjp65MHc+nkVCfq1Yfc=
|
||||||
|
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99/go.mod h1:2LLTtftTZSdAPR/iVyennXZDLZOYzyDn+T0qEKJ8eSw=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
@@ -220,12 +225,10 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt
|
|||||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
||||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
|
||||||
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
|
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
|
||||||
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
|
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
|
||||||
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
|
|
||||||
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
|
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
|
||||||
github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
|
|
||||||
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
|
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
|
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
|
||||||
@@ -267,7 +270,6 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
|||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||||
golang.org/x/image v0.13.0/go.mod h1:6mmbMOeV28HuMTgA6OSRkdXKYw/t5W9Uwn2Yv1r3Yxk=
|
|
||||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
@@ -300,6 +302,7 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
@@ -323,7 +326,6 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
|||||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
@@ -28,33 +29,49 @@ import (
|
|||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
// Manager 管理员
|
|
||||||
type Manager struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
Captcha string `json:"captcha"` // 验证码
|
|
||||||
CaptchaId string `json:"captcha_id"` // 验证码id
|
|
||||||
}
|
|
||||||
|
|
||||||
const SuperManagerID = 1
|
const SuperManagerID = 1
|
||||||
|
|
||||||
type ManagerHandler struct {
|
type ManagerHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
|
captcha *service.CaptchaService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
|
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, captcha *service.CaptchaService) *ManagerHandler {
|
||||||
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
|
return &ManagerHandler{
|
||||||
|
BaseHandler: handler.BaseHandler{DB: db, App: app},
|
||||||
|
redis: client,
|
||||||
|
captcha: captcha,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login 登录
|
// Login 登录
|
||||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||||
var data Manager
|
var data struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Key string `json:"key,omitempty"`
|
||||||
|
Dots string `json:"dots,omitempty"`
|
||||||
|
X int `json:"x,omitempty"`
|
||||||
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.EnabledVerify {
|
||||||
|
var check bool
|
||||||
|
if data.X != 0 {
|
||||||
|
check = h.captcha.SlideCheck(data)
|
||||||
|
} else {
|
||||||
|
check = h.captcha.Check(data)
|
||||||
|
}
|
||||||
|
if !check {
|
||||||
|
resp.ERROR(c, "请先完人机验证")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var manager model.AdminUser
|
var manager model.AdminUser
|
||||||
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ package admin
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -53,17 +55,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
apiKey.Enabled = data.Enabled
|
apiKey.Enabled = data.Enabled
|
||||||
apiKey.ProxyURL = data.ProxyURL
|
apiKey.ProxyURL = data.ProxyURL
|
||||||
apiKey.Name = data.Name
|
apiKey.Name = data.Name
|
||||||
res := h.DB.Save(&apiKey)
|
err := h.DB.Save(&apiKey).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var keyVo vo.ApiKey
|
var keyVo vo.ApiKey
|
||||||
err := utils.CopyObject(apiKey, &keyVo)
|
err = utils.CopyObject(apiKey, &keyVo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "数据拷贝失败!")
|
resp.ERROR(c, fmt.Sprintf("拷贝数据失败:%v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keyVo.Id = apiKey.Id
|
keyVo.Id = apiKey.Id
|
||||||
@@ -71,20 +72,18 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, keyVo)
|
resp.SUCCESS(c, keyVo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List 获取 API KEY 列表
|
||||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||||
status := h.GetBool(c, "status")
|
status := h.GetBool(c, "status")
|
||||||
t := h.GetTrim(c, "type")
|
t := c.Query("type")
|
||||||
platform := h.GetTrim(c, "platform")
|
|
||||||
|
|
||||||
session := h.DB.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if status {
|
if status {
|
||||||
session = session.Where("enabled", true)
|
session = session.Where("enabled", true)
|
||||||
}
|
}
|
||||||
if t != "" {
|
if t != "" {
|
||||||
session = session.Where("type", t)
|
types := strings.Split(t, "|")
|
||||||
}
|
session = session.Where("type IN ?", types)
|
||||||
if platform != "" {
|
|
||||||
session = session.Where("platform", platform)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var items []model.ApiKey
|
var items []model.ApiKey
|
||||||
@@ -119,10 +118,9 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
err := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -135,10 +133,9 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
|
err := h.DB.Where("id", id).Delete(&model.ApiKey{}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ package admin
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
@@ -21,16 +22,16 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatAppHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||||
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save 创建或者更新某个角色
|
// Save 创建或者更新某个角色
|
||||||
func (h *ChatRoleHandler) Save(c *gin.Context) {
|
func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||||
var data vo.ChatRole
|
var data vo.ChatRole
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -45,11 +46,16 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
|||||||
role.Id = data.Id
|
role.Id = data.Id
|
||||||
if data.CreatedAt > 0 {
|
if data.CreatedAt > 0 {
|
||||||
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
|
} else {
|
||||||
|
err = h.DB.Where("marker", data.Key).First(&role).Error
|
||||||
|
if err == nil {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
res := h.DB.Save(&role)
|
err = h.DB.Save(&role).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 填充 ID 数据
|
// 填充 ID 数据
|
||||||
@@ -58,7 +64,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, data)
|
resp.SUCCESS(c, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatRole
|
var items []model.ChatRole
|
||||||
var roles = make([]vo.ChatRole, 0)
|
var roles = make([]vo.ChatRole, 0)
|
||||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||||
@@ -69,13 +75,18 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
// initialize model mane for role
|
// initialize model mane for role
|
||||||
modelIds := make([]int, 0)
|
modelIds := make([]int, 0)
|
||||||
|
typeIds := make([]int, 0)
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
if v.ModelId > 0 {
|
if v.ModelId > 0 {
|
||||||
modelIds = append(modelIds, v.ModelId)
|
modelIds = append(modelIds, v.ModelId)
|
||||||
}
|
}
|
||||||
|
if v.Tid > 0 {
|
||||||
|
typeIds = append(typeIds, v.Tid)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelNameMap := make(map[int]string)
|
modelNameMap := make(map[int]string)
|
||||||
|
typeNameMap := make(map[int]string)
|
||||||
if len(modelIds) > 0 {
|
if len(modelIds) > 0 {
|
||||||
var models []model.ChatModel
|
var models []model.ChatModel
|
||||||
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
||||||
@@ -85,6 +96,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(typeIds) > 0 {
|
||||||
|
var appTypes []model.AppType
|
||||||
|
tx := h.DB.Where("id IN ?", typeIds).Find(&appTypes)
|
||||||
|
if tx.Error == nil {
|
||||||
|
for _, m := range appTypes {
|
||||||
|
typeNameMap[int(m.Id)] = m.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var role vo.ChatRole
|
var role vo.ChatRole
|
||||||
@@ -94,6 +114,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
role.CreatedAt = v.CreatedAt.Unix()
|
role.CreatedAt = v.CreatedAt.Unix()
|
||||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||||
role.ModelName = modelNameMap[role.ModelId]
|
role.ModelName = modelNameMap[role.ModelId]
|
||||||
|
role.TypeName = typeNameMap[role.Tid]
|
||||||
roles = append(roles, role)
|
roles = append(roles, role)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,7 +123,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sort 更新角色排序
|
// Sort 更新角色排序
|
||||||
func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
func (h *ChatAppHandler) Sort(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Ids []uint `json:"ids"`
|
Ids []uint `json:"ids"`
|
||||||
Sorts []int `json:"sorts"`
|
Sorts []int `json:"sorts"`
|
||||||
@@ -114,10 +135,9 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -125,7 +145,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatRoleHandler) Set(c *gin.Context) {
|
func (h *ChatAppHandler) Set(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Filed string `json:"filed"`
|
Filed string `json:"filed"`
|
||||||
@@ -137,16 +157,15 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
func (h *ChatAppHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id <= 0 {
|
if id <= 0 {
|
||||||
148
api/handler/admin/chat_app_type_handler.go
Normal file
148
api/handler/admin/chat_app_type_handler.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatAppTypeHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
|
||||||
|
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save 创建或更新App类型
|
||||||
|
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
SortNum int `json:"sort_num"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Id == 0 { // for add
|
||||||
|
err := h.DB.Where("name", data.Name).First(&model.AppType{}).Error
|
||||||
|
if err == nil {
|
||||||
|
resp.ERROR(c, "当前分类已经存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&model.AppType{
|
||||||
|
Name: data.Name,
|
||||||
|
Icon: data.Icon,
|
||||||
|
Enabled: data.Enabled,
|
||||||
|
SortNum: data.SortNum,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else { // for update
|
||||||
|
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).Updates(map[string]interface{}{
|
||||||
|
"name": data.Name,
|
||||||
|
"icon": data.Icon,
|
||||||
|
"enabled": data.Enabled,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 获取App类型列表
|
||||||
|
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||||
|
var items []model.AppType
|
||||||
|
var appTypes = make([]vo.AppType, 0)
|
||||||
|
err := h.DB.Order("sort_num ASC").Find(&items).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range items {
|
||||||
|
var appType vo.AppType
|
||||||
|
err = utils.CopyObject(v, &appType)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
appType.Id = v.Id
|
||||||
|
appType.CreatedAt = v.CreatedAt.Unix()
|
||||||
|
appTypes = append(appTypes, appType)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, appTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove 删除App类型
|
||||||
|
func (h *ChatAppTypeHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.DB.Where("id", id).Delete(&model.AppType{}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 启用|禁用
|
||||||
|
func (h *ChatAppTypeHandler) Enable(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort 更新排序
|
||||||
|
func (h *ChatAppTypeHandler) Sort(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Ids []uint `json:"ids"`
|
||||||
|
Sorts []int `json:"sorts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, id := range data.Ids {
|
||||||
|
err := h.DB.Model(&model.AppType{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
@@ -259,10 +259,9 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
|
|||||||
// RemoveMessage 删除聊天记录
|
// RemoveMessage 删除聊天记录
|
||||||
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
|
err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}).Error
|
||||||
if tx.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", tx.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
|||||||
Temperature float32 `json:"temperature"` // 模型温度
|
Temperature float32 `json:"temperature"` // 模型温度
|
||||||
KeyId int `json:"key_id,omitempty"`
|
KeyId int `json:"key_id,omitempty"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -65,7 +66,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
|||||||
item.MaxContext = data.MaxContext
|
item.MaxContext = data.MaxContext
|
||||||
item.Temperature = data.Temperature
|
item.Temperature = data.Temperature
|
||||||
item.KeyId = data.KeyId
|
item.KeyId = data.KeyId
|
||||||
|
item.Type = data.Type
|
||||||
var res *gorm.DB
|
var res *gorm.DB
|
||||||
if data.Id > 0 {
|
if data.Id > 0 {
|
||||||
res = h.DB.Save(&item)
|
res = h.DB.Save(&item)
|
||||||
@@ -147,10 +148,9 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
err := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -168,10 +168,9 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
err := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -186,10 +185,9 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
|
err := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/mj"
|
|
||||||
"geekai/service/sd"
|
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
@@ -28,16 +26,12 @@ type ConfigHandler struct {
|
|||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
levelDB *store.LevelDB
|
levelDB *store.LevelDB
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
mjServicePool *mj.ServicePool
|
|
||||||
sdServicePool *sd.ServicePool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
|
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
|
||||||
return &ConfigHandler{
|
return &ConfigHandler{
|
||||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
levelDB: levelDB,
|
levelDB: levelDB,
|
||||||
mjServicePool: mjPool,
|
|
||||||
sdServicePool: sdPool,
|
|
||||||
licenseService: licenseService,
|
licenseService: licenseService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -147,57 +141,69 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, license)
|
resp.SUCCESS(c, license)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAppConfig 获取内置配置
|
// FixData 修复数据
|
||||||
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
|
func (h *ConfigHandler) FixData(c *gin.Context) {
|
||||||
resp.SUCCESS(c, gin.H{
|
resp.ERROR(c, "当前升级版本没有数据需要修正!")
|
||||||
"mj_plus": h.App.Config.MjPlusConfigs,
|
return
|
||||||
"mj_proxy": h.App.Config.MjProxyConfigs,
|
//var fixed bool
|
||||||
"sd": h.App.Config.SdConfigs,
|
//version := "data_fix_4.1.4"
|
||||||
})
|
//err := h.levelDB.Get(version, &fixed)
|
||||||
}
|
//if err == nil || fixed {
|
||||||
|
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
||||||
// SaveDrawingConfig 保存AI绘画配置
|
// return
|
||||||
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
|
//}
|
||||||
var data struct {
|
//tx := h.DB.Begin()
|
||||||
Sd []types.StableDiffusionConfig `json:"sd"`
|
//var users []model.User
|
||||||
MjPlus []types.MjPlusConfig `json:"mj_plus"`
|
//err = tx.Find(&users).Error
|
||||||
MjProxy []types.MjProxyConfig `json:"mj_proxy"`
|
//if err != nil {
|
||||||
}
|
// resp.ERROR(c, err.Error())
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
// return
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
//}
|
||||||
return
|
//for _, user := range users {
|
||||||
}
|
// if user.Email != "" || user.Mobile != "" {
|
||||||
|
// continue
|
||||||
changed := false
|
// }
|
||||||
if configChanged(data.Sd, h.App.Config.SdConfigs) {
|
// if utils.IsValidEmail(user.Username) {
|
||||||
logger.Debugf("SD 配置变动了")
|
// user.Email = user.Username
|
||||||
h.App.Config.SdConfigs = data.Sd
|
// } else if utils.IsValidMobile(user.Username) {
|
||||||
h.sdServicePool.InitServices(data.Sd)
|
// user.Mobile = user.Username
|
||||||
changed = true
|
// }
|
||||||
}
|
// err = tx.Save(&user).Error
|
||||||
|
// if err != nil {
|
||||||
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
|
// resp.ERROR(c, err.Error())
|
||||||
logger.Debugf("MidJourney 配置变动了")
|
// tx.Rollback()
|
||||||
h.App.Config.MjPlusConfigs = data.MjPlus
|
// return
|
||||||
h.App.Config.MjProxyConfigs = data.MjProxy
|
// }
|
||||||
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
|
//}
|
||||||
changed = true
|
//
|
||||||
}
|
//var orders []model.Order
|
||||||
|
//err = h.DB.Find(&orders).Error
|
||||||
if changed {
|
//if err != nil {
|
||||||
err := core.SaveConfig(h.App.Config)
|
// resp.ERROR(c, err.Error())
|
||||||
if err != nil {
|
// return
|
||||||
resp.ERROR(c, "更新配置文档失败!")
|
//}
|
||||||
return
|
//for _, order := range orders {
|
||||||
}
|
// if order.PayWay == "支付宝" {
|
||||||
}
|
// order.PayWay = "alipay"
|
||||||
|
// order.PayType = "alipay"
|
||||||
resp.SUCCESS(c)
|
// } else if order.PayWay == "微信支付" {
|
||||||
|
// order.PayWay = "wechat"
|
||||||
}
|
// order.PayType = "wxpay"
|
||||||
|
// } else if order.PayWay == "hupi" {
|
||||||
func configChanged(c1 interface{}, c2 interface{}) bool {
|
// order.PayType = "wxpay"
|
||||||
encode1 := utils.JsonEncode(c1)
|
// }
|
||||||
encode2 := utils.JsonEncode(c2)
|
// err = tx.Save(&order).Error
|
||||||
return utils.Md5(encode1) != utils.Md5(encode2)
|
// if err != nil {
|
||||||
|
// resp.ERROR(c, err.Error())
|
||||||
|
// tx.Rollback()
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
//tx.Commit()
|
||||||
|
//err = h.levelDB.Put(version, true)
|
||||||
|
//if err != nil {
|
||||||
|
// resp.ERROR(c, err.Error())
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
//resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,13 +60,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
|||||||
stats.Tokens += item.Tokens
|
stats.Tokens += item.Tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// 众筹收入
|
|
||||||
var rewards []model.Reward
|
|
||||||
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
|
|
||||||
for _, item := range rewards {
|
|
||||||
stats.Income += item.Amount
|
|
||||||
}
|
|
||||||
|
|
||||||
// 订单收入
|
// 订单收入
|
||||||
var orders []model.Order
|
var orders []model.Order
|
||||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
||||||
@@ -101,13 +94,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
|||||||
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 浮点数相加?
|
|
||||||
// 统计最近7天的众筹
|
|
||||||
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
|
|
||||||
for _, item := range rewards {
|
|
||||||
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统计最近7天的订单
|
// 统计最近7天的订单
|
||||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
||||||
for _, item := range orders {
|
for _, item := range orders {
|
||||||
|
|||||||
@@ -69,10 +69,9 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
err := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -102,10 +101,9 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.DB.Delete(&model.Function{Id: uint(id)})
|
err := h.DB.Delete(&model.Function{Id: uint(id)}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
254
api/handler/admin/image_handler.go
Normal file
254
api/handler/admin/image_handler.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
userService *service.UserService
|
||||||
|
uploader *oss.UploaderManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
|
||||||
|
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageQuery struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
CreatedAt []string `json:"created_at"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MjList Midjourney 任务列表
|
||||||
|
func (h *ImageHandler) MjList(c *gin.Context) {
|
||||||
|
var data imageQuery
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Username != "" {
|
||||||
|
var user model.User
|
||||||
|
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||||
|
if err == nil {
|
||||||
|
session = session.Where("user_id", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Prompt != "" {
|
||||||
|
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||||
|
}
|
||||||
|
if len(data.CreatedAt) == 2 {
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.MidJourneyJob{}).Count(&total)
|
||||||
|
var list []model.MidJourneyJob
|
||||||
|
var items = make([]vo.MidJourneyJob, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||||
|
if err == nil {
|
||||||
|
// 填充数据
|
||||||
|
for _, item := range list {
|
||||||
|
var job vo.MidJourneyJob
|
||||||
|
err = utils.CopyObject(item, &job)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
job.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
items = append(items, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SdList Stable Diffusion 任务列表
|
||||||
|
func (h *ImageHandler) SdList(c *gin.Context) {
|
||||||
|
var data imageQuery
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Username != "" {
|
||||||
|
var user model.User
|
||||||
|
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||||
|
if err == nil {
|
||||||
|
session = session.Where("user_id", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Prompt != "" {
|
||||||
|
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||||
|
}
|
||||||
|
if len(data.CreatedAt) == 2 {
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.SdJob{}).Count(&total)
|
||||||
|
var list []model.SdJob
|
||||||
|
var items = make([]vo.SdJob, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||||
|
if err == nil {
|
||||||
|
// 填充数据
|
||||||
|
for _, item := range list {
|
||||||
|
var job vo.SdJob
|
||||||
|
err = utils.CopyObject(item, &job)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
job.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
items = append(items, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DallList DALL-E 任务列表
|
||||||
|
func (h *ImageHandler) DallList(c *gin.Context) {
|
||||||
|
var data imageQuery
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Username != "" {
|
||||||
|
var user model.User
|
||||||
|
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||||
|
if err == nil {
|
||||||
|
session = session.Where("user_id", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Prompt != "" {
|
||||||
|
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||||
|
}
|
||||||
|
if len(data.CreatedAt) == 2 {
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.DallJob{}).Count(&total)
|
||||||
|
var list []model.DallJob
|
||||||
|
var items = make([]vo.DallJob, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||||
|
if err == nil {
|
||||||
|
// 填充数据
|
||||||
|
for _, item := range list {
|
||||||
|
var job vo.DallJob
|
||||||
|
err = utils.CopyObject(item, &job)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
job.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
items = append(items, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ImageHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
tab := c.Query("tab")
|
||||||
|
|
||||||
|
tx := h.DB.Begin()
|
||||||
|
var md, remark, imgURL string
|
||||||
|
var power, userId, progress int
|
||||||
|
switch tab {
|
||||||
|
case "mj":
|
||||||
|
var job model.MidJourneyJob
|
||||||
|
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
|
||||||
|
resp.ERROR(c, "记录不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Delete(&job)
|
||||||
|
md = "mid-journey"
|
||||||
|
power = job.Power
|
||||||
|
userId = job.UserId
|
||||||
|
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||||
|
progress = job.Progress
|
||||||
|
imgURL = job.ImgURL
|
||||||
|
break
|
||||||
|
case "sd":
|
||||||
|
var job model.SdJob
|
||||||
|
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||||
|
resp.ERROR(c, "记录不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除任务
|
||||||
|
tx.Delete(&job)
|
||||||
|
md = "stable-diffusion"
|
||||||
|
power = job.Power
|
||||||
|
userId = job.UserId
|
||||||
|
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||||
|
progress = job.Progress
|
||||||
|
imgURL = job.ImgURL
|
||||||
|
break
|
||||||
|
case "dall":
|
||||||
|
var job model.DallJob
|
||||||
|
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||||
|
resp.ERROR(c, "记录不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除任务
|
||||||
|
tx.Delete(&job)
|
||||||
|
md = "dall-e-3"
|
||||||
|
power = job.Power
|
||||||
|
userId = int(job.UserId)
|
||||||
|
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||||
|
progress = job.Progress
|
||||||
|
imgURL = job.ImgURL
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if progress != 100 {
|
||||||
|
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: md,
|
||||||
|
Remark: remark,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
// remove image
|
||||||
|
err := h.uploader.GetUploadHandler().Delete(imgURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("remove image failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
200
api/handler/admin/media_handler.go
Normal file
200
api/handler/admin/media_handler.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MediaHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
userService *service.UserService
|
||||||
|
uploader *oss.UploaderManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler {
|
||||||
|
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mediaQuery struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
CreatedAt []string `json:"created_at"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SunoList Suno 任务列表
|
||||||
|
func (h *MediaHandler) SunoList(c *gin.Context) {
|
||||||
|
var data mediaQuery
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Username != "" {
|
||||||
|
var user model.User
|
||||||
|
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||||
|
if err == nil {
|
||||||
|
session = session.Where("user_id", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Prompt != "" {
|
||||||
|
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||||
|
}
|
||||||
|
if len(data.CreatedAt) == 2 {
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.SunoJob{}).Count(&total)
|
||||||
|
var list []model.SunoJob
|
||||||
|
var items = make([]vo.SunoJob, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||||
|
if err == nil {
|
||||||
|
// 填充数据
|
||||||
|
for _, item := range list {
|
||||||
|
var job vo.SunoJob
|
||||||
|
err = utils.CopyObject(item, &job)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
job.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
items = append(items, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LumaList Luma 视频任务列表
|
||||||
|
func (h *MediaHandler) LumaList(c *gin.Context) {
|
||||||
|
var data mediaQuery
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Username != "" {
|
||||||
|
var user model.User
|
||||||
|
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||||
|
if err == nil {
|
||||||
|
session = session.Where("user_id", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Prompt != "" {
|
||||||
|
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||||
|
}
|
||||||
|
if len(data.CreatedAt) == 2 {
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.VideoJob{}).Count(&total)
|
||||||
|
var list []model.VideoJob
|
||||||
|
var items = make([]vo.VideoJob, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||||
|
if err == nil {
|
||||||
|
// 填充数据
|
||||||
|
for _, item := range list {
|
||||||
|
var job vo.VideoJob
|
||||||
|
err = utils.CopyObject(item, &job)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
job.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
if job.VideoURL == "" {
|
||||||
|
job.VideoURL = job.WaterURL
|
||||||
|
}
|
||||||
|
items = append(items, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MediaHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
tab := c.Query("tab")
|
||||||
|
|
||||||
|
tx := h.DB.Begin()
|
||||||
|
var md, remark, fileURL string
|
||||||
|
var power, userId, progress int
|
||||||
|
switch tab {
|
||||||
|
case "suno":
|
||||||
|
var job model.SunoJob
|
||||||
|
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
|
||||||
|
resp.ERROR(c, "记录不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Delete(&job)
|
||||||
|
md = "suno"
|
||||||
|
power = job.Power
|
||||||
|
userId = job.UserId
|
||||||
|
remark = fmt.Sprintf("SUNO 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||||
|
progress = job.Progress
|
||||||
|
fileURL = job.AudioURL
|
||||||
|
break
|
||||||
|
case "luma":
|
||||||
|
var job model.VideoJob
|
||||||
|
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||||
|
resp.ERROR(c, "记录不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除任务
|
||||||
|
tx.Delete(&job)
|
||||||
|
md = job.Type
|
||||||
|
power = job.Power
|
||||||
|
userId = job.UserId
|
||||||
|
remark = fmt.Sprintf("LUMA 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||||
|
progress = job.Progress
|
||||||
|
fileURL = job.VideoURL
|
||||||
|
if fileURL == "" {
|
||||||
|
fileURL = job.WaterURL
|
||||||
|
}
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if progress != 100 {
|
||||||
|
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: md,
|
||||||
|
Remark: remark,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
// remove image
|
||||||
|
err := h.uploader.GetUploadHandler().Delete(fileURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("remove image failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
@@ -41,17 +41,16 @@ func (h *MenuHandler) Save(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Save(&model.Menu{
|
err := h.DB.Save(&model.Menu{
|
||||||
Id: data.Id,
|
Id: data.Id,
|
||||||
Name: data.Name,
|
Name: data.Name,
|
||||||
Icon: data.Icon,
|
Icon: data.Icon,
|
||||||
URL: data.URL,
|
URL: data.URL,
|
||||||
SortNum: data.SortNum,
|
SortNum: data.SortNum,
|
||||||
Enabled: data.Enabled,
|
Enabled: data.Enabled,
|
||||||
})
|
}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -85,10 +84,9 @@ func (h *MenuHandler) Enable(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
err := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -106,10 +104,9 @@ func (h *MenuHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
|
err := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -121,10 +118,9 @@ func (h *MenuHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.DB.Where("id", id).Delete(&model.Menu{})
|
err := h.DB.Where("id", id).Delete(&model.Menu{}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -67,6 +68,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
order.Id = item.Id
|
order.Id = item.Id
|
||||||
order.CreatedAt = item.CreatedAt.Unix()
|
order.CreatedAt = item.CreatedAt.Unix()
|
||||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||||
|
payMethod, ok := types.PayMethods[item.PayWay]
|
||||||
|
if !ok {
|
||||||
|
payMethod = item.PayWay
|
||||||
|
}
|
||||||
|
payName, ok := types.PayNames[item.PayType]
|
||||||
|
if !ok {
|
||||||
|
payName = item.PayWay
|
||||||
|
}
|
||||||
|
order.PayMethod = payMethod
|
||||||
|
order.PayName = payName
|
||||||
list = append(list, order)
|
list = append(list, order)
|
||||||
} else {
|
} else {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -92,12 +103,33 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
|
err := h.DB.Where("id = ?", id).Delete(&model.Order{}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *OrderHandler) Clear(c *gin.Context) {
|
||||||
|
var orders []model.Order
|
||||||
|
err := h.DB.Where("status <> ?", 2).Where("pay_time", 0).Find(&orders).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
deleteIds := make([]uint, 0)
|
||||||
|
for _, order := range orders {
|
||||||
|
// 只删除 15 分钟内的未支付订单
|
||||||
|
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
|
||||||
|
deleteIds = append(deleteIds, order.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = h.DB.Where("id IN ?", deleteIds).Delete(&model.Order{}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
|||||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
UserId uint `json:"userid"`
|
||||||
Type int `json:"type"`
|
Type int `json:"type"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Date []string `json:"date"`
|
Date []string `json:"date"`
|
||||||
@@ -49,6 +50,12 @@ func (h *PowerLogHandler) List(c *gin.Context) {
|
|||||||
if data.Type > 0 {
|
if data.Type > 0 {
|
||||||
session = session.Where("type", data.Type)
|
session = session.Where("type", data.Type)
|
||||||
}
|
}
|
||||||
|
if data.UserId > 0 {
|
||||||
|
session = session.Where("user_id", data.UserId)
|
||||||
|
}
|
||||||
|
if data.Username != "" {
|
||||||
|
session = session.Where("username", data.Username)
|
||||||
|
}
|
||||||
if len(data.Date) == 2 {
|
if len(data.Date) == 2 {
|
||||||
start := data.Date[0] + " 00:00:00"
|
start := data.Date[0] + " 00:00:00"
|
||||||
end := data.Date[1] + " 00:00:00"
|
end := data.Date[1] + " 00:00:00"
|
||||||
|
|||||||
@@ -55,17 +55,16 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
if item.Id > 0 {
|
if item.Id > 0 {
|
||||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
}
|
}
|
||||||
res := h.DB.Save(&item)
|
err := h.DB.Save(&item).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var itemVo vo.Product
|
var itemVo vo.Product
|
||||||
err := utils.CopyObject(item, &itemVo)
|
err = utils.CopyObject(item, &itemVo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "数据拷贝失败!")
|
resp.ERROR(c, "数据拷贝失败: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
itemVo.Id = item.Id
|
itemVo.Id = item.Id
|
||||||
@@ -106,10 +105,9 @@ func (h *ProductHandler) Enable(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
err := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -127,10 +125,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
|
err := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -142,10 +139,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.DB.Where("id", id).Delete(&model.Product{})
|
err := h.DB.Where("id", id).Delete(&model.Product{}).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
219
api/handler/admin/redeem_handler.go
Normal file
219
api/handler/admin/redeem_handler.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/csv"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RedeemHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
|
||||||
|
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RedeemHandler) List(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 1)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
|
code := c.Query("code")
|
||||||
|
status := h.GetInt(c, "status", -1)
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if code != "" {
|
||||||
|
session = session.Where("code LIKE ?", "%"+code+"%")
|
||||||
|
}
|
||||||
|
if status >= 0 {
|
||||||
|
session = session.Where("redeemed_at", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.Redeem{}).Count(&total)
|
||||||
|
var redeems []model.Redeem
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&redeems).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var items = make([]vo.Redeem, 0)
|
||||||
|
userIds := make([]uint, 0)
|
||||||
|
for _, v := range redeems {
|
||||||
|
userIds = append(userIds, v.UserId)
|
||||||
|
}
|
||||||
|
var users []model.User
|
||||||
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
|
var userMap = make(map[uint]model.User)
|
||||||
|
for _, u := range users {
|
||||||
|
userMap[u.Id] = u
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range redeems {
|
||||||
|
var r vo.Redeem
|
||||||
|
err = utils.CopyObject(v, &r)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Id = v.Id
|
||||||
|
r.Username = userMap[v.UserId].Username
|
||||||
|
r.CreatedAt = v.CreatedAt.Unix()
|
||||||
|
items = append(items, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export 导出 CVS 文件
|
||||||
|
func (h *RedeemHandler) Export(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
Ids []int `json:"ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Status >= 0 {
|
||||||
|
session = session.Where("redeemed_at", data.Status)
|
||||||
|
}
|
||||||
|
if len(data.Ids) > 0 {
|
||||||
|
session = session.Where("id IN ?", data.Ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []model.Redeem
|
||||||
|
err := session.Order("id DESC").Find(&items).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置响应头,告诉浏览器这是一个附件,需要下载
|
||||||
|
c.Header("Content-Disposition", "attachment; filename=output.csv")
|
||||||
|
c.Header("Content-Type", "text/csv")
|
||||||
|
|
||||||
|
// 创建一个 CSV writer
|
||||||
|
writer := csv.NewWriter(c.Writer)
|
||||||
|
|
||||||
|
// 写入 CSV 文件的标题行
|
||||||
|
headers := []string{"名称", "兑换码", "算力", "创建时间"}
|
||||||
|
if err := writer.Write(headers); err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入数据行
|
||||||
|
records := make([][]string, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
records = append(records, []string{item.Name, item.Code, fmt.Sprintf("%d", item.Power), item.CreatedAt.Format("2006-01-02 15:04:05")})
|
||||||
|
}
|
||||||
|
for _, record := range records {
|
||||||
|
if err := writer.Write(record); err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保所有数据都已写入响应
|
||||||
|
writer.Flush()
|
||||||
|
if err := writer.Error(); err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RedeemHandler) Create(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Power int `json:"power"`
|
||||||
|
Num int `json:"num"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
counter := 0
|
||||||
|
codes := make([]string, 0)
|
||||||
|
var errMsg = ""
|
||||||
|
if data.Num > 0 {
|
||||||
|
for i := 0; i < data.Num; i++ {
|
||||||
|
code, err := utils.GenRedeemCode(32)
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&model.Redeem{
|
||||||
|
Code: code,
|
||||||
|
Name: data.Name,
|
||||||
|
Power: data.Power,
|
||||||
|
Enabled: true,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
codes = append(codes, code)
|
||||||
|
counter++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if counter == 0 {
|
||||||
|
resp.ERROR(c, errMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, gin.H{
|
||||||
|
"counter": counter,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RedeemHandler) Set(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Filed string `json:"filed"`
|
||||||
|
Value interface{} `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.DB.Model(&model.Redeem{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RedeemHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
package admin
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"geekai/core"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/handler"
|
|
||||||
"geekai/store/model"
|
|
||||||
"geekai/store/vo"
|
|
||||||
"geekai/utils"
|
|
||||||
"geekai/utils/resp"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RewardHandler struct {
|
|
||||||
handler.BaseHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
|
||||||
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RewardHandler) List(c *gin.Context) {
|
|
||||||
var items []model.Reward
|
|
||||||
res := h.DB.Order("id DESC").Find(&items)
|
|
||||||
var rewards = make([]vo.Reward, 0)
|
|
||||||
if res.Error == nil {
|
|
||||||
userIds := make([]uint, 0)
|
|
||||||
for _, v := range items {
|
|
||||||
userIds = append(userIds, v.UserId)
|
|
||||||
}
|
|
||||||
var users []model.User
|
|
||||||
h.DB.Where("id IN ?", userIds).Find(&users)
|
|
||||||
var userMap = make(map[uint]model.User)
|
|
||||||
for _, u := range users {
|
|
||||||
userMap[u.Id] = u
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range items {
|
|
||||||
var r vo.Reward
|
|
||||||
err := utils.CopyObject(v, &r)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Id = v.Id
|
|
||||||
r.Username = userMap[v.UserId].Username
|
|
||||||
r.CreatedAt = v.CreatedAt.Unix()
|
|
||||||
r.UpdatedAt = v.UpdatedAt.Unix()
|
|
||||||
rewards = append(rewards, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, rewards)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RewardHandler) Remove(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Id uint
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if data.Id > 0 {
|
|
||||||
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update database:", res.Error)
|
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -26,10 +28,11 @@ import (
|
|||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
|
redis *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
|
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService, redisCli *redis.Client) *UserHandler {
|
||||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 用户列表
|
// List 用户列表
|
||||||
@@ -37,6 +40,8 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
pageSize := h.GetInt(c, "page_size", 20)
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
username := h.GetTrim(c, "username")
|
username := h.GetTrim(c, "username")
|
||||||
|
mobile := h.GetTrim(c, "mobile")
|
||||||
|
email := h.GetTrim(c, "email")
|
||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
var items []model.User
|
var items []model.User
|
||||||
@@ -47,9 +52,15 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
if username != "" {
|
if username != "" {
|
||||||
session = session.Where("username LIKE ?", "%"+username+"%")
|
session = session.Where("username LIKE ?", "%"+username+"%")
|
||||||
}
|
}
|
||||||
|
if mobile != "" {
|
||||||
|
session = session.Where("mobile LIKE ?", "%"+mobile+"%")
|
||||||
|
}
|
||||||
|
if email != "" {
|
||||||
|
session = session.Where("email LIKE ?", "%"+email+"%")
|
||||||
|
}
|
||||||
|
|
||||||
session.Model(&model.User{}).Count(&total)
|
session.Model(&model.User{}).Count(&total)
|
||||||
res := session.Offset(offset).Limit(pageSize).Find(&items)
|
res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var user vo.User
|
var user vo.User
|
||||||
@@ -73,6 +84,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
Mobile string `json:"mobile"`
|
||||||
|
Email string `json:"email"`
|
||||||
ChatRoles []string `json:"chat_roles"`
|
ChatRoles []string `json:"chat_roles"`
|
||||||
ChatModels []int `json:"chat_models"`
|
ChatModels []int `json:"chat_models"`
|
||||||
ExpiredTime string `json:"expired_time"`
|
ExpiredTime string `json:"expired_time"`
|
||||||
@@ -102,6 +115,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var oldPower = user.Power
|
var oldPower = user.Power
|
||||||
user.Username = data.Username
|
user.Username = data.Username
|
||||||
|
user.Email = data.Email
|
||||||
|
user.Mobile = data.Mobile
|
||||||
user.Status = data.Status
|
user.Status = data.Status
|
||||||
user.Vip = data.Vip
|
user.Vip = data.Vip
|
||||||
user.Power = data.Power
|
user.Power = data.Power
|
||||||
@@ -109,7 +124,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
user.ChatModels = utils.JsonEncode(data.ChatModels)
|
user.ChatModels = utils.JsonEncode(data.ChatModels)
|
||||||
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
|
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
|
||||||
|
|
||||||
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
res = h.DB.Select("username", "mobile", "email", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
||||||
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
@@ -135,6 +151,13 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
// 如果禁用了用户,则将用户踢下线
|
||||||
|
if user.Status == false {
|
||||||
|
key := fmt.Sprintf("users/%v", user.Id)
|
||||||
|
if _, err := h.redis.Del(c, key).Result(); err != nil {
|
||||||
|
logger.Error("error with delete session: ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 检查用户是否已经存在
|
// 检查用户是否已经存在
|
||||||
h.DB.Where("username", data.Username).First(&user)
|
h.DB.Where("username", data.Username).First(&user)
|
||||||
@@ -147,6 +170,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
u := model.User{
|
u := model.User{
|
||||||
Username: data.Username,
|
Username: data.Username,
|
||||||
Password: utils.GenPassword(data.Password, salt),
|
Password: utils.GenPassword(data.Password, salt),
|
||||||
|
Mobile: data.Mobile,
|
||||||
|
Email: data.Email,
|
||||||
Avatar: "/images/avatar/user.png",
|
Avatar: "/images/avatar/user.png",
|
||||||
Salt: salt,
|
Salt: salt,
|
||||||
Power: data.Power,
|
Power: data.Power,
|
||||||
@@ -168,8 +193,7 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, res.Error.Error())
|
||||||
resp.ERROR(c, "更新数据库失败")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,33 +229,69 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserHandler) Remove(c *gin.Context) {
|
func (h *UserHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := c.Query("id")
|
||||||
if id <= 0 {
|
ids := c.QueryArray("ids[]")
|
||||||
|
if id != "" {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 删除用户
|
|
||||||
res := h.DB.Where("id = ?", id).Delete(&model.User{})
|
tx := h.DB.Begin()
|
||||||
if res.Error != nil {
|
var err error
|
||||||
resp.ERROR(c, "删除失败")
|
for _, id = range ids {
|
||||||
|
// 删除用户
|
||||||
|
if err = tx.Where("id", id).Delete(&model.User{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 删除聊天记录
|
||||||
|
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatItem{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 删除聊天历史记录
|
||||||
|
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatMessage{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 删除登录日志
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.UserLoginLog{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 删除算力日志
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.PowerLog{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.InviteLog{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 删除众筹日志
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.Redeem{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 删除绘图任务
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.SdJob{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.DallJob{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.SunoJob{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err = tx.Where("user_id = ?", id).Delete(&model.VideoJob{}).Error; err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
tx.Rollback()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
tx.Commit()
|
||||||
// 删除聊天记录
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
|
|
||||||
// 删除聊天历史记录
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
|
|
||||||
// 删除登录日志
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
|
|
||||||
// 删除算力日志
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
|
|
||||||
// 删除众筹日志
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
|
|
||||||
// 删除绘图任务
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
|
|
||||||
// 删除订单
|
|
||||||
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.DB.First(&user, userId)
|
res := h.DB.Where("id", userId).First(&user)
|
||||||
// 更新缓存
|
// 更新缓存
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
c.Set(types.LoginUserCache, user)
|
c.Set(types.LoginUserCache, user)
|
||||||
|
|||||||
44
api/handler/chat_app_type_handler.go
Normal file
44
api/handler/chat_app_type_handler.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatAppTypeHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
|
||||||
|
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 获取App类型列表
|
||||||
|
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||||
|
var items []model.AppType
|
||||||
|
var appTypes = make([]vo.AppType, 0)
|
||||||
|
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range items {
|
||||||
|
var appType vo.AppType
|
||||||
|
err = utils.CopyObject(v, &appType)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
appType.Id = v.Id
|
||||||
|
appType.CreatedAt = v.CreatedAt.Unix()
|
||||||
|
appTypes = append(appTypes, appType)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, appTypes)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package chatimpl
|
package handler
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
@@ -15,8 +15,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
|
||||||
logger2 "geekai/logger"
|
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -33,136 +31,31 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
handler.BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
|
||||||
return &ChatHandler{
|
return &ChatHandler{
|
||||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
redis: redis,
|
redis: redis,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
licenseService: licenseService,
|
licenseService: licenseService,
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
ChatContexts: types.NewLMap[string, []interface{}](),
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatHandle 处理聊天 WebSocket 请求
|
|
||||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionId := c.Query("session_id")
|
|
||||||
roleId := h.GetInt(c, "role_id", 0)
|
|
||||||
chatId := c.Query("chat_id")
|
|
||||||
modelId := h.GetInt(c, "model_id", 0)
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
var chatRole model.ChatRole
|
|
||||||
res := h.DB.First(&chatRole, roleId)
|
|
||||||
if res.Error != nil || !chatRole.Enable {
|
|
||||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// if the role bind a model_id, use role's bind model_id
|
|
||||||
if chatRole.ModelId > 0 {
|
|
||||||
modelId = chatRole.ModelId
|
|
||||||
}
|
|
||||||
// get model info
|
|
||||||
var chatModel model.ChatModel
|
|
||||||
res = h.DB.First(&chatModel, modelId)
|
|
||||||
if res.Error != nil || chatModel.Enabled == false {
|
|
||||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session := &types.ChatSession{
|
|
||||||
SessionId: sessionId,
|
|
||||||
ClientIP: c.ClientIP(),
|
|
||||||
UserId: h.GetLoginUserId(c),
|
|
||||||
}
|
|
||||||
|
|
||||||
// use old chat data override the chat model and role ID
|
|
||||||
var chat model.ChatItem
|
|
||||||
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
|
|
||||||
if res.Error == nil {
|
|
||||||
chatModel.Id = chat.ModelId
|
|
||||||
roleId = int(chat.RoleId)
|
|
||||||
}
|
|
||||||
|
|
||||||
session.ChatId = chatId
|
|
||||||
session.Model = types.ChatModel{
|
|
||||||
Id: chatModel.Id,
|
|
||||||
Name: chatModel.Name,
|
|
||||||
Value: chatModel.Value,
|
|
||||||
Power: chatModel.Power,
|
|
||||||
MaxTokens: chatModel.MaxTokens,
|
|
||||||
MaxContext: chatModel.MaxContext,
|
|
||||||
Temperature: chatModel.Temperature,
|
|
||||||
KeyId: chatModel.KeyId}
|
|
||||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
_, msg, err := client.Receive()
|
|
||||||
if err != nil {
|
|
||||||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
|
||||||
client.Close()
|
|
||||||
cancelFunc := h.ReqCancelFunc.Get(sessionId)
|
|
||||||
if cancelFunc != nil {
|
|
||||||
cancelFunc()
|
|
||||||
h.ReqCancelFunc.Delete(sessionId)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var message types.WsMessage
|
|
||||||
err = utils.JsonDecode(string(msg), &message)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 心跳消息
|
|
||||||
if message.Type == "heartbeat" {
|
|
||||||
logger.Debug("收到 Chat 心跳消息:", message.Content)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Receive a message: ", message.Content)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
h.ReqCancelFunc.Put(sessionId, cancel)
|
|
||||||
// 回复消息
|
|
||||||
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
utils.ReplyMessage(client, err.Error())
|
|
||||||
} else {
|
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
|
||||||
logger.Infof("回答完毕: %v", message.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
||||||
if !h.App.Debug {
|
if !h.App.Debug {
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -204,45 +97,54 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
var req = types.ApiRequest{
|
var req = types.ApiRequest{
|
||||||
Model: session.Model.Value,
|
Model: session.Model.Value,
|
||||||
Stream: true,
|
}
|
||||||
|
// 兼容 GPT-O1 模型
|
||||||
|
if strings.HasPrefix(session.Model.Value, "o1-") {
|
||||||
|
utils.SendChunkMsg(ws, "> AI 正在思考...\n")
|
||||||
|
req.Stream = session.Stream
|
||||||
|
session.Start = time.Now().Unix()
|
||||||
|
} else {
|
||||||
|
req.MaxTokens = session.Model.MaxTokens
|
||||||
|
req.Temperature = session.Model.Temperature
|
||||||
|
req.Stream = session.Stream
|
||||||
}
|
}
|
||||||
req.Temperature = session.Model.Temperature
|
|
||||||
req.MaxTokens = session.Model.MaxTokens
|
|
||||||
// OpenAI 支持函数功能
|
|
||||||
var items []model.Function
|
|
||||||
res = h.DB.Where("enabled", true).Find(&items)
|
|
||||||
if res.Error == nil {
|
|
||||||
var tools = make([]types.Tool, 0)
|
|
||||||
for _, v := range items {
|
|
||||||
var parameters map[string]interface{}
|
|
||||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
tool := types.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: types.Function{
|
|
||||||
Name: v.Name,
|
|
||||||
Description: v.Description,
|
|
||||||
Parameters: parameters,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if v, ok := parameters["required"]; v == nil || !ok {
|
|
||||||
tool.Function.Parameters["required"] = []string{}
|
|
||||||
}
|
|
||||||
tools = append(tools, tool)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tools) > 0 {
|
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
|
||||||
req.Tools = tools
|
var items []model.Function
|
||||||
req.ToolChoice = "auto"
|
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
var tools = make([]types.Tool, 0)
|
||||||
|
for _, v := range items {
|
||||||
|
var parameters map[string]interface{}
|
||||||
|
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tool := types.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: types.Function{
|
||||||
|
Name: v.Name,
|
||||||
|
Description: v.Description,
|
||||||
|
Parameters: parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if v, ok := parameters["required"]; v == nil || !ok {
|
||||||
|
tool.Function.Parameters["required"] = []string{}
|
||||||
|
}
|
||||||
|
tools = append(tools, tool)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
req.Tools = tools
|
||||||
|
req.ToolChoice = "auto"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载聊天上下文
|
// 加载聊天上下文
|
||||||
chatCtx := make([]types.Message, 0)
|
chatCtx := make([]interface{}, 0)
|
||||||
messages := make([]types.Message, 0)
|
messages := make([]interface{}, 0)
|
||||||
if h.App.SysConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
if h.ChatContexts.Has(session.ChatId) {
|
if h.ChatContexts.Has(session.ChatId) {
|
||||||
messages = h.ChatContexts.Get(session.ChatId)
|
messages = h.ChatContexts.Get(session.ChatId)
|
||||||
@@ -270,8 +172,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||||
tokens += tks + promptTokens
|
tokens += tks + promptTokens
|
||||||
|
|
||||||
for _, v := range messages {
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
v := messages[i]
|
||||||
|
tks, _ = utils.CalcTokens(utils.JsonEncode(v), req.Model)
|
||||||
// 上下文 token 超出了模型的最大上下文长度
|
// 上下文 token 超出了模型的最大上下文长度
|
||||||
if tokens+tks >= session.Model.MaxContext {
|
if tokens+tks >= session.Model.MaxContext {
|
||||||
break
|
break
|
||||||
@@ -289,8 +192,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||||
}
|
}
|
||||||
reqMgs := make([]interface{}, 0)
|
reqMgs := make([]interface{}, 0)
|
||||||
for _, m := range chatCtx {
|
|
||||||
reqMgs = append(reqMgs, m)
|
for i := len(chatCtx) - 1; i >= 0; i-- {
|
||||||
|
reqMgs = append(reqMgs, chatCtx[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
fullPrompt := prompt
|
fullPrompt := prompt
|
||||||
@@ -355,7 +259,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
|
|
||||||
logger.Debugf("%+v", req.Messages)
|
logger.Debugf("%+v", req.Messages)
|
||||||
|
|
||||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokens 统计 token 数量
|
// Tokens 统计 token 数量
|
||||||
@@ -442,7 +346,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logger.Debugf(utils.JsonEncode(req))
|
logger.Debugf("对话请求消息体:%+v", req)
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||||
// 创建 HttpClient 请求对象
|
// 创建 HttpClient 请求对象
|
||||||
@@ -468,7 +372,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
|||||||
} else {
|
} else {
|
||||||
client = http.DefaultClient
|
client = http.DefaultClient
|
||||||
}
|
}
|
||||||
logger.Debugf("Sending %s request, Channel:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||||
// 更新API KEY 最后使用时间
|
// 更新API KEY 最后使用时间
|
||||||
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
@@ -481,115 +385,112 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
|
|||||||
if session.Model.Power > 0 {
|
if session.Model.Power > 0 {
|
||||||
power = session.Model.Power
|
power = session.Model.Power
|
||||||
}
|
}
|
||||||
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
|
|
||||||
if res.Error == nil {
|
|
||||||
// 记录算力消费日志
|
|
||||||
var u model.User
|
|
||||||
h.DB.Where("id", userVo.Id).First(&u)
|
|
||||||
h.DB.Create(&model.PowerLog{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
Username: userVo.Username,
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Balance: u.Power,
|
|
||||||
Model: session.Model.Value,
|
|
||||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
|
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: session.Model.Value,
|
||||||
|
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) saveChatHistory(
|
func (h *ChatHandler) saveChatHistory(
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
prompt string,
|
usage Usage,
|
||||||
contents []string,
|
|
||||||
message types.Message,
|
message types.Message,
|
||||||
chatCtx []types.Message,
|
|
||||||
session *types.ChatSession,
|
session *types.ChatSession,
|
||||||
role model.ChatRole,
|
role model.ChatRole,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
promptCreatedAt time.Time,
|
promptCreatedAt time.Time,
|
||||||
replyCreatedAt time.Time) {
|
replyCreatedAt time.Time) {
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息
|
||||||
if h.App.SysConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx := req.Messages // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.ChatContexts.Put(session.ChatId, chatCtx)
|
h.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
// for prompt
|
// for prompt
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
var promptTokens, replyTokens, totalTokens int
|
||||||
if err != nil {
|
if usage.PromptTokens > 0 {
|
||||||
logger.Error(err)
|
promptTokens = usage.PromptTokens
|
||||||
|
} else {
|
||||||
|
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
historyUserMsg := model.ChatMessage{
|
historyUserMsg := model.ChatMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
ChatId: session.ChatId,
|
ChatId: session.ChatId,
|
||||||
RoleId: role.Id,
|
RoleId: role.Id,
|
||||||
Type: types.PromptMsg,
|
Type: types.PromptMsg,
|
||||||
Icon: userVo.Avatar,
|
Icon: userVo.Avatar,
|
||||||
Content: template.HTMLEscapeString(prompt),
|
Content: template.HTMLEscapeString(usage.Prompt),
|
||||||
Tokens: promptToken,
|
Tokens: promptTokens,
|
||||||
UseContext: true,
|
TotalTokens: promptTokens,
|
||||||
Model: req.Model,
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
}
|
}
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
res := h.DB.Save(&historyUserMsg)
|
err := h.DB.Save(&historyUserMsg).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
logger.Error("failed to save prompt history message: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// for reply
|
// for reply
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
if usage.CompletionTokens > 0 {
|
||||||
totalTokens := replyTokens + getTotalTokens(req)
|
replyTokens = usage.CompletionTokens
|
||||||
|
totalTokens = usage.TotalTokens
|
||||||
|
} else {
|
||||||
|
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens = replyTokens + getTotalTokens(req)
|
||||||
|
}
|
||||||
historyReplyMsg := model.ChatMessage{
|
historyReplyMsg := model.ChatMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
ChatId: session.ChatId,
|
ChatId: session.ChatId,
|
||||||
RoleId: role.Id,
|
RoleId: role.Id,
|
||||||
Type: types.ReplyMsg,
|
Type: types.ReplyMsg,
|
||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Content: message.Content,
|
Content: usage.Content,
|
||||||
Tokens: totalTokens,
|
Tokens: replyTokens,
|
||||||
UseContext: true,
|
TotalTokens: totalTokens,
|
||||||
Model: req.Model,
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
}
|
}
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
res = h.DB.Create(&historyReplyMsg)
|
err = h.DB.Create(&historyReplyMsg).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
logger.Error("failed to save reply history message: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新用户算力
|
// 更新用户算力
|
||||||
if session.Model.Power > 0 {
|
if session.Model.Power > 0 {
|
||||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
h.subUserPower(userVo, session, promptTokens, replyTokens)
|
||||||
}
|
}
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = userVo.Id
|
chatItem.UserId = userVo.Id
|
||||||
chatItem.RoleId = role.Id
|
chatItem.RoleId = role.Id
|
||||||
chatItem.ModelId = session.Model.Id
|
chatItem.ModelId = session.Model.Id
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
if utf8.RuneCountInString(usage.Prompt) > 30 {
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
|
||||||
} else {
|
} else {
|
||||||
chatItem.Title = prompt
|
chatItem.Title = usage.Prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.DB.Create(&chatItem)
|
err = h.DB.Create(&chatItem).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save chat item: ", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package chatimpl
|
package handler
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
@@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
var items = make([]vo.ChatItem, 0)
|
var items = make([]vo.ChatItem, 0)
|
||||||
var chats []model.ChatItem
|
var chats []model.ChatItem
|
||||||
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||||
if res.Error == nil {
|
if len(chats) == 0 {
|
||||||
var roleIds = make([]uint, 0)
|
resp.SUCCESS(c, items)
|
||||||
for _, chat := range chats {
|
return
|
||||||
roleIds = append(roleIds, chat.RoleId)
|
}
|
||||||
}
|
|
||||||
var roles []model.ChatRole
|
|
||||||
res = h.DB.Find(&roles, roleIds)
|
|
||||||
if res.Error == nil {
|
|
||||||
roleMap := make(map[uint]model.ChatRole)
|
|
||||||
for _, role := range roles {
|
|
||||||
roleMap[role.Id] = role
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chat := range chats {
|
var roleIds = make([]uint, 0)
|
||||||
var item vo.ChatItem
|
var modelValues = make([]string, 0)
|
||||||
err := utils.CopyObject(chat, &item)
|
for _, chat := range chats {
|
||||||
if err == nil {
|
roleIds = append(roleIds, chat.RoleId)
|
||||||
item.Id = chat.Id
|
modelValues = append(modelValues, chat.Model)
|
||||||
item.Icon = roleMap[chat.RoleId].Icon
|
}
|
||||||
items = append(items, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var roles []model.ChatRole
|
||||||
|
var models []model.ChatModel
|
||||||
|
roleMap := make(map[uint]model.ChatRole)
|
||||||
|
modelMap := make(map[string]model.ChatModel)
|
||||||
|
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||||
|
h.DB.Where("value IN ?", modelValues).Find(&models)
|
||||||
|
for _, role := range roles {
|
||||||
|
roleMap[role.Id] = role
|
||||||
|
}
|
||||||
|
for _, m := range models {
|
||||||
|
modelMap[m.Value] = m
|
||||||
|
}
|
||||||
|
for _, chat := range chats {
|
||||||
|
var item vo.ChatItem
|
||||||
|
err := utils.CopyObject(chat, &item)
|
||||||
|
if err == nil {
|
||||||
|
item.Id = chat.Id
|
||||||
|
item.Icon = roleMap[chat.RoleId].Icon
|
||||||
|
item.ModelId = modelMap[chat.Model].Id
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, items)
|
resp.SUCCESS(c, items)
|
||||||
}
|
}
|
||||||
@@ -30,29 +30,25 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
|||||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatModel
|
var items []model.ChatModel
|
||||||
var chatModels = make([]vo.ChatModel, 0)
|
var chatModels = make([]vo.ChatModel, 0)
|
||||||
var res *gorm.DB
|
session := h.DB.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
|
||||||
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
|
|
||||||
t := c.Query("type")
|
t := c.Query("type")
|
||||||
if t != "" {
|
if t != "" {
|
||||||
session = session.Where("type", t)
|
session = session.Where("type", t)
|
||||||
}
|
}
|
||||||
// 如果用户没有登录,则加载所有开放模型
|
|
||||||
if !h.IsLogin(c) {
|
session = session.Where("open", true)
|
||||||
res = session.Where("open", true).Order("sort_num ASC").Find(&items)
|
if h.IsLogin(c) {
|
||||||
} else {
|
|
||||||
user, _ := h.GetLoginUser(c)
|
user, _ := h.GetLoginUser(c)
|
||||||
var models []int
|
var models []int
|
||||||
err := utils.JsonDecode(user.ChatModels, &models)
|
err := utils.JsonDecode(user.ChatModels, &models)
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "当前用户没有订阅任何模型")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 查询用户有权限访问的模型以及所有开放的模型
|
// 查询用户有权限访问的模型以及所有开放的模型
|
||||||
res = h.DB.Where("enabled = ?", true).Where(
|
if err == nil {
|
||||||
h.DB.Where("id IN ?", models).Or("open", true),
|
session = session.Or("id IN ?", models)
|
||||||
).Order("sort_num ASC").Find(&items)
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res := session.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var cm vo.ChatModel
|
var cm vo.ChatModel
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package chatimpl
|
package handler
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
@@ -17,15 +17,41 @@ import (
|
|||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
req2 "github.com/imroc/req/v3"
|
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
req2 "github.com/imroc/req/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIResVo struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Choices []struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
Logprobs interface{} `json:"logprobs"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
// OPenAI 消息发送实现
|
// OPenAI 消息发送实现
|
||||||
func (h *ChatHandler) sendOpenAiMessage(
|
func (h *ChatHandler) sendOpenAiMessage(
|
||||||
chatCtx []types.Message,
|
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -37,7 +63,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = model.ApiKey{}
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Since(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
return fmt.Errorf("用户取消了请求:%s", prompt)
|
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||||
@@ -49,17 +75,29 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if response.StatusCode != 200 {
|
||||||
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
contentType := response.Header.Get("Content-Type")
|
contentType := response.Header.Get("Content-Type")
|
||||||
if strings.Contains(contentType, "text/event-stream") {
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
replyCreatedAt := time.Now() // 记录回复时间
|
replyCreatedAt := time.Now() // 记录回复时间
|
||||||
// 循环读取 Chunk 消息
|
// 循环读取 Chunk 消息
|
||||||
var message = types.Message{}
|
var message = types.Message{Role: "assistant"}
|
||||||
var contents = make([]string, 0)
|
var contents = make([]string, 0)
|
||||||
var function model.Function
|
var function model.Function
|
||||||
var toolCall = false
|
var toolCall = false
|
||||||
var arguments = make([]string, 0)
|
var arguments = make([]string, 0)
|
||||||
|
|
||||||
|
if strings.HasPrefix(req.Model, "o1-") {
|
||||||
|
content := fmt.Sprintf("AI 思考结束,耗时:%d 秒。\n\n", time.Now().Unix()-session.Start)
|
||||||
|
contents = append(contents, "> AI 正在思考中...\n")
|
||||||
|
contents = append(contents, content)
|
||||||
|
utils.SendChunkMsg(ws, content)
|
||||||
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
var isNew = true
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||||
@@ -78,7 +116,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,8 +144,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
toolCall = true
|
toolCall = true
|
||||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.SendChunkMsg(ws, callMsg)
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
|
|
||||||
contents = append(contents, callMsg)
|
contents = append(contents, callMsg)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -121,17 +158,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
// output stopped
|
// output stopped
|
||||||
if responseBody.Choices[0].FinishReason != "" {
|
if responseBody.Choices[0].FinishReason != "" {
|
||||||
break // 输出完成或者输出中断了
|
break // 输出完成或者输出中断了
|
||||||
} else {
|
} else { // 正常输出结果
|
||||||
content := responseBody.Choices[0].Delta.Content
|
content := responseBody.Choices[0].Delta.Content
|
||||||
contents = append(contents, utils.InterfaceToString(content))
|
contents = append(contents, utils.InterfaceToString(content))
|
||||||
if isNew {
|
utils.SendChunkMsg(ws, content)
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
|
||||||
isNew = false
|
|
||||||
}
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
} // end for
|
} // end for
|
||||||
|
|
||||||
@@ -149,39 +179,62 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||||
params["user_id"] = userVo.Id
|
params["user_id"] = userVo.Id
|
||||||
var apiRes types.BizVo
|
var apiRes types.BizVo
|
||||||
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
|
r, err := req2.C().R().SetHeader("Body-Type", "application/json").
|
||||||
SetHeader("Authorization", function.Token).
|
SetHeader("Authorization", function.Token).
|
||||||
SetBody(params).
|
SetBody(params).Post(function.Action)
|
||||||
SetSuccessResult(&apiRes).Post(function.Action)
|
|
||||||
errMsg := ""
|
errMsg := ""
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg = err.Error()
|
errMsg = err.Error()
|
||||||
} else if r.IsErrorState() {
|
|
||||||
errMsg = r.Status
|
|
||||||
}
|
|
||||||
if errMsg != "" || apiRes.Code != types.Success {
|
|
||||||
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: msg,
|
|
||||||
})
|
|
||||||
contents = append(contents, msg)
|
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
all, _ := io.ReadAll(r.Body)
|
||||||
Type: types.WsMiddle,
|
err = json.Unmarshal(all, &apiRes)
|
||||||
Content: apiRes.Data,
|
if err != nil {
|
||||||
})
|
errMsg = err.Error()
|
||||||
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
} else if apiRes.Code != types.Success {
|
||||||
|
errMsg = apiRes.Message
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errMsg != "" {
|
||||||
|
errMsg = "调用函数工具出错:" + errMsg
|
||||||
|
contents = append(contents, errMsg)
|
||||||
|
} else {
|
||||||
|
errMsg = utils.InterfaceToString(apiRes.Data)
|
||||||
|
contents = append(contents, errMsg)
|
||||||
|
}
|
||||||
|
utils.SendChunkMsg(ws, errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
usage := Usage{
|
||||||
|
Prompt: prompt,
|
||||||
|
Content: strings.Join(contents, ""),
|
||||||
|
PromptTokens: 0,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: 0,
|
||||||
|
}
|
||||||
|
message.Content = usage.Content
|
||||||
|
h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
}
|
}
|
||||||
} else {
|
} else { // 非流式输出
|
||||||
body, _ := io.ReadAll(response.Body)
|
var respVo OpenAIResVo
|
||||||
return fmt.Errorf("请求 OpenAI API 失败:%s", body)
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("读取响应失败:%v", body)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(body, &respVo)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("解析响应失败:%v", body)
|
||||||
|
}
|
||||||
|
content := respVo.Choices[0].Message.Content
|
||||||
|
if strings.HasPrefix(req.Model, "o1-") {
|
||||||
|
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||||
|
}
|
||||||
|
utils.SendChunkMsg(ws, content)
|
||||||
|
respVo.Usage.Prompt = prompt
|
||||||
|
respVo.Usage.Content = content
|
||||||
|
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -29,10 +29,37 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
|||||||
|
|
||||||
// List 获取用户聊天应用列表
|
// List 获取用户聊天应用列表
|
||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||||
|
tid := h.GetInt(c, "tid", 0)
|
||||||
|
var roles []model.ChatRole
|
||||||
|
session := h.DB.Where("enable", true)
|
||||||
|
if tid > 0 {
|
||||||
|
session = session.Where("tid", tid)
|
||||||
|
}
|
||||||
|
err := session.Order("sort_num ASC").Find(&roles).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var roleVos = make([]vo.ChatRole, 0)
|
||||||
|
for _, r := range roles {
|
||||||
|
var v vo.ChatRole
|
||||||
|
err := utils.CopyObject(r, &v)
|
||||||
|
if err == nil {
|
||||||
|
v.Id = r.Id
|
||||||
|
roleVos = append(roleVos, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, roleVos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByUser 获取用户添加的角色列表
|
||||||
|
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatRole
|
||||||
query := h.DB.Where("enable", true)
|
session := h.DB.Where("enable", true)
|
||||||
|
// 如果用户没登录,则获取所有角色
|
||||||
if userId > 0 {
|
if userId > 0 {
|
||||||
var user model.User
|
var user model.User
|
||||||
h.DB.First(&user, userId)
|
h.DB.First(&user, userId)
|
||||||
@@ -42,12 +69,16 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
resp.ERROR(c, "角色解析失败!")
|
resp.ERROR(c, "角色解析失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
query = query.Where("marker IN ?", roleKeys)
|
// 保证用户至少有一个角色可用
|
||||||
|
if len(roleKeys) > 0 {
|
||||||
|
session = session.Where("marker IN ?", roleKeys)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
query = query.Or("id", id)
|
session = session.Or("id", id)
|
||||||
}
|
}
|
||||||
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
res := session.Order("sort_num ASC").Find(&roles)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -81,10 +112,9 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
|
err = h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,34 +8,33 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
"geekai/service/dalle"
|
"geekai/service/dalle"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DallJobHandler struct {
|
type DallJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
dallService *dalle.Service
|
||||||
service *dalle.Service
|
uploader *oss.UploaderManager
|
||||||
uploader *oss.UploaderManager
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
|
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
|
||||||
return &DallJobHandler{
|
return &DallJobHandler{
|
||||||
service: service,
|
dallService: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
|
userService: userService,
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -43,82 +42,50 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
|
||||||
func (h *DallJobHandler) Client(c *gin.Context) {
|
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
if userId == 0 {
|
|
||||||
logger.Info("Invalid user ID")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
h.service.Clients.Put(uint(userId), client)
|
|
||||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
_, msg, err := client.Receive()
|
|
||||||
if err != nil {
|
|
||||||
client.Close()
|
|
||||||
h.service.Clients.Delete(uint(userId))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var message types.WsMessage
|
|
||||||
err = utils.JsonDecode(string(msg), &message)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 心跳消息
|
|
||||||
if message.Type == "heartbeat" {
|
|
||||||
logger.Debug("收到 DallE 心跳消息:", message.Content)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
|
|
||||||
user, err := h.GetLoginUser(c)
|
|
||||||
if err != nil {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if user.Power < h.App.SysConfig.DallPower {
|
|
||||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *DallJobHandler) Image(c *gin.Context) {
|
func (h *DallJobHandler) Image(c *gin.Context) {
|
||||||
if !h.preCheck(c) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var data types.DallTask
|
var data types.DallTask
|
||||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var chatModel model.ChatModel
|
||||||
|
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
|
||||||
|
resp.ERROR(c, "模型不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查用户剩余算力
|
||||||
|
user, err := h.GetLoginUser(c)
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if user.Power < chatModel.Power {
|
||||||
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
|
task := types.DallTask{
|
||||||
|
ClientId: data.ClientId,
|
||||||
|
UserId: uint(userId),
|
||||||
|
ModelId: chatModel.Id,
|
||||||
|
ModelName: chatModel.Value,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
Quality: data.Quality,
|
||||||
|
Size: data.Size,
|
||||||
|
Style: data.Style,
|
||||||
|
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||||
|
Power: chatModel.Power,
|
||||||
|
}
|
||||||
job := model.DallJob{
|
job := model.DallJob{
|
||||||
UserId: uint(userId),
|
UserId: uint(userId),
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Power: h.App.SysConfig.DallPower,
|
Power: chatModel.Power,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
}
|
}
|
||||||
res := h.DB.Create(&job)
|
res := h.DB.Create(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
@@ -126,19 +93,18 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.service.PushTask(types.DallTask{
|
task.Id = job.Id
|
||||||
JobId: job.Id,
|
h.dallService.PushTask(task)
|
||||||
UserId: uint(userId),
|
|
||||||
Prompt: data.Prompt,
|
|
||||||
Quality: data.Quality,
|
|
||||||
Size: data.Size,
|
|
||||||
Style: data.Style,
|
|
||||||
Power: job.Power,
|
|
||||||
})
|
|
||||||
|
|
||||||
client := h.service.Clients.Get(job.UserId)
|
// 扣减算力
|
||||||
if client != nil {
|
err = h.userService.DecreasePower(int(user.Id), chatModel.Power, model.PowerLog{
|
||||||
_ = client.Send([]byte("Task Updated"))
|
Type: types.PowerConsume,
|
||||||
|
Model: chatModel.Value,
|
||||||
|
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "error with decrease power: "+err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
@@ -174,11 +140,11 @@ func (h *DallJobHandler) JobList(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// JobList 获取任务列表
|
// JobList 获取任务列表
|
||||||
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
|
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
|
||||||
|
|
||||||
session := h.DB.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if finish {
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress >= ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
}
|
}
|
||||||
@@ -192,11 +158,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
|
|||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
session = session.Offset(offset).Limit(pageSize)
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
}
|
}
|
||||||
|
// 统计总数
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.DallJob{}).Count(&total)
|
||||||
|
|
||||||
var items []model.DallJob
|
var items []model.DallJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error, nil
|
return res.Error, vo.Page{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.DallJob, 0)
|
var jobs = make([]vo.DallJob, 0)
|
||||||
@@ -209,28 +178,28 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
|
|||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, jobs
|
return nil, vo.NewPage(total, page, pageSize, jobs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
func (h *DallJobHandler) Remove(c *gin.Context) {
|
func (h *DallJobHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
var job model.DallJob
|
var job model.DallJob
|
||||||
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
||||||
resp.ERROR(c, "记录不存在")
|
resp.ERROR(c, "记录不存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// 删除任务
|
||||||
res := h.DB.Delete(&model.DallJob{Id: job.Id})
|
err := h.DB.Delete(&job).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove image
|
// remove image
|
||||||
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("remove image failed: ", err)
|
logger.Error("remove image failed: ", err)
|
||||||
}
|
}
|
||||||
@@ -241,15 +210,36 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
|
|||||||
// Publish 发布/取消发布图片到画廊显示
|
// Publish 发布/取消发布图片到画廊显示
|
||||||
func (h *DallJobHandler) Publish(c *gin.Context) {
|
func (h *DallJobHandler) Publish(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||||
|
|
||||||
res := h.DB.Model(&model.DallJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action)
|
err := h.DB.Model(&model.DallJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *DallJobHandler) GetModels(c *gin.Context) {
|
||||||
|
var models []model.ChatModel
|
||||||
|
err := h.DB.Where("type", "img").Where("enabled", true).Find(&models).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelVos []vo.ChatModel
|
||||||
|
for _, v := range models {
|
||||||
|
var modelVo vo.ChatModel
|
||||||
|
err := utils.CopyObject(v, &modelVo)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modelVo.Id = v.Id
|
||||||
|
modelVos = append(modelVos, modelVo)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, modelVos)
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,15 +8,17 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
"geekai/service/dalle"
|
"geekai/service/dalle"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -31,6 +33,7 @@ type FunctionHandler struct {
|
|||||||
config types.ApiConfig
|
config types.ApiConfig
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
dallService *dalle.Service
|
dallService *dalle.Service
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctionHandler(
|
func NewFunctionHandler(
|
||||||
@@ -38,7 +41,8 @@ func NewFunctionHandler(
|
|||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
config *types.AppConfig,
|
config *types.AppConfig,
|
||||||
manager *oss.UploaderManager,
|
manager *oss.UploaderManager,
|
||||||
dallService *dalle.Service) *FunctionHandler {
|
dallService *dalle.Service,
|
||||||
|
userService *service.UserService) *FunctionHandler {
|
||||||
return &FunctionHandler{
|
return &FunctionHandler{
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: server,
|
App: server,
|
||||||
@@ -47,6 +51,7 @@ func NewFunctionHandler(
|
|||||||
config: config.ApiConfig,
|
config: config.ApiConfig,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
dallService: dallService,
|
dallService: dallService,
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,10 +117,13 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
|
|||||||
SetHeader("AppId", h.config.AppId).
|
SetHeader("AppId", h.config.AppId).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil {
|
||||||
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
|
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if r.IsErrorState() {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("error http code status: %v", r.Status))
|
||||||
|
}
|
||||||
|
|
||||||
if res.Code != types.Success {
|
if res.Code != types.Success {
|
||||||
resp.ERROR(c, res.Message)
|
resp.ERROR(c, res.Message)
|
||||||
@@ -148,8 +156,12 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
|||||||
SetHeader("AppId", h.config.AppId).
|
SetHeader("AppId", h.config.AppId).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil {
|
||||||
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
|
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.IsErrorState() {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("%v", r.Err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +175,7 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
|||||||
for _, v := range res.Data.Items {
|
for _, v := range res.Data.Items {
|
||||||
builder = append(builder, v.Title)
|
builder = append(builder, v.Title)
|
||||||
}
|
}
|
||||||
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
|
builder = append(builder, res.Data.Title)
|
||||||
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,32 +207,71 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
|
|
||||||
// create dall task
|
// create dall task
|
||||||
prompt := utils.InterfaceToString(params["prompt"])
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
job := model.DallJob{
|
task := types.DallTask{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Power: h.App.SysConfig.DallPower,
|
ModelId: 0,
|
||||||
|
ModelName: "dall-e-3",
|
||||||
|
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||||
|
N: 1,
|
||||||
|
Quality: "standard",
|
||||||
|
Size: "1024x1024",
|
||||||
|
Style: "vivid",
|
||||||
|
Power: h.App.SysConfig.DallPower,
|
||||||
}
|
}
|
||||||
res = h.DB.Create(&job)
|
job := model.DallJob{
|
||||||
|
UserId: user.Id,
|
||||||
if res.Error != nil {
|
Prompt: prompt,
|
||||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
|
Power: h.App.SysConfig.DallPower,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
|
}
|
||||||
|
err := h.DB.Create(&job).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
content, err := h.dallService.Image(types.DallTask{
|
task.Id = job.Id
|
||||||
JobId: job.Id,
|
content, err := h.dallService.Image(task, true)
|
||||||
UserId: user.Id,
|
|
||||||
Prompt: job.Prompt,
|
|
||||||
N: 1,
|
|
||||||
Quality: "standard",
|
|
||||||
Size: "1024x1024",
|
|
||||||
Style: "vivid",
|
|
||||||
Power: job.Power,
|
|
||||||
}, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "任务执行失败:"+err.Error())
|
resp.ERROR(c, "任务执行失败:"+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 扣减算力
|
||||||
|
err = h.userService.DecreasePower(int(user.Id), job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: task.ModelName,
|
||||||
|
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
resp.SUCCESS(c, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List 获取所有的工具函数列表
|
||||||
|
func (h *FunctionHandler) List(c *gin.Context) {
|
||||||
|
var items []model.Function
|
||||||
|
err := h.DB.Where("enabled", true).Find(&items).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := make([]vo.Function, 0)
|
||||||
|
for _, v := range items {
|
||||||
|
var f vo.Function
|
||||||
|
err = utils.CopyObject(v, &f)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f.Action = ""
|
||||||
|
f.Token = ""
|
||||||
|
tools = append(tools, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, tools)
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
@@ -59,23 +58,16 @@ func (h *InviteHandler) Code(c *gin.Context) {
|
|||||||
|
|
||||||
// List Log 用户邀请记录
|
// List Log 用户邀请记录
|
||||||
func (h *InviteHandler) List(c *gin.Context) {
|
func (h *InviteHandler) List(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 1)
|
||||||
var data struct {
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
Page int `json:"page"`
|
|
||||||
PageSize int `json:"page_size"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
|
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.InviteLog{}).Count(&total)
|
session.Model(&model.InviteLog{}).Count(&total)
|
||||||
var items []model.InviteLog
|
var items []model.InviteLog
|
||||||
var list = make([]vo.InviteLog, 0)
|
var list = make([]vo.InviteLog, 0)
|
||||||
offset := (data.Page - 1) * data.PageSize
|
offset := (page - 1) * pageSize
|
||||||
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var v vo.InviteLog
|
var v vo.InviteLog
|
||||||
@@ -89,7 +81,7 @@ func (h *InviteHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hits 访问邀请码
|
// Hits 访问邀请码
|
||||||
|
|||||||
@@ -8,110 +8,66 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MarkMapHandler 生成思维导图
|
// MarkMapHandler 生成思维导图
|
||||||
type MarkMapHandler struct {
|
type MarkMapHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
clients *types.LMap[int, *types.WsClient]
|
clients *types.LMap[int, *types.WsClient]
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
|
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler {
|
||||||
return &MarkMapHandler{
|
return &MarkMapHandler{
|
||||||
BaseHandler: BaseHandler{App: app, DB: db},
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
clients: types.NewLMap[int, *types.WsClient](),
|
clients: types.NewLMap[int, *types.WsClient](),
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MarkMapHandler) Client(c *gin.Context) {
|
// Generate 生成思维导图
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||||
if err != nil {
|
var data struct {
|
||||||
logger.Error(err)
|
Prompt string `json:"prompt"`
|
||||||
|
ModelId int `json:"model_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelId := h.GetInt(c, "model_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
h.clients.Put(userId, client)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
_, msg, err := client.Receive()
|
|
||||||
if err != nil {
|
|
||||||
client.Close()
|
|
||||||
h.clients.Delete(userId)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var message types.WsMessage
|
|
||||||
err = utils.JsonDecode(string(msg), &message)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 心跳消息
|
|
||||||
if message.Type == "heartbeat" {
|
|
||||||
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// change model
|
|
||||||
if message.Type == "model_id" {
|
|
||||||
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Receive a message: ", message.Content)
|
|
||||||
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.DB.Model(&model.User{}).First(&user, userId)
|
err := h.DB.Where("id", userId).First(&user, userId).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error with query user info: %v", res.Error)
|
resp.ERROR(c, "error with query user info")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
res = h.DB.Where("id", modelId).First(&chatModel)
|
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error with query chat model: %v", res.Error)
|
resp.ERROR(c, "error with query chat model")
|
||||||
}
|
return
|
||||||
|
|
||||||
if user.Status == false {
|
|
||||||
return errors.New("当前用户被禁用")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < chatModel.Power {
|
if user.Power < chatModel.Power {
|
||||||
return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
|
resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := make([]interface{}, 0)
|
messages := make([]interface{}, 0)
|
||||||
messages = append(messages, types.Message{Role: "system", Content: `
|
messages = append(messages, types.Message{Role: "system", Content: `
|
||||||
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
|
你是一位非常优秀的思维导图助手, 你能帮助用户整理思路,根据用户提供的主题或内容,快速生成结构清晰,有条理的思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
|
||||||
# Geek-AI 助手
|
# Geek-AI 助手
|
||||||
|
|
||||||
## 完整的开源系统
|
## 完整的开源系统
|
||||||
@@ -128,130 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
|||||||
### 支付宝
|
### 支付宝
|
||||||
### 微信
|
### 微信
|
||||||
|
|
||||||
另外,除此之外不要任何解释性语句。
|
请直接生成结果,不要任何解释性语句。
|
||||||
`})
|
`})
|
||||||
messages = append(messages, types.Message{Role: "user", Content: prompt})
|
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", data.Prompt)})
|
||||||
var req = types.ApiRequest{
|
content, err := utils.SendOpenAIMessage(h.DB, messages, data.ModelId)
|
||||||
Model: chatModel.Value,
|
|
||||||
Stream: true,
|
|
||||||
Messages: messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
var apiKey model.ApiKey
|
|
||||||
response, err := h.doRequest(req, chatModel, &apiKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
|
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
|
||||||
}
|
return
|
||||||
|
|
||||||
defer response.Body.Close()
|
|
||||||
|
|
||||||
contentType := response.Header.Get("Content-Type")
|
|
||||||
if strings.Contains(contentType, "text/event-stream") {
|
|
||||||
// 循环读取 Chunk 消息
|
|
||||||
scanner := bufio.NewScanner(response.Body)
|
|
||||||
var isNew = true
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseBody = types.ApiResponse{}
|
|
||||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
|
||||||
if err != nil { // 数据解析出错
|
|
||||||
return fmt.Errorf("error with decode data: %v", line)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "stop" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if isNew {
|
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
|
|
||||||
isNew = false
|
|
||||||
}
|
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
|
||||||
})
|
|
||||||
} // end for
|
|
||||||
|
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
|
||||||
|
|
||||||
} else {
|
|
||||||
body, _ := io.ReadAll(response.Body)
|
|
||||||
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 扣减算力
|
// 扣减算力
|
||||||
if chatModel.Power > 0 {
|
if chatModel.Power > 0 {
|
||||||
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
|
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
|
||||||
if res.Error == nil {
|
Type: types.PowerConsume,
|
||||||
// 记录算力消费日志
|
Model: chatModel.Value,
|
||||||
var u model.User
|
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||||
h.DB.Where("id", userId).First(&u)
|
})
|
||||||
h.DB.Create(&model.PowerLog{
|
if err != nil {
|
||||||
UserId: u.Id,
|
resp.ERROR(c, "error with save power log, "+err.Error())
|
||||||
Username: u.Username,
|
return
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: chatModel.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Balance: u.Power,
|
|
||||||
Model: chatModel.Value,
|
|
||||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
resp.SUCCESS(c, content)
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
|
|
||||||
|
|
||||||
session := h.DB.Session(&gorm.Session{})
|
|
||||||
// if the chat model bind a KEY, use it directly
|
|
||||||
if chatModel.KeyId > 0 {
|
|
||||||
session = session.Where("id", chatModel.KeyId)
|
|
||||||
} else { // use the last unused key
|
|
||||||
session = session.Where("type", "chat").
|
|
||||||
Where("enabled", true).Order("last_used_at ASC")
|
|
||||||
}
|
|
||||||
|
|
||||||
res := session.First(apiKey)
|
|
||||||
if res.Error != nil {
|
|
||||||
return nil, errors.New("no available key, please import key")
|
|
||||||
}
|
|
||||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
|
||||||
// 更新 API KEY 的最后使用时间
|
|
||||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
|
||||||
|
|
||||||
// 创建 HttpClient 请求对象
|
|
||||||
var client *http.Client
|
|
||||||
requestBody, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Header.Set("Content-Type", "application/json")
|
|
||||||
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
|
||||||
proxy, _ := url.Parse(apiKey.ProxyURL)
|
|
||||||
client = &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
client = http.DefaultClient
|
|
||||||
}
|
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
|
||||||
return client.Do(request)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
@@ -19,27 +18,27 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
pool *mj.ServicePool
|
mjService *mj.Service
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
|
||||||
return &MidJourneyHandler{
|
return &MidJourneyHandler{
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
pool: pool,
|
mjService: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
|
userService: userService,
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -59,40 +58,15 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !h.pool.HasAvailableService() {
|
|
||||||
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
|
||||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
if userId == 0 {
|
|
||||||
logger.Info("Invalid user ID")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
h.pool.Clients.Put(uint(userId), client)
|
|
||||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
TaskType string `json:"task_type"`
|
TaskType string `json:"task_type"`
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
NegPrompt string `json:"neg_prompt"`
|
NegPrompt string `json:"neg_prompt"`
|
||||||
Rate string `json:"rate"`
|
Rate string `json:"rate"`
|
||||||
@@ -178,10 +152,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
resp.ERROR(c, "error with generate task id: "+err.Error())
|
resp.ERROR(c, "error with generate task id: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task := types.MjTask{
|
||||||
|
ClientId: data.ClientId,
|
||||||
|
TaskId: taskId,
|
||||||
|
Type: types.TaskType(data.TaskType),
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
NegPrompt: data.NegPrompt,
|
||||||
|
Params: params,
|
||||||
|
UserId: userId,
|
||||||
|
ImgArr: data.ImgArr,
|
||||||
|
Mode: h.App.SysConfig.MjMode,
|
||||||
|
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||||
|
}
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
Type: data.TaskType,
|
Type: data.TaskType,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
||||||
Power: h.App.SysConfig.MjPower,
|
Power: h.App.SysConfig.MjPower,
|
||||||
@@ -201,44 +188,26 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
task.Id = job.Id
|
||||||
Id: job.Id,
|
h.mjService.PushTask(task)
|
||||||
TaskId: taskId,
|
|
||||||
Type: types.TaskType(data.TaskType),
|
|
||||||
Prompt: data.Prompt,
|
|
||||||
NegPrompt: data.NegPrompt,
|
|
||||||
Params: params,
|
|
||||||
UserId: userId,
|
|
||||||
ImgArr: data.ImgArr,
|
|
||||||
})
|
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// update user's power
|
// update user's power
|
||||||
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
// 记录算力变化日志
|
Type: types.PowerConsume,
|
||||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
Model: "mid-journey",
|
||||||
user, _ := h.GetLoginUser(c)
|
Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
|
||||||
h.DB.Create(&model.PowerLog{
|
})
|
||||||
UserId: user.Id,
|
if err != nil {
|
||||||
Username: user.Username,
|
resp.ERROR(c, err.Error())
|
||||||
Type: types.PowerConsume,
|
return
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power - job.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Model: "mid-journey",
|
|
||||||
Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
type reqVo struct {
|
type reqVo struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
ChannelId string `json:"channel_id"`
|
ChannelId string `json:"channel_id"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
MessageHash string `json:"message_hash"`
|
MessageHash string `json:"message_hash"`
|
||||||
@@ -259,51 +228,44 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
taskId, _ := h.snowflake.Next(true)
|
taskId, _ := h.snowflake.Next(true)
|
||||||
job := model.MidJourneyJob{
|
task := types.MjTask{
|
||||||
Type: types.TaskUpscale.String(),
|
ClientId: data.ClientId,
|
||||||
ReferenceId: data.MessageId,
|
|
||||||
UserId: userId,
|
|
||||||
TaskId: taskId,
|
|
||||||
Progress: 0,
|
|
||||||
Power: h.App.SysConfig.MjActionPower,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
|
||||||
Id: job.Id,
|
|
||||||
Type: types.TaskUpscale,
|
Type: types.TaskUpscale,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
ChannelId: data.ChannelId,
|
ChannelId: data.ChannelId,
|
||||||
Index: data.Index,
|
Index: data.Index,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
})
|
Mode: h.App.SysConfig.MjMode,
|
||||||
|
}
|
||||||
|
job := model.MidJourneyJob{
|
||||||
|
Type: types.TaskUpscale.String(),
|
||||||
|
UserId: userId,
|
||||||
|
TaskId: taskId,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
|
Progress: 0,
|
||||||
|
Power: h.App.SysConfig.MjActionPower,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
task.Id = job.Id
|
||||||
|
h.mjService.PushTask(task)
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
// update user's power
|
// update user's power
|
||||||
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
// 记录算力变化日志
|
Type: types.PowerConsume,
|
||||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
Model: "mid-journey",
|
||||||
user, _ := h.GetLoginUser(c)
|
Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
|
||||||
h.DB.Create(&model.PowerLog{
|
})
|
||||||
UserId: user.Id,
|
if err != nil {
|
||||||
Username: user.Username,
|
resp.ERROR(c, err.Error())
|
||||||
Type: types.PowerConsume,
|
return
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power - job.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Model: "mid-journey",
|
|
||||||
Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,53 +284,44 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
taskId, _ := h.snowflake.Next(true)
|
taskId, _ := h.snowflake.Next(true)
|
||||||
job := model.MidJourneyJob{
|
task := types.MjTask{
|
||||||
Type: types.TaskVariation.String(),
|
Type: types.TaskVariation,
|
||||||
ChannelId: data.ChannelId,
|
ClientId: data.ClientId,
|
||||||
ReferenceId: data.MessageId,
|
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
TaskId: taskId,
|
Index: data.Index,
|
||||||
Progress: 0,
|
ChannelId: data.ChannelId,
|
||||||
Power: h.App.SysConfig.MjActionPower,
|
MessageId: data.MessageId,
|
||||||
CreatedAt: time.Now(),
|
MessageHash: data.MessageHash,
|
||||||
|
Mode: h.App.SysConfig.MjMode,
|
||||||
|
}
|
||||||
|
job := model.MidJourneyJob{
|
||||||
|
Type: types.TaskVariation.String(),
|
||||||
|
ChannelId: data.ChannelId,
|
||||||
|
UserId: userId,
|
||||||
|
TaskId: taskId,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
|
Progress: 0,
|
||||||
|
Power: h.App.SysConfig.MjActionPower,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
task.Id = job.Id
|
||||||
Id: job.Id,
|
h.mjService.PushTask(task)
|
||||||
Type: types.TaskVariation,
|
|
||||||
UserId: userId,
|
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
Index: data.Index,
|
Type: types.PowerConsume,
|
||||||
ChannelId: data.ChannelId,
|
Model: "mid-journey",
|
||||||
MessageId: data.MessageId,
|
Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
|
||||||
MessageHash: data.MessageHash,
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
resp.ERROR(c, err.Error())
|
||||||
if client != nil {
|
return
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update user's power
|
|
||||||
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
|
||||||
// 记录算力变化日志
|
|
||||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
|
||||||
user, _ := h.GetLoginUser(c)
|
|
||||||
h.DB.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power - job.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Model: "mid-journey",
|
|
||||||
Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -403,7 +356,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// JobList 获取 MJ 任务列表
|
// JobList 获取 MJ 任务列表
|
||||||
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
|
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
|
||||||
session := h.DB.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if finish {
|
if finish {
|
||||||
session = session.Where("progress >= ?", 100).Order("id DESC")
|
session = session.Where("progress >= ?", 100).Order("id DESC")
|
||||||
@@ -421,10 +374,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
|
|||||||
session = session.Offset(offset).Limit(pageSize)
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 统计总数
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.MidJourneyJob{}).Count(&total)
|
||||||
|
|
||||||
var items []model.MidJourneyJob
|
var items []model.MidJourneyJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error, nil
|
return res.Error, vo.Page{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.MidJourneyJob, 0)
|
var jobs = make([]vo.MidJourneyJob, 0)
|
||||||
@@ -434,17 +391,9 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
|
||||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
|
||||||
if err == nil {
|
|
||||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
return nil, jobs
|
return nil, vo.NewPage(total, page, pageSize, jobs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
@@ -457,40 +406,12 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// remove job
|
||||||
tx := h.DB.Begin()
|
err := h.DB.Delete(&job).Error
|
||||||
if err := tx.Delete(&job).Error; err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// refund power
|
|
||||||
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var user model.User
|
|
||||||
h.DB.Where("id = ?", job.UserId).First(&user)
|
|
||||||
err = tx.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power + job.Power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Model: "mid-journey",
|
|
||||||
Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}).Error
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
|
|
||||||
// remove image
|
// remove image
|
||||||
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||||
@@ -498,11 +419,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
logger.Error("remove image failed: ", err)
|
logger.Error("remove image failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,10 +427,9 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetInt(c, "user_id", 0)
|
||||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||||
res := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action)
|
err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,19 +17,21 @@ import (
|
|||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UploadHandler struct {
|
type NetHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
uploaderManager *oss.UploaderManager
|
uploaderManager *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *NetHandler {
|
||||||
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
func (h *NetHandler) Upload(c *gin.Context) {
|
||||||
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
@@ -60,9 +62,11 @@ func (h *UploadHandler) Upload(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, file)
|
resp.SUCCESS(c, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UploadHandler) List(c *gin.Context) {
|
func (h *NetHandler) List(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Urls []string `json:"urls,omitempty"`
|
Urls []string `json:"urls,omitempty"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -77,25 +81,36 @@ func (h *UploadHandler) List(c *gin.Context) {
|
|||||||
if len(data.Urls) > 0 {
|
if len(data.Urls) > 0 {
|
||||||
session = session.Where("url IN ?", data.Urls)
|
session = session.Where("url IN ?", data.Urls)
|
||||||
}
|
}
|
||||||
session.Find(&items)
|
// 统计总数
|
||||||
if len(items) > 0 {
|
var total int64
|
||||||
for _, v := range items {
|
session.Model(&model.File{}).Count(&total)
|
||||||
var file vo.File
|
|
||||||
err := utils.CopyObject(v, &file)
|
if data.Page > 0 && data.PageSize > 0 {
|
||||||
if err != nil {
|
offset := (data.Page - 1) * data.PageSize
|
||||||
logger.Error(err)
|
session = session.Offset(offset).Limit(data.PageSize)
|
||||||
continue
|
}
|
||||||
}
|
err := session.Order("id desc").Find(&items).Error
|
||||||
file.CreatedAt = v.CreatedAt.Unix()
|
if err != nil {
|
||||||
files = append(files, file)
|
resp.ERROR(c, err.Error())
|
||||||
}
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, files)
|
for _, v := range items {
|
||||||
|
var file vo.File
|
||||||
|
err := utils.CopyObject(v, &file)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
file.CreatedAt = v.CreatedAt.Unix()
|
||||||
|
files = append(files, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, files))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove files
|
// Remove remove files
|
||||||
func (h *UploadHandler) Remove(c *gin.Context) {
|
func (h *NetHandler) Remove(c *gin.Context) {
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
var file model.File
|
var file model.File
|
||||||
@@ -119,3 +134,28 @@ func (h *UploadHandler) Remove(c *gin.Context) {
|
|||||||
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
|
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *NetHandler) Download(c *gin.Context) {
|
||||||
|
fileUrl := c.Query("url")
|
||||||
|
// 使用http工具下载文件
|
||||||
|
if fileUrl == "" {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 使用http.Get下载文件
|
||||||
|
r, err := http.Get(fileUrl)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
resp.ERROR(c, "error status:"+r.Status)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
// 将下载的文件内容写入响应
|
||||||
|
_, _ = io.Copy(c.Writer, r.Body)
|
||||||
|
}
|
||||||
@@ -48,6 +48,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
order.Id = item.Id
|
order.Id = item.Id
|
||||||
order.CreatedAt = item.CreatedAt.Unix()
|
order.CreatedAt = item.CreatedAt.Unix()
|
||||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||||
|
payMethod, ok := types.PayMethods[item.PayWay]
|
||||||
|
if !ok {
|
||||||
|
payMethod = item.PayWay
|
||||||
|
}
|
||||||
|
payName, ok := types.PayNames[item.PayType]
|
||||||
|
if !ok {
|
||||||
|
payName = item.PayWay
|
||||||
|
}
|
||||||
|
order.PayMethod = payMethod
|
||||||
|
order.PayName = payName
|
||||||
list = append(list, order)
|
list = append(list, order)
|
||||||
} else {
|
} else {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
@@ -18,10 +17,7 @@ import (
|
|||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/shopspring/decimal"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -34,21 +30,15 @@ type PayWay struct {
|
|||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
PayWayAlipay = PayWay{Name: "支付宝", Value: "alipay"}
|
|
||||||
PayWayXunHu = PayWay{Name: "虎皮椒", Value: "hupi"}
|
|
||||||
PayWayJs = PayWay{Name: "PayJS", Value: "payjs"}
|
|
||||||
PayWayWechat = PayWay{Name: "微信支付", Value: "wechat"}
|
|
||||||
)
|
|
||||||
|
|
||||||
// PaymentHandler 支付服务回调 handler
|
// PaymentHandler 支付服务回调 handler
|
||||||
type PaymentHandler struct {
|
type PaymentHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
alipayService *payment.AlipayService
|
alipayService *payment.AlipayService
|
||||||
huPiPayService *payment.HuPiPayService
|
huPiPayService *payment.HuPiPayService
|
||||||
jsPayService *payment.JPayService
|
geekPayService *payment.GeekPayService
|
||||||
wechatPayService *payment.WechatPayService
|
wechatPayService *payment.WechatPayService
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
|
userService *service.UserService
|
||||||
fs embed.FS
|
fs embed.FS
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
signKey string // 用来签名的随机秘钥
|
signKey string // 用来签名的随机秘钥
|
||||||
@@ -58,17 +48,19 @@ func NewPaymentHandler(
|
|||||||
server *core.AppServer,
|
server *core.AppServer,
|
||||||
alipayService *payment.AlipayService,
|
alipayService *payment.AlipayService,
|
||||||
huPiPayService *payment.HuPiPayService,
|
huPiPayService *payment.HuPiPayService,
|
||||||
jsPayService *payment.JPayService,
|
geekPayService *payment.GeekPayService,
|
||||||
wechatPayService *payment.WechatPayService,
|
wechatPayService *payment.WechatPayService,
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
|
userService *service.UserService,
|
||||||
snowflake *service.Snowflake,
|
snowflake *service.Snowflake,
|
||||||
fs embed.FS) *PaymentHandler {
|
fs embed.FS) *PaymentHandler {
|
||||||
return &PaymentHandler{
|
return &PaymentHandler{
|
||||||
alipayService: alipayService,
|
alipayService: alipayService,
|
||||||
huPiPayService: huPiPayService,
|
huPiPayService: huPiPayService,
|
||||||
jsPayService: jsPayService,
|
geekPayService: geekPayService,
|
||||||
wechatPayService: wechatPayService,
|
wechatPayService: wechatPayService,
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
|
userService: userService,
|
||||||
fs: fs,
|
fs: fs,
|
||||||
lock: sync.Mutex{},
|
lock: sync.Mutex{},
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
@@ -79,309 +71,167 @@ func NewPaymentHandler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||||
orderNo := h.GetTrim(c, "order_no")
|
var data struct {
|
||||||
payWay := h.GetTrim(c, "pay_way")
|
PayWay string `json:"pay_way"`
|
||||||
t := h.GetInt(c, "t", 0)
|
PayType string `json:"pay_type"`
|
||||||
sign := h.GetTrim(c, "sign")
|
ProductId int `json:"product_id"`
|
||||||
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey)
|
UserId int `json:"user_id"`
|
||||||
newSign := utils.Sha256(signStr)
|
Device string `json:"device"`
|
||||||
if newSign != sign {
|
Host string `json:"host"`
|
||||||
resp.ERROR(c, "订单签名错误!")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
// 检查二维码是否过期
|
|
||||||
if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) {
|
|
||||||
resp.ERROR(c, "支付二维码已过期,请重新生成!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if orderNo == "" {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var order model.Order
|
var product model.Product
|
||||||
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
err := h.DB.Where("id", data.ProductId).First(&product).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Order not found")
|
resp.ERROR(c, "Product not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
|
orderNo, err := h.snowflake.Next(false)
|
||||||
if order.Status == types.OrderPaidSuccess {
|
if err != nil {
|
||||||
resp.ERROR(c, "订单已支付成功,无需重复支付!")
|
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
err = h.DB.Where("id", data.UserId).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新扫码状态
|
amount := product.Discount
|
||||||
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
|
var payURL, returnURL, notifyURL string
|
||||||
|
switch data.PayWay {
|
||||||
|
case "alipay":
|
||||||
|
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
|
||||||
|
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||||
|
} else {
|
||||||
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
|
||||||
|
}
|
||||||
|
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
|
||||||
|
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||||
|
} else {
|
||||||
|
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||||
|
}
|
||||||
|
money := fmt.Sprintf("%.2f", amount)
|
||||||
|
if data.Device == "wechat" {
|
||||||
|
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
|
||||||
|
OutTradeNo: orderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
TotalFee: money,
|
||||||
|
ReturnURL: returnURL,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
|
||||||
|
OutTradeNo: orderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
TotalFee: money,
|
||||||
|
ReturnURL: returnURL,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if payWay == "alipay" { // 支付宝
|
|
||||||
amount := fmt.Sprintf("%.2f", order.Amount)
|
|
||||||
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, amount, order.Subject)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
break
|
||||||
c.Redirect(302, uri)
|
case "wechat":
|
||||||
return
|
if h.App.Config.WechatPayConfig.NotifyURL != "" {
|
||||||
} else if payWay == "hupi" { // 虎皮椒支付
|
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
||||||
params := payment.HuPiPayReq{
|
} else {
|
||||||
Version: "1.1",
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
|
||||||
TradeOrderId: orderNo,
|
}
|
||||||
TotalFee: fmt.Sprintf("%f", order.Amount),
|
if data.Device == "wechat" {
|
||||||
Title: order.Subject,
|
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
|
||||||
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
|
OutTradeNo: orderNo,
|
||||||
WapName: "极客学长",
|
TotalFee: int(amount * 100),
|
||||||
|
Subject: product.Name,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
ClientIP: c.ClientIP(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
|
||||||
|
OutTradeNo: orderNo,
|
||||||
|
TotalFee: int(amount * 100),
|
||||||
|
Subject: product.Name,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
r, err := h.huPiPayService.Pay(params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
break
|
||||||
c.Redirect(302, r.URL)
|
|
||||||
}
|
|
||||||
resp.ERROR(c, "Invalid operations")
|
|
||||||
}
|
|
||||||
|
|
||||||
// PayQrcode 生成支付 URL 二维码
|
|
||||||
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
PayWay string `json:"pay_way"` // 支付方式
|
|
||||||
ProductId uint `json:"product_id"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var product model.Product
|
|
||||||
res := h.DB.First(&product, data.ProductId)
|
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "Product not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
orderNo, err := h.snowflake.Next(false)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user, err := h.GetLoginUser(c)
|
|
||||||
if err != nil {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var payWay string
|
|
||||||
var notifyURL string
|
|
||||||
switch data.PayWay {
|
|
||||||
case "hupi":
|
case "hupi":
|
||||||
payWay = PayWayXunHu.Value
|
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
|
||||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||||
break
|
|
||||||
case "payjs":
|
|
||||||
payWay = PayWayJs.Value
|
|
||||||
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
|
||||||
break
|
|
||||||
case "alipay":
|
|
||||||
payWay = PayWayAlipay.Value
|
|
||||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
payWay = PayWayWechat.Value
|
|
||||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
|
||||||
|
|
||||||
}
|
|
||||||
// 创建订单
|
|
||||||
remark := types.OrderRemark{
|
|
||||||
Days: product.Days,
|
|
||||||
Power: product.Power,
|
|
||||||
Name: product.Name,
|
|
||||||
Price: product.Price,
|
|
||||||
Discount: product.Discount,
|
|
||||||
}
|
|
||||||
|
|
||||||
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
|
||||||
order := model.Order{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
ProductId: product.Id,
|
|
||||||
OrderNo: orderNo,
|
|
||||||
Subject: product.Name,
|
|
||||||
Amount: amount,
|
|
||||||
Status: types.OrderNotPaid,
|
|
||||||
PayWay: payWay,
|
|
||||||
Remark: utils.JsonEncode(remark),
|
|
||||||
}
|
|
||||||
res = h.DB.Create(&order)
|
|
||||||
if res.Error != nil || res.RowsAffected == 0 {
|
|
||||||
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// PayJs 单独处理,只能用官方生成的二维码
|
|
||||||
if data.PayWay == "payjs" {
|
|
||||||
params := payment.JPayReq{
|
|
||||||
TotalFee: int(math.Ceil(order.Amount * 100)),
|
|
||||||
OutTradeNo: order.OrderNo,
|
|
||||||
Subject: product.Name,
|
|
||||||
}
|
|
||||||
r := h.jsPayService.Pay(params)
|
|
||||||
if r.IsOK() {
|
|
||||||
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
|
|
||||||
return
|
|
||||||
} else {
|
} else {
|
||||||
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
|
||||||
|
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||||
var logo string
|
|
||||||
if data.PayWay == "alipay" {
|
|
||||||
logo = "res/img/alipay.jpg"
|
|
||||||
} else if data.PayWay == "hupi" {
|
|
||||||
if h.App.Config.HuPiPayConfig.Name == "wechat" {
|
|
||||||
logo = "res/img/wechat-pay.jpg"
|
|
||||||
} else {
|
} else {
|
||||||
logo = "res/img/alipay.jpg"
|
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||||
}
|
}
|
||||||
} else if data.PayWay == "wechat" {
|
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
|
||||||
logo = "res/img/wechat-pay.jpg"
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := h.fs.Open(logo)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "error with open qrcode log file: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
parse, err := url.Parse(notifyURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
timestamp := time.Now().Unix()
|
|
||||||
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
|
|
||||||
sign := utils.Sha256(signStr)
|
|
||||||
var imageURL string
|
|
||||||
if data.PayWay == "wechat" {
|
|
||||||
payUrl, err := h.wechatPayService.PayUrlNative(order.OrderNo, int(math.Floor(order.Amount*100)), product.Name)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "error with generating wechat payment qrcode: "+err.Error())
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
imageURL = payUrl
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
imageURL = fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
|
|
||||||
}
|
|
||||||
imgData, err := utils.GenQrcode(imageURL, 400, file)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
imgDataBase64 := base64.StdEncoding.EncodeToString(imgData)
|
|
||||||
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mobile 移动端支付
|
|
||||||
func (h *PaymentHandler) Mobile(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
PayWay string `json:"pay_way"` // 支付方式
|
|
||||||
ProductId uint `json:"product_id"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var product model.Product
|
|
||||||
res := h.DB.First(&product, data.ProductId)
|
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "Product not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
orderNo, err := h.snowflake.Next(false)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user, err := h.GetLoginUser(c)
|
|
||||||
if err != nil {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
|
||||||
var payWay string
|
|
||||||
var notifyURL, returnURL string
|
|
||||||
var payURL string
|
|
||||||
switch data.PayWay {
|
|
||||||
case "hupi":
|
|
||||||
payWay = PayWayXunHu.Name
|
|
||||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
|
||||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
|
||||||
parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
|
|
||||||
baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
|
|
||||||
params := payment.HuPiPayReq{
|
|
||||||
Version: "1.1",
|
Version: "1.1",
|
||||||
TradeOrderId: orderNo,
|
TradeOrderId: orderNo,
|
||||||
TotalFee: fmt.Sprintf("%f", amount),
|
TotalFee: fmt.Sprintf("%f", amount),
|
||||||
Title: product.Name,
|
Title: product.Name,
|
||||||
NotifyURL: notifyURL,
|
NotifyURL: notifyURL,
|
||||||
ReturnURL: returnURL,
|
ReturnURL: returnURL,
|
||||||
CallbackURL: returnURL,
|
WapName: "GeekAI助手",
|
||||||
WapName: "极客学长",
|
})
|
||||||
WapUrl: baseURL,
|
|
||||||
Type: "WAP",
|
|
||||||
}
|
|
||||||
r, err := h.huPiPayService.Pay(params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := "error with generating Pay Hupi URL: " + err.Error()
|
resp.ERROR(c, err.Error())
|
||||||
logger.Error(errMsg)
|
|
||||||
resp.ERROR(c, errMsg)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payURL = r.URL
|
payURL = r.URL
|
||||||
case "payjs":
|
break
|
||||||
payWay = PayWayJs.Name
|
case "geek":
|
||||||
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
if h.App.Config.GeekPayConfig.NotifyURL != "" {
|
||||||
returnURL = h.App.Config.JPayConfig.ReturnURL
|
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
|
||||||
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
|
} else {
|
||||||
params := url.Values{}
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
|
||||||
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
|
}
|
||||||
params.Add("out_trade_no", orderNo)
|
if h.App.Config.GeekPayConfig.ReturnURL != "" {
|
||||||
params.Add("body", product.Name)
|
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
|
||||||
params.Add("notify_url", notifyURL)
|
}
|
||||||
params.Add("auto", "0")
|
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
|
||||||
payURL = h.jsPayService.PayH5(params)
|
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
|
||||||
case "alipay":
|
} else {
|
||||||
payWay = PayWayAlipay.Name
|
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||||
payURL, err = h.alipayService.PayUrlMobile(orderNo, fmt.Sprintf("%.2f", amount), product.Name)
|
}
|
||||||
|
params := payment.GeekPayParams{
|
||||||
|
OutTradeNo: orderNo,
|
||||||
|
Method: "web",
|
||||||
|
Name: product.Name,
|
||||||
|
Money: fmt.Sprintf("%f", amount),
|
||||||
|
ClientIP: c.ClientIP(),
|
||||||
|
Device: data.Device,
|
||||||
|
Type: data.PayType,
|
||||||
|
ReturnURL: returnURL,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := h.geekPayService.Pay(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := "error with generating Alipay URL: " + err.Error()
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, errMsg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case "wechat":
|
|
||||||
payWay = PayWayWechat.Name
|
|
||||||
payURL, err = h.wechatPayService.PayUrlH5(orderNo, int(amount*100), product.Name, c.ClientIP())
|
|
||||||
if err != nil {
|
|
||||||
errMsg := "error with generating Wechat URL: " + err.Error()
|
|
||||||
logger.Error(errMsg)
|
|
||||||
resp.ERROR(c, errMsg)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
payURL = res.PayURL
|
||||||
default:
|
default:
|
||||||
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
|
resp.ERROR(c, "不支持的支付渠道")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建订单
|
// 创建订单
|
||||||
remark := types.OrderRemark{
|
remark := types.OrderRemark{
|
||||||
Days: product.Days,
|
Days: product.Days,
|
||||||
@@ -390,7 +240,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
|||||||
Price: product.Price,
|
Price: product.Price,
|
||||||
Discount: product.Discount,
|
Discount: product.Discount,
|
||||||
}
|
}
|
||||||
|
|
||||||
order := model.Order{
|
order := model.Order{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
@@ -399,26 +248,24 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
|||||||
Subject: product.Name,
|
Subject: product.Name,
|
||||||
Amount: amount,
|
Amount: amount,
|
||||||
Status: types.OrderNotPaid,
|
Status: types.OrderNotPaid,
|
||||||
PayWay: payWay,
|
PayWay: data.PayWay,
|
||||||
|
PayType: data.PayType,
|
||||||
Remark: utils.JsonEncode(remark),
|
Remark: utils.JsonEncode(remark),
|
||||||
}
|
}
|
||||||
res = h.DB.Create(&order)
|
err = h.DB.Create(&order).Error
|
||||||
if res.Error != nil || res.RowsAffected == 0 {
|
if err != nil {
|
||||||
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
resp.ERROR(c, "error with create order: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
resp.SUCCESS(c, payURL)
|
||||||
resp.SUCCESS(c, gin.H{"url": payURL, "order_no": orderNo})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步通知回调公共逻辑
|
// 异步通知回调公共逻辑
|
||||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
err := fmt.Errorf("error with fetch order: %v", res.Error)
|
return fmt.Errorf("error with fetch order: %v", err)
|
||||||
logger.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
h.lock.Lock()
|
h.lock.Lock()
|
||||||
@@ -430,45 +277,24 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res = h.DB.First(&user, order.UserId)
|
err = h.DB.First(&user, order.UserId).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
err := fmt.Errorf("error with fetch user info: %v", res.Error)
|
return fmt.Errorf("error with fetch user info: %v", err)
|
||||||
logger.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var remark types.OrderRemark
|
var remark types.OrderRemark
|
||||||
err := utils.JsonDecode(order.Remark, &remark)
|
err = utils.JsonDecode(order.Remark, &remark)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := fmt.Errorf("error with decode order remark: %v", err)
|
return fmt.Errorf("error with decode order remark: %v", err)
|
||||||
logger.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var opt string
|
// 增加用户算力
|
||||||
var power int
|
err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
|
||||||
if remark.Days > 0 { // VIP 充值
|
Type: types.PowerRecharge,
|
||||||
if user.ExpiredTime >= time.Now().Unix() {
|
Model: order.PayWay,
|
||||||
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||||
opt = "VIP充值,VIP 没到期,只延期不增加算力"
|
})
|
||||||
} else {
|
if err != nil {
|
||||||
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
|
||||||
user.Power += h.App.SysConfig.VipMonthPower
|
|
||||||
power = h.App.SysConfig.VipMonthPower
|
|
||||||
opt = "VIP充值"
|
|
||||||
}
|
|
||||||
user.Vip = true
|
|
||||||
} else { // 充值点卡,直接增加次数即可
|
|
||||||
user.Power += remark.Power
|
|
||||||
opt = "点卡充值"
|
|
||||||
power = remark.Power
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户信息
|
|
||||||
res = h.DB.Updates(&user)
|
|
||||||
if res.Error != nil {
|
|
||||||
err := fmt.Errorf("error with update user info: %v", res.Error)
|
|
||||||
logger.Error(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -476,29 +302,16 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
order.PayTime = time.Now().Unix()
|
order.PayTime = time.Now().Unix()
|
||||||
order.Status = types.OrderPaidSuccess
|
order.Status = types.OrderPaidSuccess
|
||||||
order.TradeNo = tradeNo
|
order.TradeNo = tradeNo
|
||||||
res = h.DB.Updates(&order)
|
err = h.DB.Updates(&order).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
err := fmt.Errorf("error with update order info: %v", res.Error)
|
return fmt.Errorf("error with update order info: %v", err)
|
||||||
logger.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新产品销量
|
// 更新产品销量
|
||||||
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
err = h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).
|
||||||
|
UpdateColumn("sales", gorm.Expr("sales + ?", 1)).Error
|
||||||
// 记录算力充值日志
|
if err != nil {
|
||||||
if power > 0 {
|
return fmt.Errorf("error with update product sales: %v", err)
|
||||||
h.DB.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerRecharge,
|
|
||||||
Amount: power,
|
|
||||||
Balance: user.Power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Model: order.PayWay,
|
|
||||||
Remark: fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -506,20 +319,22 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
|
|
||||||
// GetPayWays 获取支付方式
|
// GetPayWays 获取支付方式
|
||||||
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
||||||
data := gin.H{}
|
payWays := make([]gin.H, 0)
|
||||||
if h.App.Config.AlipayConfig.Enabled {
|
if h.App.Config.AlipayConfig.Enabled {
|
||||||
data["alipay"] = gin.H{"name": "alipay"}
|
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
|
||||||
}
|
}
|
||||||
if h.App.Config.HuPiPayConfig.Enabled {
|
if h.App.Config.HuPiPayConfig.Enabled {
|
||||||
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
|
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
|
||||||
}
|
}
|
||||||
if h.App.Config.JPayConfig.Enabled {
|
if h.App.Config.GeekPayConfig.Enabled {
|
||||||
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
|
for _, v := range h.App.Config.GeekPayConfig.Methods {
|
||||||
|
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if h.App.Config.WechatPayConfig.Enabled {
|
if h.App.Config.WechatPayConfig.Enabled {
|
||||||
data["wechat"] = gin.H{"name": "wechat"}
|
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, data)
|
resp.SUCCESS(c, payWays)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HuPiPayNotify 虎皮椒支付异步回调
|
// HuPiPayNotify 虎皮椒支付异步回调
|
||||||
@@ -532,15 +347,17 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
|
|||||||
|
|
||||||
orderNo := c.Request.Form.Get("trade_order_id")
|
orderNo := c.Request.Form.Get("trade_order_id")
|
||||||
tradeNo := c.Request.Form.Get("open_order_id")
|
tradeNo := c.Request.Form.Get("open_order_id")
|
||||||
logger.Infof("收到虎皮椒订单支付回调,订单 NO:%s,交易流水号:%s", orderNo, tradeNo)
|
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
|
||||||
|
|
||||||
if err = h.huPiPayService.Check(tradeNo); err != nil {
|
if err = h.huPiPayService.Check(orderNo); err != nil {
|
||||||
logger.Error("订单校验失败:", err)
|
logger.Error("订单校验失败:", err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notify(orderNo, tradeNo)
|
err = h.notify(orderNo, tradeNo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -556,18 +373,18 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:验证交易签名
|
result := h.alipayService.TradeVerify(c.Request)
|
||||||
res := h.alipayService.TradeVerify(c.Request)
|
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
|
||||||
logger.Infof("验证支付结果:%+v", res)
|
if !result.Success() {
|
||||||
if !res.Success() {
|
logger.Error("订单校验失败:", result.Message)
|
||||||
logger.Error("订单校验失败:", res.Message)
|
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tradeNo := c.Request.Form.Get("trade_no")
|
tradeNo := c.Request.Form.Get("trade_no")
|
||||||
err = h.notify(res.OutTradeNo, tradeNo)
|
err = h.notify(result.OutTradeNo, tradeNo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -575,33 +392,30 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
|||||||
c.String(http.StatusOK, "success")
|
c.String(http.StatusOK, "success")
|
||||||
}
|
}
|
||||||
|
|
||||||
// PayJsNotify PayJs 支付异步回调
|
// GeekPayNotify 支付异步回调
|
||||||
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
|
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
||||||
err := c.Request.ParseForm()
|
var params = make(map[string]string)
|
||||||
if err != nil {
|
for k := range c.Request.URL.Query() {
|
||||||
|
params[k] = c.Query(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("收到GeekPay订单支付回调:%+v", params)
|
||||||
|
// 检查支付状态
|
||||||
|
if params["trade_status"] != "TRADE_SUCCESS" {
|
||||||
|
c.String(http.StatusOK, "success")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sign := h.geekPayService.Sign(params)
|
||||||
|
if sign != c.Query("sign") {
|
||||||
|
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
orderNo := c.Request.Form.Get("out_trade_no")
|
err := h.notify(params["out_trade_no"], params["trade_no"])
|
||||||
returnCode := c.Request.Form.Get("return_code")
|
|
||||||
logger.Infof("收到PayJs订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
|
|
||||||
// 支付失败
|
|
||||||
if returnCode != "1" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 校验订单支付状态
|
|
||||||
tradeNo := c.Request.Form.Get("payjs_order_id")
|
|
||||||
err = h.jsPayService.TradeVerify(tradeNo)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("订单校验失败:", err)
|
|
||||||
c.String(http.StatusOK, "fail")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.notify(orderNo, tradeNo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -618,6 +432,7 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result := h.wechatPayService.TradeVerify(c.Request)
|
result := h.wechatPayService.TradeVerify(c.Request)
|
||||||
|
logger.Infof("收到微信商号订单支付回调:%+v", result)
|
||||||
if !result.Success() {
|
if !result.Success() {
|
||||||
logger.Error("订单校验失败:", err)
|
logger.Error("订单校验失败:", err)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
@@ -629,6 +444,7 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
|||||||
|
|
||||||
err = h.notify(result.OutTradeNo, result.TradeId)
|
err = h.notify(result.OutTradeNo, result.TradeId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
155
api/handler/prompt_handler.go
Normal file
155
api/handler/prompt_handler.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 提示词生成 handler
|
||||||
|
// 使用 AI 生成绘画指令,歌词,视频生成指令等
|
||||||
|
|
||||||
|
type PromptHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *PromptHandler {
|
||||||
|
return &PromptHandler{
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lyric 生成歌词
|
||||||
|
func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.PromptPower > 0 {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: h.getPromptModel(),
|
||||||
|
Remark: "生成歌词",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image 生成 AI 绘画提示词
|
||||||
|
func (h *PromptHandler) Image(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.App.SysConfig.PromptPower > 0 {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: h.getPromptModel(),
|
||||||
|
Remark: "生成绘画提示词",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Video 生成视频提示词
|
||||||
|
func (h *PromptHandler) Video(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.PromptPower > 0 {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: h.getPromptModel(),
|
||||||
|
Remark: "生成视频脚本",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetaPrompt 生成元提示词
|
||||||
|
func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages := make([]interface{}, 0)
|
||||||
|
messages = append(messages, types.Message{
|
||||||
|
Role: "system",
|
||||||
|
Content: service.MetaPromptTemplate,
|
||||||
|
})
|
||||||
|
messages = append(messages, types.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Task, Goal, or the Role to actor is:\n" + data.Prompt,
|
||||||
|
})
|
||||||
|
content, err := utils.SendOpenAIMessage(h.DB, messages, 0)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PromptHandler) getPromptModel() string {
|
||||||
|
if h.App.SysConfig.TranslateModelId > 0 {
|
||||||
|
var chatModel model.ChatModel
|
||||||
|
h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel)
|
||||||
|
return chatModel.Value
|
||||||
|
}
|
||||||
|
return "gpt-4o"
|
||||||
|
}
|
||||||
221
api/handler/realtime_handler.go
Normal file
221
api/handler/realtime_handler.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
// OpenAI Realtime API Relay Server
|
||||||
|
|
||||||
|
type RealtimeHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *service.UserService) *RealtimeHandler {
|
||||||
|
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||||
|
// 获取客户端请求中指定的子协议
|
||||||
|
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||||
|
md := c.Query("model")
|
||||||
|
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
var user model.User
|
||||||
|
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 HTTP 协议升级为 Websocket 协议
|
||||||
|
subProtocols := strings.Split(clientProtocols, ",")
|
||||||
|
ws, err := (&websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: subProtocols,
|
||||||
|
}).Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
// 目前只针对 VIP 用户可以访问
|
||||||
|
if !user.Vip {
|
||||||
|
sendError(ws, "当前功能只针对 VIP 用户开放")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
|
||||||
|
if apiKey.Id == 0 {
|
||||||
|
sendError(ws, "管理员未配置 Realtime API KEY")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
|
||||||
|
// 连接到真实的后端服务器,传入相同的子协议
|
||||||
|
headers := http.Header{}
|
||||||
|
// 修正子协议内容
|
||||||
|
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
|
||||||
|
if clientProtocols != "" {
|
||||||
|
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
|
||||||
|
}
|
||||||
|
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
|
||||||
|
if err != nil {
|
||||||
|
sendError(ws, "桥接后端 API 失败:"+err.Error())
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer backendConn.Close()
|
||||||
|
|
||||||
|
// 确保协议一致性,如果失败返回
|
||||||
|
if ws.Subprotocol() != backendConn.Subprotocol() {
|
||||||
|
sendError(ws, "Websocket 子协议不匹配")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新API KEY 最后使用时间
|
||||||
|
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
|
||||||
|
// 开始双向转发
|
||||||
|
errorChan := make(chan error, 2)
|
||||||
|
go relay(ws, backendConn, errorChan)
|
||||||
|
go relay(backendConn, ws, errorChan)
|
||||||
|
|
||||||
|
// 等待其中一个连接关闭
|
||||||
|
err = <-errorChan
|
||||||
|
logger.Infof("Relay ended: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func relay(src, dst *websocket.Conn, errorChan chan error) {
|
||||||
|
for {
|
||||||
|
messageType, message, err := src.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
errorChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = dst.WriteMessage(messageType, message)
|
||||||
|
if err != nil {
|
||||||
|
errorChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendError(ws *websocket.Conn, message string) {
|
||||||
|
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI 实时语音对话,一次性对话
|
||||||
|
func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
err := h.DB.Session(&gorm.Session{}).Where("type", "realtime").Where("enabled", true).First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("error with fetch OpenAI API KEY:%v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查用户是否还有算力
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
var user model.User
|
||||||
|
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("error with fetch user:%v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Power < h.App.SysConfig.AdvanceVoicePower {
|
||||||
|
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var response utils.OpenAIResponse
|
||||||
|
client := req.C()
|
||||||
|
if len(apiKey.ProxyURL) > 5 {
|
||||||
|
client.SetProxyURL(apiKey.ApiURL)
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||||
|
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, "advanced-voice")
|
||||||
|
r, err := client.R().SetHeader("Body-Type", "application/json").
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(types.ApiRequest{
|
||||||
|
Model: "advanced-voice",
|
||||||
|
Temperature: 0.9,
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Stream: false,
|
||||||
|
Messages: []interface{}{types.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: "实时语音通话",
|
||||||
|
}},
|
||||||
|
}).Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败:%v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败:%v", r.Status))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("解析API数据失败:%v, %s", err, string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新 API KEY 的最后使用时间
|
||||||
|
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
|
||||||
|
// 扣减算力
|
||||||
|
err = h.userService.DecreasePower(int(userId), h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: "advanced-voice",
|
||||||
|
Remark: "实时语音通话",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("Response: %v", response.Choices[0].Message.Content)
|
||||||
|
|
||||||
|
// 提取链接
|
||||||
|
re := regexp.MustCompile(`\[(.*?)\]\((.*?)\)`)
|
||||||
|
links := re.FindAllStringSubmatch(response.Choices[0].Message.Content, -1)
|
||||||
|
var url = ""
|
||||||
|
if len(links) > 0 {
|
||||||
|
url = links[0][2]
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, url)
|
||||||
|
}
|
||||||
88
api/handler/redeem_handler.go
Normal file
88
api/handler/redeem_handler.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RedeemHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
lock sync.Mutex
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler {
|
||||||
|
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RedeemHandler) Verify(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
var item model.Redeem
|
||||||
|
res := h.DB.Where("code", data.Code).First(&item)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "无效的兑换码!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !item.Enabled {
|
||||||
|
resp.ERROR(c, "当前兑换码已被禁用!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.RedeemedAt > 0 {
|
||||||
|
resp.ERROR(c, "当前兑换码已使用,请勿重复使用!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := h.DB.Begin()
|
||||||
|
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
|
||||||
|
Type: types.PowerRedeem,
|
||||||
|
Model: "兑换码",
|
||||||
|
Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新核销状态
|
||||||
|
item.RedeemedAt = time.Now().Unix()
|
||||||
|
item.UserId = userId
|
||||||
|
err = tx.Updates(&item).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"geekai/core"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/store/model"
|
|
||||||
"geekai/store/vo"
|
|
||||||
"geekai/utils"
|
|
||||||
"geekai/utils/resp"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"math"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RewardHandler struct {
|
|
||||||
BaseHandler
|
|
||||||
lock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
|
||||||
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify 打赏码核销
|
|
||||||
func (h *RewardHandler) Verify(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
TxId string `json:"tx_id"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := h.GetLoginUser(c)
|
|
||||||
if err != nil {
|
|
||||||
resp.HACKER(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
|
|
||||||
data.TxId = strings.ReplaceAll(data.TxId, " ", "")
|
|
||||||
|
|
||||||
h.lock.Lock()
|
|
||||||
defer h.lock.Unlock()
|
|
||||||
|
|
||||||
var item model.Reward
|
|
||||||
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
|
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "无效的交易流水号!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.Status {
|
|
||||||
resp.ERROR(c, "当前交易流水号已经被核销,请不要重复核销!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tx := h.DB.Begin()
|
|
||||||
exchange := vo.RewardExchange{}
|
|
||||||
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
|
|
||||||
exchange.Power = int(power)
|
|
||||||
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
logger.Error("添加应用失败:", res.Error)
|
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新核销状态
|
|
||||||
item.Status = true
|
|
||||||
item.UserId = user.Id
|
|
||||||
item.Exchange = utils.JsonEncode(exchange)
|
|
||||||
res = tx.Updates(&item)
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
logger.Error("添加应用失败:", res.Error)
|
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 记录算力充值日志
|
|
||||||
h.DB.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerReward,
|
|
||||||
Amount: exchange.Power,
|
|
||||||
Balance: user.Power + exchange.Power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Model: "众筹支付",
|
|
||||||
Remark: fmt.Sprintf("充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
tx.Commit()
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -19,11 +19,8 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"net/http"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -31,19 +28,27 @@ import (
|
|||||||
|
|
||||||
type SdJobHandler struct {
|
type SdJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
pool *sd.ServicePool
|
sdService *sd.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
leveldb *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
func NewSdJobHandler(app *core.AppServer,
|
||||||
|
db *gorm.DB,
|
||||||
|
service *sd.Service,
|
||||||
|
manager *oss.UploaderManager,
|
||||||
|
snowflake *service.Snowflake,
|
||||||
|
userService *service.UserService,
|
||||||
|
levelDB *store.LevelDB) *SdJobHandler {
|
||||||
return &SdJobHandler{
|
return &SdJobHandler{
|
||||||
pool: pool,
|
sdService: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
leveldb: levelDB,
|
leveldb: levelDB,
|
||||||
|
userService: userService,
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -51,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, man
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
|
||||||
func (h *SdJobHandler) Client(c *gin.Context) {
|
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
if userId == 0 {
|
|
||||||
logger.Info("Invalid user ID")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
h.pool.Clients.Put(uint(userId), client)
|
|
||||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,11 +63,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !h.pool.HasAvailableService() {
|
|
||||||
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.SdPower {
|
if user.Power < h.App.SysConfig.SdPower {
|
||||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
@@ -130,29 +109,37 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
resp.ERROR(c, "error with generate task id: "+err.Error())
|
resp.ERROR(c, "error with generate task id: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
params := types.SdTaskParams{
|
|
||||||
TaskId: taskId,
|
task := types.SdTask{
|
||||||
Prompt: data.Prompt,
|
ClientId: data.ClientId,
|
||||||
NegPrompt: data.NegPrompt,
|
Type: types.TaskImage,
|
||||||
Steps: data.Steps,
|
Params: types.SdTaskParams{
|
||||||
Sampler: data.Sampler,
|
TaskId: taskId,
|
||||||
FaceFix: data.FaceFix,
|
Prompt: data.Prompt,
|
||||||
CfgScale: data.CfgScale,
|
NegPrompt: data.NegPrompt,
|
||||||
Seed: data.Seed,
|
Steps: data.Steps,
|
||||||
Height: data.Height,
|
Sampler: data.Sampler,
|
||||||
Width: data.Width,
|
FaceFix: data.FaceFix,
|
||||||
HdFix: data.HdFix,
|
CfgScale: data.CfgScale,
|
||||||
HdRedrawRate: data.HdRedrawRate,
|
Seed: data.Seed,
|
||||||
HdScale: data.HdScale,
|
Height: data.Height,
|
||||||
HdScaleAlg: data.HdScaleAlg,
|
Width: data.Width,
|
||||||
HdSteps: data.HdSteps,
|
HdFix: data.HdFix,
|
||||||
|
HdRedrawRate: data.HdRedrawRate,
|
||||||
|
HdScale: data.HdScale,
|
||||||
|
HdScaleAlg: data.HdScaleAlg,
|
||||||
|
HdSteps: data.HdSteps,
|
||||||
|
},
|
||||||
|
UserId: userId,
|
||||||
|
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||||
}
|
}
|
||||||
|
|
||||||
job := model.SdJob{
|
job := model.SdJob{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Type: types.TaskImage.String(),
|
Type: types.TaskImage.String(),
|
||||||
TaskId: params.TaskId,
|
TaskId: taskId,
|
||||||
Params: utils.JsonEncode(params),
|
Params: utils.JsonEncode(task.Params),
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Power: h.App.SysConfig.SdPower,
|
Power: h.App.SysConfig.SdPower,
|
||||||
@@ -164,34 +151,18 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.SdTask{
|
task.Id = int(job.Id)
|
||||||
Id: int(job.Id),
|
h.sdService.PushTask(task)
|
||||||
Type: types.TaskImage,
|
|
||||||
Params: params,
|
|
||||||
UserId: userId,
|
|
||||||
})
|
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// update user's power
|
// update user's power
|
||||||
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
// 记录算力变化日志
|
Type: types.PowerConsume,
|
||||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
Model: "stable-diffusion",
|
||||||
user, _ := h.GetLoginUser(c)
|
Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
|
||||||
h.DB.Create(&model.PowerLog{
|
})
|
||||||
UserId: user.Id,
|
if err != nil {
|
||||||
Username: user.Username,
|
resp.ERROR(c, err.Error())
|
||||||
Type: types.PowerConsume,
|
return
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power - job.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Model: "stable-diffusion",
|
|
||||||
Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
@@ -228,11 +199,11 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// JobList 获取 MJ 任务列表
|
// JobList 获取 MJ 任务列表
|
||||||
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
|
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
|
||||||
|
|
||||||
session := h.DB.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if finish {
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress >= ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
}
|
}
|
||||||
@@ -247,10 +218,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
|||||||
session = session.Offset(offset).Limit(pageSize)
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 统计总数
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.SdJob{}).Count(&total)
|
||||||
|
|
||||||
var items []model.SdJob
|
var items []model.SdJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error, nil
|
return res.Error, vo.Page{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.SdJob, 0)
|
var jobs = make([]vo.SdJob, 0)
|
||||||
@@ -260,62 +235,47 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if item.Progress < 100 {
|
|
||||||
// 从 leveldb 中获取图片预览数据
|
|
||||||
var imageData string
|
|
||||||
err = h.leveldb.Get(item.TaskId, &imageData)
|
|
||||||
if err == nil {
|
|
||||||
job.ImgURL = "data:image/png;base64," + imageData
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, jobs
|
return nil, vo.NewPage(total, page, pageSize, jobs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
func (h *SdJobHandler) Remove(c *gin.Context) {
|
func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
var job model.SdJob
|
var job model.SdJob
|
||||||
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
||||||
resp.ERROR(c, "记录不存在")
|
resp.ERROR(c, "记录不存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// 删除任务
|
||||||
res := h.DB.Delete(&model.SdJob{Id: job.Id})
|
err := h.DB.Delete(&job).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove image
|
// remove image
|
||||||
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("remove image failed: ", err)
|
logger.Error("remove image failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte(sd.Finished))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish 发布/取消发布图片到画廊显示
|
// Publish 发布/取消发布图片到画廊显示
|
||||||
func (h *SdJobHandler) Publish(c *gin.Context) {
|
func (h *SdJobHandler) Publish(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||||
|
|
||||||
res := h.DB.Model(&model.SdJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action)
|
err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,15 +56,17 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var check bool
|
if h.App.SysConfig.EnabledVerify {
|
||||||
if data.X != 0 {
|
var check bool
|
||||||
check = h.captcha.SlideCheck(data)
|
if data.X != 0 {
|
||||||
} else {
|
check = h.captcha.SlideCheck(data)
|
||||||
check = h.captcha.Check(data)
|
} else {
|
||||||
}
|
check = h.captcha.Check(data)
|
||||||
if !check {
|
}
|
||||||
resp.ERROR(c, "验证码错误,请先完人机验证")
|
if !check {
|
||||||
return
|
resp.ERROR(c, "请先完人机验证")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
code := utils.RandomNumber(6)
|
code := utils.RandomNumber(6)
|
||||||
@@ -74,6 +76,20 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
resp.ERROR(c, "系统已禁用邮箱注册!")
|
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 检查邮箱后缀是否在白名单
|
||||||
|
if len(h.App.SysConfig.EmailWhiteList) > 0 {
|
||||||
|
inWhiteList := false
|
||||||
|
for _, suffix := range h.App.SysConfig.EmailWhiteList {
|
||||||
|
if strings.HasSuffix(data.Receiver, suffix) {
|
||||||
|
inWhiteList = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !inWhiteList {
|
||||||
|
resp.ERROR(c, "邮箱后缀不在白名单中")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||||
} else {
|
} else {
|
||||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/suno"
|
"geekai/service/suno"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -18,53 +19,33 @@ import (
|
|||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"net/http"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SunoHandler struct {
|
type SunoHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
service *suno.Service
|
sunoService *suno.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler {
|
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
|
||||||
return &SunoHandler{
|
return &SunoHandler{
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
},
|
},
|
||||||
service: service,
|
sunoService: service,
|
||||||
uploader: uploader,
|
uploader: uploader,
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
|
||||||
func (h *SunoHandler) Client(c *gin.Context) {
|
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
if userId == 0 {
|
|
||||||
logger.Info("Invalid user ID")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
h.service.Clients.Put(uint(userId), client)
|
|
||||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SunoHandler) Create(c *gin.Context) {
|
func (h *SunoHandler) Create(c *gin.Context) {
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Instrumental bool `json:"instrumental"`
|
Instrumental bool `json:"instrumental"`
|
||||||
Lyrics string `json:"lyrics"`
|
Lyrics string `json:"lyrics"`
|
||||||
@@ -72,21 +53,65 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
|||||||
Tags string `json:"tags"`
|
Tags string `json:"tags"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Type int `json:"type"`
|
Type int `json:"type"`
|
||||||
RefTaskId string `json:"ref_task_id"` // 续写的任务id
|
RefTaskId string `json:"ref_task_id"` // 续写的任务id
|
||||||
ExtendSecs int `json:"extend_secs"` // 续写秒数
|
ExtendSecs int `json:"extend_secs"` // 续写秒数
|
||||||
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
|
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
|
||||||
|
SongId string `json:"song_id,omitempty"` // 要拼接的歌曲id
|
||||||
|
AudioURL string `json:"audio_url,omitempty"` // 上传自己创作的歌曲
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user, err := h.GetLoginUser(c)
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Power < h.App.SysConfig.SunoPower {
|
||||||
|
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 歌曲拼接
|
||||||
|
if data.SongId != "" && data.Type == 3 {
|
||||||
|
var song model.SunoJob
|
||||||
|
if err := h.DB.Where("song_id = ?", data.SongId).First(&song).Error; err == nil {
|
||||||
|
data.Instrumental = song.Instrumental
|
||||||
|
data.Model = song.ModelName
|
||||||
|
data.Tags = song.Tags
|
||||||
|
}
|
||||||
|
// 拼接歌词
|
||||||
|
var refSong model.SunoJob
|
||||||
|
if err := h.DB.Where("song_id = ?", data.RefSongId).First(&refSong).Error; err == nil {
|
||||||
|
data.Prompt = fmt.Sprintf("%s\n%s", song.Prompt, refSong.Prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
task := types.SunoTask{
|
||||||
|
ClientId: data.ClientId,
|
||||||
|
UserId: int(h.GetLoginUserId(c)),
|
||||||
|
Type: data.Type,
|
||||||
|
Title: data.Title,
|
||||||
|
RefTaskId: data.RefTaskId,
|
||||||
|
RefSongId: data.RefSongId,
|
||||||
|
ExtendSecs: data.ExtendSecs,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
Tags: data.Tags,
|
||||||
|
Model: data.Model,
|
||||||
|
Instrumental: data.Instrumental,
|
||||||
|
SongId: data.SongId,
|
||||||
|
AudioURL: data.AudioURL,
|
||||||
|
}
|
||||||
|
|
||||||
// 插入数据库
|
// 插入数据库
|
||||||
job := model.SunoJob{
|
job := model.SunoJob{
|
||||||
UserId: int(h.GetLoginUserId(c)),
|
UserId: task.UserId,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Instrumental: data.Instrumental,
|
Instrumental: data.Instrumental,
|
||||||
ModelName: data.Model,
|
ModelName: data.Model,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Tags: data.Tags,
|
Tags: data.Tags,
|
||||||
Title: data.Title,
|
Title: data.Title,
|
||||||
Type: data.Type,
|
Type: data.Type,
|
||||||
@@ -94,6 +119,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
|||||||
RefTaskId: data.RefTaskId,
|
RefTaskId: data.RefTaskId,
|
||||||
ExtendSecs: data.ExtendSecs,
|
ExtendSecs: data.ExtendSecs,
|
||||||
Power: h.App.SysConfig.SunoPower,
|
Power: h.App.SysConfig.SunoPower,
|
||||||
|
SongId: utils.RandString(32),
|
||||||
}
|
}
|
||||||
if data.Lyrics != "" {
|
if data.Lyrics != "" {
|
||||||
job.Prompt = data.Lyrics
|
job.Prompt = data.Lyrics
|
||||||
@@ -105,49 +131,28 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建任务
|
// 创建任务
|
||||||
h.service.PushTask(types.SunoTask{
|
task.Id = job.Id
|
||||||
Id: job.Id,
|
h.sunoService.PushTask(task)
|
||||||
UserId: job.UserId,
|
|
||||||
Type: job.Type,
|
|
||||||
Title: job.Title,
|
|
||||||
RefTaskId: data.RefTaskId,
|
|
||||||
RefSongId: data.RefSongId,
|
|
||||||
ExtendSecs: data.ExtendSecs,
|
|
||||||
Prompt: job.Prompt,
|
|
||||||
Tags: data.Tags,
|
|
||||||
Model: data.Model,
|
|
||||||
Instrumental: data.Instrumental,
|
|
||||||
})
|
|
||||||
|
|
||||||
// update user's power
|
// update user's power
|
||||||
tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
// 记录算力变化日志
|
Type: types.PowerConsume,
|
||||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
Model: job.ModelName,
|
||||||
user, _ := h.GetLoginUser(c)
|
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
|
||||||
h.DB.Create(&model.PowerLog{
|
CreatedAt: time.Now(),
|
||||||
UserId: user.Id,
|
})
|
||||||
Username: user.Username,
|
if err != nil {
|
||||||
Type: types.PowerConsume,
|
resp.ERROR(c, err.Error())
|
||||||
Amount: job.Power,
|
return
|
||||||
Balance: user.Power - job.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Model: job.ModelName,
|
|
||||||
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client := h.service.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SunoHandler) List(c *gin.Context) {
|
func (h *SunoHandler) List(c *gin.Context) {
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 1)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
||||||
|
|
||||||
// 统计总数
|
// 统计总数
|
||||||
@@ -209,8 +214,20 @@ func (h *SunoHandler) Remove(c *gin.Context) {
|
|||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 只有失败或者已完成的任务可以删除
|
||||||
|
if !(job.Progress == service.FailTaskProgress || job.Progress == 100) {
|
||||||
|
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 删除任务
|
// 删除任务
|
||||||
h.DB.Delete(&job)
|
err = h.DB.Delete(&job).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 删除文件
|
// 删除文件
|
||||||
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
|
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
|
||||||
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
|
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
|
||||||
@@ -306,40 +323,3 @@ func (h *SunoHandler) Play(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
|
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
const genLyricTemplate = `
|
|
||||||
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
|
|
||||||
请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。
|
|
||||||
输出格式如下:
|
|
||||||
歌曲名称
|
|
||||||
第一节:
|
|
||||||
{{歌词内容}}
|
|
||||||
副歌:
|
|
||||||
{{歌词内容}}
|
|
||||||
|
|
||||||
第二节:
|
|
||||||
{{歌词内容}}
|
|
||||||
副歌:
|
|
||||||
{{歌词内容}}
|
|
||||||
|
|
||||||
尾声:
|
|
||||||
{{歌词内容}}
|
|
||||||
`
|
|
||||||
|
|
||||||
// Lyric 生成歌词
|
|
||||||
func (h *SunoHandler) Lyric(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,15 +3,52 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/payment"
|
"geekai/service/payment"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestHandler struct {
|
type TestHandler struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
js *payment.JPayService
|
js *payment.GeekPayService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.JPayService) *TestHandler {
|
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
|
||||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *TestHandler) SseTest(c *gin.Context) {
|
||||||
|
//c.Header("Body-Type", "text/event-stream")
|
||||||
|
//c.Header("Cache-Control", "no-cache")
|
||||||
|
//c.Header("Connection", "keep-alive")
|
||||||
|
//
|
||||||
|
//
|
||||||
|
//// 模拟实时数据更新
|
||||||
|
//for i := 0; i < 10; i++ {
|
||||||
|
// // 发送 SSE 数据
|
||||||
|
// _, err := fmt.Fprintf(c.Writer, "data: %v\n\n", data)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// c.Writer.Flush() // 确保立即发送数据
|
||||||
|
// time.Sleep(1 * time.Second) // 每秒发送一次数据
|
||||||
|
//}
|
||||||
|
//c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TestHandler) PostTest(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将参数存储在上下文中
|
||||||
|
c.Set("data", data)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ type UserHandler struct {
|
|||||||
searcher *xdb.Searcher
|
searcher *xdb.Searcher
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
|
captcha *service.CaptchaService
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(
|
func NewUserHandler(
|
||||||
@@ -40,12 +42,16 @@ func NewUserHandler(
|
|||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
searcher *xdb.Searcher,
|
searcher *xdb.Searcher,
|
||||||
client *redis.Client,
|
client *redis.Client,
|
||||||
|
captcha *service.CaptchaService,
|
||||||
|
userService *service.UserService,
|
||||||
licenseService *service.LicenseService) *UserHandler {
|
licenseService *service.LicenseService) *UserHandler {
|
||||||
return &UserHandler{
|
return &UserHandler{
|
||||||
BaseHandler: BaseHandler{DB: db, App: app},
|
BaseHandler: BaseHandler{DB: db, App: app},
|
||||||
searcher: searcher,
|
searcher: searcher,
|
||||||
redis: client,
|
redis: client,
|
||||||
|
captcha: captcha,
|
||||||
licenseService: licenseService,
|
licenseService: licenseService,
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,14 +61,33 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
var data struct {
|
var data struct {
|
||||||
RegWay string `json:"reg_way"`
|
RegWay string `json:"reg_way"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
Mobile string `json:"mobile"`
|
||||||
|
Email string `json:"email"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
InviteCode string `json:"invite_code"`
|
InviteCode string `json:"invite_code"`
|
||||||
|
Key string `json:"key,omitempty"`
|
||||||
|
Dots string `json:"dots,omitempty"`
|
||||||
|
X int `json:"x,omitempty"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
|
||||||
|
var check bool
|
||||||
|
if data.X != 0 {
|
||||||
|
check = h.captcha.SlideCheck(data)
|
||||||
|
} else {
|
||||||
|
check = h.captcha.Check(data)
|
||||||
|
}
|
||||||
|
if !check {
|
||||||
|
resp.ERROR(c, "请先完人机验证")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data.Password = strings.TrimSpace(data.Password)
|
data.Password = strings.TrimSpace(data.Password)
|
||||||
if len(data.Password) < 8 {
|
if len(data.Password) < 8 {
|
||||||
resp.ERROR(c, "密码长度不能少于8个字符")
|
resp.ERROR(c, "密码长度不能少于8个字符")
|
||||||
@@ -79,8 +104,15 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查验证码
|
// 检查验证码
|
||||||
var key string
|
var key string
|
||||||
if data.RegWay == "email" || data.RegWay == "mobile" {
|
if data.RegWay == "email" {
|
||||||
key = CodeStorePrefix + data.Username
|
key = CodeStorePrefix + data.Email
|
||||||
|
code, err := h.redis.Get(c, key).Result()
|
||||||
|
if err != nil || code != data.Code {
|
||||||
|
resp.ERROR(c, "验证码错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if data.RegWay == "mobile" {
|
||||||
|
key = CodeStorePrefix + data.Mobile
|
||||||
code, err := h.redis.Get(c, key).Result()
|
code, err := h.redis.Get(c, key).Result()
|
||||||
if err != nil || code != data.Code {
|
if err != nil || code != data.Code {
|
||||||
resp.ERROR(c, "验证码错误")
|
resp.ERROR(c, "验证码错误")
|
||||||
@@ -98,26 +130,37 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
salt := utils.RandString(8)
|
||||||
|
user := model.User{
|
||||||
|
Username: data.Username,
|
||||||
|
Password: utils.GenPassword(data.Password, salt),
|
||||||
|
Avatar: "/images/avatar/user.png",
|
||||||
|
Salt: salt,
|
||||||
|
Status: true,
|
||||||
|
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||||
|
Power: h.App.SysConfig.InitPower,
|
||||||
|
}
|
||||||
|
|
||||||
// check if the username is existing
|
// check if the username is existing
|
||||||
var item model.User
|
var item model.User
|
||||||
res := h.DB.Where("username = ?", data.Username).First(&item)
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Mobile != "" {
|
||||||
|
session = session.Where("mobile = ?", data.Mobile)
|
||||||
|
user.Username = data.Mobile
|
||||||
|
user.Mobile = data.Mobile
|
||||||
|
} else if data.Email != "" {
|
||||||
|
session = session.Where("email = ?", data.Email)
|
||||||
|
user.Username = data.Email
|
||||||
|
user.Email = data.Email
|
||||||
|
} else if data.Username != "" {
|
||||||
|
session = session.Where("username = ?", data.Username)
|
||||||
|
}
|
||||||
|
session.First(&item)
|
||||||
if item.Id > 0 {
|
if item.Id > 0 {
|
||||||
resp.ERROR(c, "该用户名已经被注册")
|
resp.ERROR(c, "该用户名已经被注册")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
salt := utils.RandString(8)
|
|
||||||
user := model.User{
|
|
||||||
Username: data.Username,
|
|
||||||
Password: utils.GenPassword(data.Password, salt),
|
|
||||||
Avatar: "/images/avatar/user.png",
|
|
||||||
Salt: salt,
|
|
||||||
Status: true,
|
|
||||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
|
||||||
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
|
||||||
Power: h.App.SysConfig.InitPower,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 被邀请人也获得赠送算力
|
// 被邀请人也获得赠送算力
|
||||||
if data.InviteCode != "" {
|
if data.InviteCode != "" {
|
||||||
user.Power += h.App.SysConfig.InvitePower
|
user.Power += h.App.SysConfig.InvitePower
|
||||||
@@ -128,10 +171,9 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.DB.Create(&user)
|
tx := h.DB.Begin()
|
||||||
if res.Error != nil {
|
if err := tx.Create(&user).Error; err != nil {
|
||||||
resp.ERROR(c, "保存数据失败")
|
resp.ERROR(c, err.Error())
|
||||||
logger.Error(res.Error)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,35 +182,35 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
// 增加邀请数量
|
// 增加邀请数量
|
||||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||||
if h.App.SysConfig.InvitePower > 0 {
|
if h.App.SysConfig.InvitePower > 0 {
|
||||||
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
|
err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
|
||||||
// 记录邀请算力充值日志
|
Type: types.PowerInvite,
|
||||||
var inviter model.User
|
Model: "Invite",
|
||||||
h.DB.Where("id", inviteCode.UserId).First(&inviter)
|
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||||||
h.DB.Create(&model.PowerLog{
|
|
||||||
UserId: inviter.Id,
|
|
||||||
Username: inviter.Username,
|
|
||||||
Type: types.PowerInvite,
|
|
||||||
Amount: h.App.SysConfig.InvitePower,
|
|
||||||
Balance: inviter.Power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Model: "",
|
|
||||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加邀请记录
|
// 添加邀请记录
|
||||||
h.DB.Create(&model.InviteLog{
|
err := tx.Create(&model.InviteLog{
|
||||||
InviterId: inviteCode.UserId,
|
InviterId: inviteCode.UserId,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
InviteCode: inviteCode.Code,
|
InviteCode: inviteCode.Code,
|
||||||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
||||||
})
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
|
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
|
||||||
|
|
||||||
// 自动登录创建 token
|
// 自动登录创建 token
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
"user_id": user.Id,
|
"user_id": user.Id,
|
||||||
@@ -193,20 +235,41 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
var data struct {
|
var data struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
Key string `json:"key,omitempty"`
|
||||||
|
Dots string `json:"dots,omitempty"`
|
||||||
|
X int `json:"x,omitempty"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
|
||||||
|
needVerify, err := h.redis.Get(c, verifyKey).Bool()
|
||||||
|
|
||||||
|
if h.App.SysConfig.EnabledVerify && needVerify {
|
||||||
|
var check bool
|
||||||
|
if data.X != 0 {
|
||||||
|
check = h.captcha.SlideCheck(data)
|
||||||
|
} else {
|
||||||
|
check = h.captcha.Check(data)
|
||||||
|
}
|
||||||
|
if !check {
|
||||||
|
resp.ERROR(c, "请先完人机验证")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.DB.Where("username = ?", data.Username).First(&user)
|
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
h.redis.Set(c, verifyKey, true, 0)
|
||||||
resp.ERROR(c, "用户名不存在")
|
resp.ERROR(c, "用户名不存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
if password != user.Password {
|
if password != user.Password {
|
||||||
|
h.redis.Set(c, verifyKey, true, 0)
|
||||||
resp.ERROR(c, "用户名或密码错误")
|
resp.ERROR(c, "用户名或密码错误")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -239,11 +302,13 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 保存到 redis
|
// 保存到 redis
|
||||||
key := fmt.Sprintf("users/%d", user.Id)
|
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
||||||
resp.ERROR(c, "error with save token: "+err.Error())
|
resp.ERROR(c, "error with save token: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 移除登录行为验证码
|
||||||
|
h.redis.Del(c, verifyKey)
|
||||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,8 +350,10 @@ func (h *UserHandler) CLogin(c *gin.Context) {
|
|||||||
|
|
||||||
// CLoginCallback 第三方登录回调
|
// CLoginCallback 第三方登录回调
|
||||||
func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
||||||
loginType := h.GetTrim(c, "login_type")
|
loginType := c.Query("login_type")
|
||||||
code := h.GetTrim(c, "code")
|
code := c.Query("code")
|
||||||
|
userId := h.GetInt(c, "user_id", 0)
|
||||||
|
action := c.Query("action")
|
||||||
|
|
||||||
var res types.BizVo
|
var res types.BizVo
|
||||||
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
|
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
|
||||||
@@ -311,11 +378,34 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
|||||||
|
|
||||||
// login successfully
|
// login successfully
|
||||||
data := res.Data.(map[string]interface{})
|
data := res.Data.(map[string]interface{})
|
||||||
session := gin.H{}
|
|
||||||
var user model.User
|
var user model.User
|
||||||
tx := h.DB.Debug().Where("openid", data["openid"]).First(&user)
|
if action == "bind" && userId > 0 {
|
||||||
if tx.Error != nil { // user not exist, create new user
|
err = h.DB.Where("openid", data["openid"]).First(&user).Error
|
||||||
// 检测最大注册人数
|
if err == nil {
|
||||||
|
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.DB.Where("id", userId).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "绑定用户不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "更新用户信息失败,"+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, gin.H{"token": ""})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := gin.H{}
|
||||||
|
tx := h.DB.Where("openid", data["openid"]).First(&user)
|
||||||
|
if tx.Error != nil {
|
||||||
|
// create new user
|
||||||
var totalUser int64
|
var totalUser int64
|
||||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||||
@@ -326,16 +416,15 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
|||||||
salt := utils.RandString(8)
|
salt := utils.RandString(8)
|
||||||
password := fmt.Sprintf("%d", utils.RandomNumber(8))
|
password := fmt.Sprintf("%d", utils.RandomNumber(8))
|
||||||
user = model.User{
|
user = model.User{
|
||||||
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
|
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
|
||||||
Password: utils.GenPassword(password, salt),
|
Password: utils.GenPassword(password, salt),
|
||||||
Avatar: fmt.Sprintf("%s", data["avatar"]),
|
Avatar: fmt.Sprintf("%s", data["avatar"]),
|
||||||
Salt: salt,
|
Salt: salt,
|
||||||
Status: true,
|
Status: true,
|
||||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||||
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
Power: h.App.SysConfig.InitPower,
|
||||||
Power: h.App.SysConfig.InitPower,
|
OpenId: fmt.Sprintf("%s", data["openid"]),
|
||||||
OpenId: fmt.Sprintf("%s", data["openid"]),
|
Nickname: fmt.Sprintf("%s", data["nickname"]),
|
||||||
Nickname: fmt.Sprintf("%s", data["nickname"]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tx = h.DB.Create(&user)
|
tx = h.DB.Create(&user)
|
||||||
@@ -383,18 +472,24 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
|||||||
// Session 获取/验证会话
|
// Session 获取/验证会话
|
||||||
func (h *UserHandler) Session(c *gin.Context) {
|
func (h *UserHandler) Session(c *gin.Context) {
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
var userVo vo.User
|
resp.NotAuth(c, err.Error())
|
||||||
err := utils.CopyObject(user, &userVo)
|
return
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c)
|
|
||||||
}
|
|
||||||
userVo.Id = user.Id
|
|
||||||
resp.SUCCESS(c, userVo)
|
|
||||||
} else {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var userVo vo.User
|
||||||
|
err = utils.CopyObject(user, &userVo)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 用户 VIP 到期
|
||||||
|
if user.ExpiredTime > 0 && user.ExpiredTime < time.Now().Unix() {
|
||||||
|
h.DB.Model(&user).UpdateColumn("vip", false)
|
||||||
|
}
|
||||||
|
userVo.Id = user.Id
|
||||||
|
resp.SUCCESS(c, userVo)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type userProfile struct {
|
type userProfile struct {
|
||||||
@@ -481,20 +576,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
newPass := utils.GenPassword(data.Password, user.Salt)
|
newPass := utils.GenPassword(data.Password, user.Salt)
|
||||||
res := h.DB.Model(&user).UpdateColumn("password", newPass)
|
err = h.DB.Model(&user).UpdateColumn("password", newPass).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c, "更新数据库失败")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetPass 重置密码
|
// ResetPass 找回密码
|
||||||
func (h *UserHandler) ResetPass(c *gin.Context) {
|
func (h *UserHandler) ResetPass(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Username string `json:"username"`
|
Type string `json:"type"` // 验证类别:mobile, email
|
||||||
|
Mobile string `json:"mobile"` // 手机号
|
||||||
|
Email string `json:"email"` // 邮箱地址
|
||||||
Code string `json:"code"` // 验证码
|
Code string `json:"code"` // 验证码
|
||||||
Password string `json:"password"` // 新密码
|
Password string `json:"password"` // 新密码
|
||||||
}
|
}
|
||||||
@@ -503,37 +599,47 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
var key string
|
||||||
|
if data.Type == "email" {
|
||||||
|
session = session.Where("email", data.Email)
|
||||||
|
key = CodeStorePrefix + data.Email
|
||||||
|
} else if data.Type == "mobile" {
|
||||||
|
session = session.Where("mobile", data.Mobile)
|
||||||
|
key = CodeStorePrefix + data.Mobile
|
||||||
|
} else {
|
||||||
|
resp.ERROR(c, "验证类别错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.DB.Where("username", data.Username).First(&user)
|
err := session.First(&user).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "用户不存在!")
|
resp.ERROR(c, "用户不存在!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查验证码
|
// 检查验证码
|
||||||
key := CodeStorePrefix + data.Username
|
|
||||||
code, err := h.redis.Get(c, key).Result()
|
code, err := h.redis.Get(c, key).Result()
|
||||||
if err != nil || code != data.Code {
|
if err != nil || code != data.Code {
|
||||||
resp.ERROR(c, "短信验证码错误")
|
resp.ERROR(c, "验证码错误")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
user.Password = password
|
err = h.DB.Model(&user).UpdateColumn("password", password).Error
|
||||||
res = h.DB.Updates(&user)
|
if err != nil {
|
||||||
if res.Error != nil {
|
resp.ERROR(c, err.Error())
|
||||||
resp.ERROR(c)
|
|
||||||
} else {
|
} else {
|
||||||
h.redis.Del(c, key)
|
h.redis.Del(c, key)
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindUsername 重置账号
|
// BindMobile 绑定手机号
|
||||||
func (h *UserHandler) BindUsername(c *gin.Context) {
|
func (h *UserHandler) BindMobile(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Username string `json:"username"`
|
Mobile string `json:"mobile"`
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -541,7 +647,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查验证码
|
// 检查验证码
|
||||||
key := CodeStorePrefix + data.Username
|
key := CodeStorePrefix + data.Mobile
|
||||||
code, err := h.redis.Get(c, key).Result()
|
code, err := h.redis.Get(c, key).Result()
|
||||||
if err != nil || code != data.Code {
|
if err != nil || code != data.Code {
|
||||||
resp.ERROR(c, "验证码错误")
|
resp.ERROR(c, "验证码错误")
|
||||||
@@ -550,22 +656,56 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查手机号是否被其他账号绑定
|
// 检查手机号是否被其他账号绑定
|
||||||
var item model.User
|
var item model.User
|
||||||
res := h.DB.Where("username = ?", data.Username).First(&item)
|
res := h.DB.Where("mobile", data.Mobile).First(&item)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
resp.ERROR(c, "该账号已经被其他账号绑定")
|
resp.ERROR(c, "该手机号已经绑定了其他账号,请更换手机号")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := h.GetLoginUser(c)
|
userId := h.GetLoginUserId(c)
|
||||||
|
|
||||||
|
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("mobile", data.Mobile).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
|
_ = h.redis.Del(c, key) // 删除短信验证码
|
||||||
if res.Error != nil {
|
resp.SUCCESS(c)
|
||||||
logger.Error(res.Error)
|
}
|
||||||
resp.ERROR(c, "更新数据库失败")
|
|
||||||
|
// BindEmail 绑定邮箱
|
||||||
|
func (h *UserHandler) BindEmail(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查验证码
|
||||||
|
key := CodeStorePrefix + data.Email
|
||||||
|
code, err := h.redis.Get(c, key).Result()
|
||||||
|
if err != nil || code != data.Code {
|
||||||
|
resp.ERROR(c, "验证码错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查手机号是否被其他账号绑定
|
||||||
|
var item model.User
|
||||||
|
res := h.DB.Where("email", data.Email).First(&item)
|
||||||
|
if res.Error == nil {
|
||||||
|
resp.ERROR(c, "该邮箱地址已经绑定了其他账号,请更邮箱地址")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
|
||||||
|
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("email", data.Email).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
215
api/handler/video_handler.go
Normal file
215
api/handler/video_handler.go
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/service/video"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VideoHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
videoService *video.Service
|
||||||
|
uploader *oss.UploaderManager
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
|
||||||
|
return &VideoHandler{
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
|
videoService: service,
|
||||||
|
uploader: uploader,
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
FirstFrameImg string `json:"first_frame_img,omitempty"`
|
||||||
|
EndFrameImg string `json:"end_frame_img,omitempty"`
|
||||||
|
ExpandPrompt bool `json:"expand_prompt,omitempty"`
|
||||||
|
Loop bool `json:"loop,omitempty"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.GetLoginUser(c)
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Power < h.App.SysConfig.LumaPower {
|
||||||
|
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Prompt == "" {
|
||||||
|
resp.ERROR(c, "prompt is needed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := int(h.GetLoginUserId(c))
|
||||||
|
params := types.VideoParams{
|
||||||
|
PromptOptimize: data.ExpandPrompt,
|
||||||
|
Loop: data.Loop,
|
||||||
|
StartImgURL: data.FirstFrameImg,
|
||||||
|
EndImgURL: data.EndFrameImg,
|
||||||
|
}
|
||||||
|
task := types.VideoTask{
|
||||||
|
ClientId: data.ClientId,
|
||||||
|
UserId: userId,
|
||||||
|
Type: types.VideoLuma,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
Params: params,
|
||||||
|
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||||
|
}
|
||||||
|
// 插入数据库
|
||||||
|
job := model.VideoJob{
|
||||||
|
UserId: userId,
|
||||||
|
Type: types.VideoLuma,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
Power: h.App.SysConfig.LumaPower,
|
||||||
|
TaskInfo: utils.JsonEncode(task),
|
||||||
|
}
|
||||||
|
tx := h.DB.Create(&job)
|
||||||
|
if tx.Error != nil {
|
||||||
|
resp.ERROR(c, tx.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建任务
|
||||||
|
task.Id = job.Id
|
||||||
|
h.videoService.PushTask(task)
|
||||||
|
|
||||||
|
// update user's power
|
||||||
|
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Model: "luma",
|
||||||
|
Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *VideoHandler) List(c *gin.Context) {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
t := c.Query("type")
|
||||||
|
page := h.GetInt(c, "page", 1)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
|
all := h.GetBool(c, "all")
|
||||||
|
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
||||||
|
if t != "" {
|
||||||
|
session = session.Where("type", t)
|
||||||
|
}
|
||||||
|
if all {
|
||||||
|
session = session.Where("publish", 0).Where("progress", 100)
|
||||||
|
} else {
|
||||||
|
session = session.Where("user_id", h.GetLoginUserId(c))
|
||||||
|
}
|
||||||
|
// 统计总数
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.VideoJob{}).Count(&total)
|
||||||
|
|
||||||
|
if page > 0 && pageSize > 0 {
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
|
}
|
||||||
|
var list []model.VideoJob
|
||||||
|
err := session.Order("id desc").Find(&list).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 VO
|
||||||
|
items := make([]vo.VideoJob, 0)
|
||||||
|
for _, v := range list {
|
||||||
|
var item vo.VideoJob
|
||||||
|
err = utils.CopyObject(v, &item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
item.CreatedAt = v.CreatedAt.Unix()
|
||||||
|
if item.VideoURL == "" {
|
||||||
|
item.VideoURL = v.WaterURL
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *VideoHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
var job model.VideoJob
|
||||||
|
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 只有失败或者超时的任务才能删除
|
||||||
|
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
|
||||||
|
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除任务
|
||||||
|
err = h.DB.Delete(&job).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除文件
|
||||||
|
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
|
||||||
|
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *VideoHandler) Publish(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
publish := h.GetBool(c, "publish")
|
||||||
|
var job model.VideoJob
|
||||||
|
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.DB.Model(&job).UpdateColumn("publish", publish).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
151
api/handler/ws_handler.go
Normal file
151
api/handler/ws_handler.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Websocket 连接处理 handler
|
||||||
|
|
||||||
|
type WebsocketHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
wsService *service.WebsocketService
|
||||||
|
chatHandler *ChatHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
|
||||||
|
return &WebsocketHandler{
|
||||||
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
|
chatHandler: chatHandler,
|
||||||
|
wsService: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||||
|
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||||
|
ws, err := (&websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Subprotocols: strings.Split(clientProtocols, ","),
|
||||||
|
}).Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientId := c.Query("client_id")
|
||||||
|
client := types.NewWsClient(ws, clientId)
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
if userId == 0 {
|
||||||
|
_ = client.Send([]byte("Invalid user_id"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||||
|
_ = client.Send([]byte("Invalid user_id"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.wsService.Clients.Put(clientId, client)
|
||||||
|
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, msg, err := client.Receive()
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||||
|
client.Close()
|
||||||
|
h.wsService.Clients.Delete(clientId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var message types.InputMessage
|
||||||
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("Receive a message:%+v", message)
|
||||||
|
if message.Type == types.MsgTypePing {
|
||||||
|
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 当前只处理聊天消息,其他消息全部丢弃
|
||||||
|
var chatMessage types.ChatMessage
|
||||||
|
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
|
||||||
|
if err != nil || message.Channel != types.ChChat {
|
||||||
|
logger.Warnf("invalid message body:%+v", message.Body)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var chatRole model.ChatRole
|
||||||
|
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
|
||||||
|
if err != nil || !chatRole.Enable {
|
||||||
|
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// if the role bind a model_id, use role's bind model_id
|
||||||
|
if chatRole.ModelId > 0 {
|
||||||
|
chatMessage.RoleId = chatRole.ModelId
|
||||||
|
}
|
||||||
|
// get model info
|
||||||
|
var chatModel model.ChatModel
|
||||||
|
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
|
||||||
|
if err != nil || chatModel.Enabled == false {
|
||||||
|
utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
session := &types.ChatSession{
|
||||||
|
ClientIP: c.ClientIP(),
|
||||||
|
UserId: userId,
|
||||||
|
}
|
||||||
|
|
||||||
|
// use old chat data override the chat model and role ID
|
||||||
|
var chat model.ChatItem
|
||||||
|
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
|
||||||
|
if chat.Id > 0 {
|
||||||
|
chatModel.Id = chat.ModelId
|
||||||
|
chatMessage.RoleId = int(chat.RoleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
session.ChatId = chatMessage.ChatId
|
||||||
|
session.Tools = chatMessage.Tools
|
||||||
|
session.Stream = chatMessage.Stream
|
||||||
|
// 复制模型数据
|
||||||
|
err = utils.CopyObject(chatModel, &session.Model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err, chatModel)
|
||||||
|
}
|
||||||
|
session.Model.Id = chatModel.Id
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
|
||||||
|
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
utils.SendAndFlush(client, err.Error())
|
||||||
|
} else {
|
||||||
|
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
|
||||||
|
logger.Infof("回答完毕: %v", message.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
199
api/main.go
199
api/main.go
@@ -14,7 +14,6 @@ import (
|
|||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/handler/admin"
|
"geekai/handler/admin"
|
||||||
"geekai/handler/chatimpl"
|
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/dalle"
|
"geekai/service/dalle"
|
||||||
@@ -24,7 +23,7 @@ import (
|
|||||||
"geekai/service/sd"
|
"geekai/service/sd"
|
||||||
"geekai/service/sms"
|
"geekai/service/sms"
|
||||||
"geekai/service/suno"
|
"geekai/service/suno"
|
||||||
"geekai/service/wx"
|
"geekai/service/video"
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -128,10 +127,10 @@ func main() {
|
|||||||
// 创建控制器
|
// 创建控制器
|
||||||
fx.Provide(handler.NewChatRoleHandler),
|
fx.Provide(handler.NewChatRoleHandler),
|
||||||
fx.Provide(handler.NewUserHandler),
|
fx.Provide(handler.NewUserHandler),
|
||||||
fx.Provide(chatimpl.NewChatHandler),
|
fx.Provide(handler.NewChatHandler),
|
||||||
fx.Provide(handler.NewUploadHandler),
|
fx.Provide(handler.NewNetHandler),
|
||||||
fx.Provide(handler.NewSmsHandler),
|
fx.Provide(handler.NewSmsHandler),
|
||||||
fx.Provide(handler.NewRewardHandler),
|
fx.Provide(handler.NewRedeemHandler),
|
||||||
fx.Provide(handler.NewCaptchaHandler),
|
fx.Provide(handler.NewCaptchaHandler),
|
||||||
fx.Provide(handler.NewMidJourneyHandler),
|
fx.Provide(handler.NewMidJourneyHandler),
|
||||||
fx.Provide(handler.NewChatModelHandler),
|
fx.Provide(handler.NewChatModelHandler),
|
||||||
@@ -146,8 +145,8 @@ func main() {
|
|||||||
fx.Provide(admin.NewAdminHandler),
|
fx.Provide(admin.NewAdminHandler),
|
||||||
fx.Provide(admin.NewApiKeyHandler),
|
fx.Provide(admin.NewApiKeyHandler),
|
||||||
fx.Provide(admin.NewUserHandler),
|
fx.Provide(admin.NewUserHandler),
|
||||||
fx.Provide(admin.NewChatRoleHandler),
|
fx.Provide(admin.NewChatAppHandler),
|
||||||
fx.Provide(admin.NewRewardHandler),
|
fx.Provide(admin.NewRedeemHandler),
|
||||||
fx.Provide(admin.NewDashboardHandler),
|
fx.Provide(admin.NewDashboardHandler),
|
||||||
fx.Provide(admin.NewChatModelHandler),
|
fx.Provide(admin.NewChatModelHandler),
|
||||||
fx.Provide(admin.NewProductHandler),
|
fx.Provide(admin.NewProductHandler),
|
||||||
@@ -161,13 +160,12 @@ func main() {
|
|||||||
return service.NewCaptchaService(config.ApiConfig)
|
return service.NewCaptchaService(config.ApiConfig)
|
||||||
}),
|
}),
|
||||||
fx.Provide(oss.NewUploaderManager),
|
fx.Provide(oss.NewUploaderManager),
|
||||||
fx.Provide(mj.NewService),
|
|
||||||
fx.Provide(dalle.NewService),
|
fx.Provide(dalle.NewService),
|
||||||
fx.Invoke(func(service *dalle.Service) {
|
fx.Invoke(func(s *dalle.Service) {
|
||||||
service.Run()
|
s.Run()
|
||||||
service.CheckTaskNotify()
|
s.CheckTaskNotify()
|
||||||
service.DownloadImages()
|
s.DownloadImages()
|
||||||
service.CheckTaskStatus()
|
s.CheckTaskStatus()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 邮件服务
|
// 邮件服务
|
||||||
@@ -178,36 +176,22 @@ func main() {
|
|||||||
licenseService.SyncLicense()
|
licenseService.SyncLicense()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 微信机器人服务
|
|
||||||
fx.Provide(wx.NewWeChatBot),
|
|
||||||
fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
|
|
||||||
if config.WeChatBot {
|
|
||||||
err := bot.Run()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("微信登录失败:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
|
|
||||||
// MidJourney service pool
|
// MidJourney service pool
|
||||||
fx.Provide(mj.NewServicePool),
|
fx.Provide(mj.NewService),
|
||||||
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
|
fx.Provide(mj.NewClient),
|
||||||
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
|
fx.Invoke(func(s *mj.Service) {
|
||||||
if pool.HasAvailableService() {
|
s.Run()
|
||||||
pool.DownloadImages()
|
s.SyncTaskProgress()
|
||||||
pool.CheckTaskNotify()
|
s.CheckTaskNotify()
|
||||||
pool.SyncTaskProgress()
|
s.DownloadImages()
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewServicePool),
|
fx.Provide(sd.NewService),
|
||||||
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
|
fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
|
||||||
pool.InitServices(config.SdConfigs)
|
s.Run()
|
||||||
if pool.HasAvailableService() {
|
s.CheckTaskStatus()
|
||||||
pool.CheckTaskNotify()
|
s.CheckTaskNotify()
|
||||||
pool.CheckTaskStatus()
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(suno.NewService),
|
fx.Provide(suno.NewService),
|
||||||
@@ -215,9 +199,16 @@ func main() {
|
|||||||
s.Run()
|
s.Run()
|
||||||
s.SyncTaskProgress()
|
s.SyncTaskProgress()
|
||||||
s.CheckTaskNotify()
|
s.CheckTaskNotify()
|
||||||
s.DownloadImages()
|
s.DownloadFiles()
|
||||||
}),
|
}),
|
||||||
|
fx.Provide(video.NewService),
|
||||||
|
fx.Invoke(func(s *video.Service) {
|
||||||
|
s.Run()
|
||||||
|
s.SyncTaskProgress()
|
||||||
|
s.CheckTaskNotify()
|
||||||
|
s.DownloadFiles()
|
||||||
|
}),
|
||||||
|
fx.Provide(service.NewUserService),
|
||||||
fx.Provide(payment.NewAlipayService),
|
fx.Provide(payment.NewAlipayService),
|
||||||
fx.Provide(payment.NewHuPiPay),
|
fx.Provide(payment.NewHuPiPay),
|
||||||
fx.Provide(payment.NewJPayService),
|
fx.Provide(payment.NewJPayService),
|
||||||
@@ -234,8 +225,9 @@ func main() {
|
|||||||
|
|
||||||
// 注册路由
|
// 注册路由
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||||
group := s.Engine.Group("/api/role/")
|
group := s.Engine.Group("/api/app/")
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
|
group.GET("list/user", h.ListByUser)
|
||||||
group.POST("update", h.UpdateRole)
|
group.POST("update", h.UpdateRole)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
||||||
@@ -247,14 +239,14 @@ func main() {
|
|||||||
group.GET("profile", h.Profile)
|
group.GET("profile", h.Profile)
|
||||||
group.POST("profile/update", h.ProfileUpdate)
|
group.POST("profile/update", h.ProfileUpdate)
|
||||||
group.POST("password", h.UpdatePass)
|
group.POST("password", h.UpdatePass)
|
||||||
group.POST("bind/username", h.BindUsername)
|
group.POST("bind/mobile", h.BindMobile)
|
||||||
|
group.POST("bind/email", h.BindEmail)
|
||||||
group.POST("resetPass", h.ResetPass)
|
group.POST("resetPass", h.ResetPass)
|
||||||
group.GET("clogin", h.CLogin)
|
group.GET("clogin", h.CLogin)
|
||||||
group.GET("clogin/callback", h.CLoginCallback)
|
group.GET("clogin/callback", h.CLoginCallback)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
||||||
group := s.Engine.Group("/api/chat/")
|
group := s.Engine.Group("/api/chat/")
|
||||||
group.Any("new", h.ChatHandle)
|
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.GET("detail", h.Detail)
|
group.GET("detail", h.Detail)
|
||||||
group.POST("update", h.Update)
|
group.POST("update", h.Update)
|
||||||
@@ -264,10 +256,11 @@ func main() {
|
|||||||
group.POST("tokens", h.Tokens)
|
group.POST("tokens", h.Tokens)
|
||||||
group.GET("stop", h.StopGenerate)
|
group.GET("stop", h.StopGenerate)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
|
||||||
s.Engine.POST("/api/upload", h.Upload)
|
s.Engine.POST("/api/upload", h.Upload)
|
||||||
s.Engine.POST("/api/upload/list", h.List)
|
s.Engine.POST("/api/upload/list", h.List)
|
||||||
s.Engine.GET("/api/upload/remove", h.Remove)
|
s.Engine.GET("/api/upload/remove", h.Remove)
|
||||||
|
s.Engine.GET("/api/download", h.Download)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
||||||
group := s.Engine.Group("/api/sms/")
|
group := s.Engine.Group("/api/sms/")
|
||||||
@@ -280,13 +273,12 @@ func main() {
|
|||||||
group.GET("slide/get", h.SlideGet)
|
group.GET("slide/get", h.SlideGet)
|
||||||
group.POST("slide/check", h.SlideCheck)
|
group.POST("slide/check", h.SlideCheck)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
|
||||||
group := s.Engine.Group("/api/reward/")
|
group := s.Engine.Group("/api/redeem/")
|
||||||
group.POST("verify", h.Verify)
|
group.POST("verify", h.Verify)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||||
group := s.Engine.Group("/api/mj/")
|
group := s.Engine.Group("/api/mj/")
|
||||||
group.Any("client", h.Client)
|
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
@@ -297,7 +289,6 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||||
group := s.Engine.Group("/api/sd")
|
group := s.Engine.Group("/api/sd")
|
||||||
group.Any("client", h.Client)
|
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
group.GET("imgWall", h.ImgWall)
|
group.GET("imgWall", h.ImgWall)
|
||||||
@@ -312,13 +303,12 @@ func main() {
|
|||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||||
group := s.Engine.Group("/api/admin/")
|
group := s.Engine.Group("/api/admin/config")
|
||||||
group.POST("config/update", h.Update)
|
group.POST("update", h.Update)
|
||||||
group.GET("config/get", h.Get)
|
group.GET("get", h.Get)
|
||||||
group.POST("active", h.Active)
|
group.POST("active", h.Active)
|
||||||
group.GET("config/get/license", h.GetLicense)
|
group.GET("fixData", h.FixData)
|
||||||
group.GET("config/get/app", h.GetAppConfig)
|
group.GET("license", h.GetLicense)
|
||||||
group.POST("config/update/draw", h.SaveDrawingConfig)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||||
group := s.Engine.Group("/api/admin/")
|
group := s.Engine.Group("/api/admin/")
|
||||||
@@ -346,7 +336,7 @@ func main() {
|
|||||||
group.GET("loginLog", h.LoginLog)
|
group.GET("loginLog", h.LoginLog)
|
||||||
group.POST("resetPass", h.ResetPass)
|
group.POST("resetPass", h.ResetPass)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatRoleHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
||||||
group := s.Engine.Group("/api/admin/role/")
|
group := s.Engine.Group("/api/admin/role/")
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.POST("save", h.Save)
|
group.POST("save", h.Save)
|
||||||
@@ -354,10 +344,13 @@ func main() {
|
|||||||
group.POST("set", h.Set)
|
group.POST("set", h.Set)
|
||||||
group.GET("remove", h.Remove)
|
group.GET("remove", h.Remove)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
|
||||||
group := s.Engine.Group("/api/admin/reward/")
|
group := s.Engine.Group("/api/admin/redeem/")
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("create", h.Create)
|
||||||
|
group.POST("set", h.Set)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("export", h.Export)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||||
group := s.Engine.Group("/api/admin/dashboard/")
|
group := s.Engine.Group("/api/admin/dashboard/")
|
||||||
@@ -377,14 +370,12 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
||||||
group := s.Engine.Group("/api/payment/")
|
group := s.Engine.Group("/api/payment/")
|
||||||
group.GET("doPay", h.DoPay)
|
group.POST("doPay", h.Pay)
|
||||||
group.GET("payWays", h.GetPayWays)
|
group.GET("payWays", h.GetPayWays)
|
||||||
group.POST("qrcode", h.PayQrcode)
|
group.POST("notify/alipay", h.AlipayNotify)
|
||||||
group.POST("mobile", h.Mobile)
|
group.GET("notify/geek", h.GeekPayNotify)
|
||||||
group.POST("alipay/notify", h.AlipayNotify)
|
group.POST("notify/wechat", h.WechatPayNotify)
|
||||||
group.POST("hupipay/notify", h.HuPiPayNotify)
|
group.POST("notify/hupi", h.HuPiPayNotify)
|
||||||
group.POST("payjs/notify", h.PayJsNotify)
|
|
||||||
group.POST("wechat/notify", h.WechatPayNotify)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||||
group := s.Engine.Group("/api/admin/product/")
|
group := s.Engine.Group("/api/admin/product/")
|
||||||
@@ -398,6 +389,7 @@ func main() {
|
|||||||
group := s.Engine.Group("/api/admin/order/")
|
group := s.Engine.Group("/api/admin/order/")
|
||||||
group.POST("list", h.List)
|
group.POST("list", h.List)
|
||||||
group.GET("remove", h.Remove)
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("clear", h.Clear)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
||||||
group := s.Engine.Group("/api/order/")
|
group := s.Engine.Group("/api/order/")
|
||||||
@@ -413,7 +405,7 @@ func main() {
|
|||||||
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
|
||||||
group := s.Engine.Group("/api/invite/")
|
group := s.Engine.Group("/api/invite/")
|
||||||
group.GET("code", h.Code)
|
group.GET("code", h.Code)
|
||||||
group.POST("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.GET("hits", h.Hits)
|
group.GET("hits", h.Hits)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -438,6 +430,7 @@ func main() {
|
|||||||
group.POST("weibo", h.WeiBo)
|
group.POST("weibo", h.WeiBo)
|
||||||
group.POST("zaobao", h.ZaoBao)
|
group.POST("zaobao", h.ZaoBao)
|
||||||
group.POST("dalle3", h.Dall3)
|
group.POST("dalle3", h.Dall3)
|
||||||
|
group.GET("list", h.List)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
||||||
group := s.Engine.Group("/api/admin/chat/")
|
group := s.Engine.Group("/api/admin/chat/")
|
||||||
@@ -471,23 +464,21 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewMarkMapHandler),
|
fx.Provide(handler.NewMarkMapHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||||
group := s.Engine.Group("/api/markMap/")
|
s.Engine.POST("/api/markMap/gen", h.Generate)
|
||||||
group.Any("client", h.Client)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewDallJobHandler),
|
fx.Provide(handler.NewDallJobHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||||
group := s.Engine.Group("/api/dall")
|
group := s.Engine.Group("/api/dall")
|
||||||
group.Any("client", h.Client)
|
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
group.GET("imgWall", h.ImgWall)
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.GET("remove", h.Remove)
|
group.GET("remove", h.Remove)
|
||||||
group.GET("publish", h.Publish)
|
group.GET("publish", h.Publish)
|
||||||
|
group.GET("models", h.GetModels)
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewSunoHandler),
|
fx.Provide(handler.NewSunoHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
||||||
group := s.Engine.Group("/api/suno")
|
group := s.Engine.Group("/api/suno")
|
||||||
group.Any("client", h.Client)
|
|
||||||
group.POST("create", h.Create)
|
group.POST("create", h.Create)
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.GET("remove", h.Remove)
|
group.GET("remove", h.Remove)
|
||||||
@@ -495,13 +486,53 @@ func main() {
|
|||||||
group.POST("update", h.Update)
|
group.POST("update", h.Update)
|
||||||
group.GET("detail", h.Detail)
|
group.GET("detail", h.Detail)
|
||||||
group.GET("play", h.Play)
|
group.GET("play", h.Play)
|
||||||
group.POST("lyric", h.Lyric)
|
}),
|
||||||
|
fx.Provide(handler.NewVideoHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
||||||
|
group := s.Engine.Group("/api/video")
|
||||||
|
group.POST("luma/create", h.LumaCreate)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("publish", h.Publish)
|
||||||
|
}),
|
||||||
|
fx.Provide(admin.NewChatAppTypeHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/app/type")
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewChatAppTypeHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
||||||
|
group := s.Engine.Group("/api/app/type")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewTestHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||||
|
group := s.Engine.Group("/api/test")
|
||||||
|
group.Any("sse", h.PostTest, h.SseTest)
|
||||||
|
}),
|
||||||
|
fx.Provide(service.NewWebsocketService),
|
||||||
|
fx.Provide(handler.NewWebsocketHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
|
||||||
|
s.Engine.Any("/api/ws", h.Client)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewPromptHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||||
|
group := s.Engine.Group("/api/prompt")
|
||||||
|
group.POST("/lyric", h.Lyric)
|
||||||
|
group.POST("/image", h.Image)
|
||||||
|
group.POST("/video", h.Video)
|
||||||
|
group.POST("/meta", h.MetaPrompt)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||||
go func() {
|
go func() {
|
||||||
err := s.Run(db)
|
err := s.Run(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
logger.Error(err)
|
||||||
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}),
|
}),
|
||||||
@@ -517,6 +548,26 @@ func main() {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}),
|
}),
|
||||||
|
fx.Provide(admin.NewImageHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/image")
|
||||||
|
group.POST("/list/mj", h.MjList)
|
||||||
|
group.POST("/list/sd", h.SdList)
|
||||||
|
group.POST("/list/dall", h.DallList)
|
||||||
|
group.GET("/remove", h.Remove)
|
||||||
|
}),
|
||||||
|
fx.Provide(admin.NewMediaHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/media")
|
||||||
|
group.POST("/list/suno", h.SunoList)
|
||||||
|
group.POST("/list/luma", h.LumaList)
|
||||||
|
group.GET("/remove", h.Remove)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewRealtimeHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
|
||||||
|
s.Engine.Any("/api/realtime", h.Connection)
|
||||||
|
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
// 启动应用程序
|
// 启动应用程序
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
BIN
api/res/img/geek-pay.jpg
Normal file
BIN
api/res/img/geek-pay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
BIN
api/res/img/qq-pay.jpg
Normal file
BIN
api/res/img/qq-pay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
@@ -8,19 +8,19 @@ package dalle
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/sd"
|
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"github.com/go-redis/redis/v8"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -35,17 +35,21 @@ type Service struct {
|
|||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
userService *service.UserService
|
||||||
|
wsService *service.WebsocketService
|
||||||
|
clientIds map[uint]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
wsService: wsService,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
|
userService: userService,
|
||||||
|
clientIds: map[uint]string{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,6 +60,20 @@ func (s *Service) PushTask(task types.DallTask) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
|
// 将数据库中未提交的人物加载到队列
|
||||||
|
var jobs []model.DallJob
|
||||||
|
s.db.Where("progress", 0).Find(&jobs)
|
||||||
|
for _, v := range jobs {
|
||||||
|
var task types.DallTask
|
||||||
|
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("decode task info with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
task.Id = v.Id
|
||||||
|
s.PushTask(task)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Starting DALL-E job consumer...")
|
logger.Info("Starting DALL-E job consumer...")
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@@ -66,14 +84,15 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||||
|
s.clientIds[task.Id] = task.ClientId
|
||||||
_, err = s.Image(task, false)
|
_, err = s.Image(task, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("error with image task: %v", err)
|
logger.Errorf("error with image task: %v", err)
|
||||||
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||||
"progress": -1,
|
"progress": service.FailTaskProgress,
|
||||||
"err_msg": err.Error(),
|
"err_msg": err.Error(),
|
||||||
})
|
})
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -82,17 +101,18 @@ func (s *Service) Run() {
|
|||||||
type imgReq struct {
|
type imgReq struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
N int `json:"n"`
|
N int `json:"n,omitempty"`
|
||||||
Size string `json:"size"`
|
Size string `json:"size,omitempty"`
|
||||||
Quality string `json:"quality"`
|
Quality string `json:"quality,omitempty"`
|
||||||
Style string `json:"style"`
|
Style string `json:"style,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type imgRes struct {
|
type imgRes struct {
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Data []struct {
|
Data []struct {
|
||||||
RevisedPrompt string `json:"revised_prompt"`
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
Url string `json:"url"`
|
Url string `json:"url,omitempty"`
|
||||||
|
B64Json string `json:"b64_json,omitempty"`
|
||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,45 +130,27 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
prompt := task.Prompt
|
prompt := task.Prompt
|
||||||
// translate prompt
|
// translate prompt
|
||||||
if utils.HasChinese(prompt) {
|
if utils.HasChinese(prompt) {
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
prompt = content
|
prompt = content
|
||||||
logger.Debugf("重写后提示词:%s", prompt)
|
logger.Debugf("重写后提示词:%s", prompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var chatModel model.ChatModel
|
||||||
s.db.Where("id", task.UserId).First(&user)
|
s.db.Where("id = ?", task.ModelId).First(&chatModel)
|
||||||
if user.Power < task.Power {
|
|
||||||
return "", errors.New("insufficient of power")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户算力
|
|
||||||
tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
|
|
||||||
// 记录算力变化日志
|
|
||||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
|
||||||
var u model.User
|
|
||||||
s.db.Where("id", user.Id).First(&u)
|
|
||||||
s.db.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: task.Power,
|
|
||||||
Balance: u.Power,
|
|
||||||
Mark: types.PowerSub,
|
|
||||||
Model: "dall-e-3",
|
|
||||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// get image generation API KEY
|
// get image generation API KEY
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
tx = s.db.Where("type", "dalle").
|
session := s.db.Where("enabled", true)
|
||||||
Where("enabled", true).
|
if chatModel.KeyId > 0 {
|
||||||
Order("last_used_at ASC").First(&apiKey)
|
session = session.Where("id = ?", chatModel.KeyId)
|
||||||
if tx.Error != nil {
|
} else {
|
||||||
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
|
session = session.Where("type = ?", "dalle")
|
||||||
|
}
|
||||||
|
err := session.Order("last_used_at ASC").First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("no available Image Generation api key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var res imgRes
|
var res imgRes
|
||||||
@@ -158,7 +160,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||||
reqBody := imgReq{
|
reqBody := imgReq{
|
||||||
Model: "dall-e-3",
|
Model: chatModel.Value,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
N: 1,
|
N: 1,
|
||||||
Size: task.Size,
|
Size: task.Size,
|
||||||
@@ -166,35 +168,54 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
Quality: task.Quality,
|
Quality: task.Quality,
|
||||||
}
|
}
|
||||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||||
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
|
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
||||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
SetBody(reqBody).
|
SetBody(reqBody).
|
||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
SetSuccessResult(&res).
|
SetSuccessResult(&res).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Errorf("error with send request: %v", err)
|
||||||
return "", fmt.Errorf("error with send request: %v", err)
|
return "", fmt.Errorf("error with send request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.IsErrorState() {
|
if r.IsErrorState() {
|
||||||
|
logger.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||||
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
all, _ := io.ReadAll(r.Body)
|
||||||
|
logger.Debugf("response: %+v", string(all))
|
||||||
|
|
||||||
// update the api key last use time
|
// update the api key last use time
|
||||||
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
// update task progress
|
var imgURL string
|
||||||
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
var data = map[string]interface{}{
|
||||||
"progress": 100,
|
"progress": 100,
|
||||||
"org_url": res.Data[0].Url,
|
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
})
|
}
|
||||||
if tx.Error != nil {
|
// 如果返回的是base64,则需要上传到oss
|
||||||
return "", fmt.Errorf("err with update database: %v", tx.Error)
|
if res.Data[0].B64Json != "" {
|
||||||
|
imgURL, err = s.uploadManager.GetUploadHandler().PutBase64(res.Data[0].B64Json)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with upload image: %v", err)
|
||||||
|
}
|
||||||
|
logger.Infof("upload image to oss: %s", imgURL)
|
||||||
|
data["img_url"] = imgURL
|
||||||
|
} else {
|
||||||
|
imgURL = res.Data[0].Url
|
||||||
|
}
|
||||||
|
data["org_url"] = imgURL
|
||||||
|
// update task progress
|
||||||
|
err = s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(data).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("err with update database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||||
var content string
|
var content string
|
||||||
if sync {
|
if sync {
|
||||||
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
|
imgURL, err := s.downloadImage(task.Id, int(task.UserId), res.Data[0].Url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with download image: %v", err)
|
return "", fmt.Errorf("error with download image: %v", err)
|
||||||
}
|
}
|
||||||
@@ -208,19 +229,58 @@ func (s *Service) CheckTaskNotify() {
|
|||||||
go func() {
|
go func() {
|
||||||
logger.Info("Running DALL-E task notify checking ...")
|
logger.Info("Running DALL-E task notify checking ...")
|
||||||
for {
|
for {
|
||||||
var message sd.NotifyMessage
|
var message service.NotifyMessage
|
||||||
err := s.notifyQueue.LPop(&message)
|
err := s.notifyQueue.LPop(&message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := s.Clients.Get(uint(message.UserId))
|
|
||||||
|
logger.Debugf("notify message: %+v", message)
|
||||||
|
client := s.wsService.Clients.Get(message.ClientId)
|
||||||
if client == nil {
|
if client == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte(message.Message))
|
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||||
if err != nil {
|
}
|
||||||
continue
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running DALL-E task status checking ...")
|
||||||
|
for {
|
||||||
|
// 检查未完成任务进度
|
||||||
|
var jobs []model.DallJob
|
||||||
|
s.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 超时的任务标记为失败
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||||
|
job.Progress = service.FailTaskProgress
|
||||||
|
job.ErrMsg = "任务超时"
|
||||||
|
s.db.Updates(&job)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 找出失败的任务,并恢复其扣减算力
|
||||||
|
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||||
|
for _, job := range jobs {
|
||||||
|
var task types.DallTask
|
||||||
|
err := utils.JsonDecode(job.TaskInfo, &task)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: task.ModelName,
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 更新任务状态
|
||||||
|
s.db.Model(&job).UpdateColumn("power", 0)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -268,47 +328,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
|||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||||
return imgURL, nil
|
return imgURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
|
||||||
func (s *Service) CheckTaskStatus() {
|
|
||||||
go func() {
|
|
||||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
|
||||||
for {
|
|
||||||
var jobs []model.DallJob
|
|
||||||
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
|
||||||
if res.Error != nil {
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, job := range jobs {
|
|
||||||
// 5 分钟还没完成的任务直接删除
|
|
||||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
|
||||||
s.db.Delete(&job)
|
|
||||||
var user model.User
|
|
||||||
s.db.Where("id = ?", job.UserId).First(&user)
|
|
||||||
// 退回绘图次数
|
|
||||||
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
|
||||||
if res.Error == nil && res.RowsAffected > 0 {
|
|
||||||
s.db.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power + job.Power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Model: "dall-e-3",
|
|
||||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 10)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,13 +8,16 @@ package service
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/shirou/gopsutil/host"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LicenseService struct {
|
type LicenseService struct {
|
||||||
@@ -27,11 +30,18 @@ type LicenseService struct {
|
|||||||
|
|
||||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||||
var license types.License
|
var license types.License
|
||||||
|
var machineId string
|
||||||
|
_ = levelDB.Get(types.LicenseKey, &license)
|
||||||
|
info, err := host.Info()
|
||||||
|
if err == nil {
|
||||||
|
machineId = info.HostID
|
||||||
|
}
|
||||||
|
logger.Infof("License: %+v", license)
|
||||||
return &LicenseService{
|
return &LicenseService{
|
||||||
config: server.Config.ApiConfig,
|
config: server.Config.ApiConfig,
|
||||||
levelDB: levelDB,
|
levelDB: levelDB,
|
||||||
license: &license,
|
license: &license,
|
||||||
machineId: "",
|
machineId: machineId,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,33 +119,30 @@ func (s *LicenseService) SyncLicense() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||||
//var res struct {
|
var res struct {
|
||||||
// Code types.BizCode `json:"code"`
|
Code types.BizCode `json:"code"`
|
||||||
// Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
// Data License `json:"data"`
|
Data License `json:"data"`
|
||||||
//}
|
}
|
||||||
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
||||||
//response, err := req.C().R().
|
response, err := req.C().R().
|
||||||
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||||
// SetSuccessResult(&res).Post(apiURL)
|
SetSuccessResult(&res).Post(apiURL)
|
||||||
//if err != nil {
|
if err != nil {
|
||||||
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
||||||
//}
|
}
|
||||||
//if response.IsErrorState() {
|
if response.IsErrorState() {
|
||||||
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
return nil, fmt.Errorf("激活失败:%v", response.Status)
|
||||||
//}
|
}
|
||||||
//if res.Code != types.Success {
|
if res.Code != types.Success {
|
||||||
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
return nil, fmt.Errorf("激活失败:%v", res.Message)
|
||||||
//}
|
}
|
||||||
|
|
||||||
return &types.License{
|
return &types.License{
|
||||||
Key: "abc",
|
Key: res.Data.License,
|
||||||
MachineId: "abc",
|
MachineId: res.Data.MachineId,
|
||||||
Configs: types.LicenseConfig{
|
Configs: res.Data.Configs,
|
||||||
UserNum: 10000,
|
ExpiredAt: res.Data.ExpiredAt,
|
||||||
DeCopy: false,
|
|
||||||
},
|
|
||||||
ExpiredAt: 0,
|
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -169,29 +176,28 @@ func (s *LicenseService) GetLicense() *types.License {
|
|||||||
// IsValidApiURL 判断是否合法的中转 URL
|
// IsValidApiURL 判断是否合法的中转 URL
|
||||||
func (s *LicenseService) IsValidApiURL(uri string) error {
|
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||||
// 获得许可授权的直接放行
|
// 获得许可授权的直接放行
|
||||||
return nil
|
if s.license.IsActive {
|
||||||
//if s.license.IsActive {
|
if s.license.MachineId != s.machineId {
|
||||||
// if s.license.MachineId != s.machineId {
|
return errors.New("系统使用了盗版的许可证书")
|
||||||
// return errors.New("系统使用了盗版的许可证书")
|
}
|
||||||
// }
|
|
||||||
//
|
if time.Now().Unix() > s.license.ExpiredAt {
|
||||||
// if time.Now().Unix() > s.license.ExpiredAt {
|
return errors.New("系统许可证书已经过期")
|
||||||
// return errors.New("系统许可证书已经过期")
|
}
|
||||||
// }
|
return nil
|
||||||
// return nil
|
}
|
||||||
//}
|
|
||||||
//
|
if len(s.urlWhiteList) == 0 {
|
||||||
//if len(s.urlWhiteList) == 0 {
|
urls, err := s.fetchUrlWhiteList()
|
||||||
// urls, err := s.fetchUrlWhiteList()
|
if err == nil {
|
||||||
// if err == nil {
|
s.urlWhiteList = urls
|
||||||
// s.urlWhiteList = urls
|
}
|
||||||
// }
|
}
|
||||||
//}
|
|
||||||
//
|
for _, v := range s.urlWhiteList {
|
||||||
//for _, v := range s.urlWhiteList {
|
if strings.HasPrefix(uri, v) {
|
||||||
// if strings.HasPrefix(uri, v) {
|
return nil
|
||||||
// return nil
|
}
|
||||||
// }
|
}
|
||||||
//}
|
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||||
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,15 +7,28 @@ package mj
|
|||||||
// * @Author yangjian102621@163.com
|
// * @Author yangjian102621@163.com
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import "geekai/core/types"
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
type Client interface {
|
"github.com/gin-gonic/gin"
|
||||||
Imagine(task types.MjTask) (ImageRes, error)
|
)
|
||||||
Blend(task types.MjTask) (ImageRes, error)
|
|
||||||
SwapFace(task types.MjTask) (ImageRes, error)
|
// Client MidJourney client
|
||||||
Upscale(task types.MjTask) (ImageRes, error)
|
type Client struct {
|
||||||
Variation(task types.MjTask) (ImageRes, error)
|
client *req.Client
|
||||||
QueryTask(taskId string) (QueryRes, error)
|
licenseService *service.LicenseService
|
||||||
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageReq struct {
|
type ImageReq struct {
|
||||||
@@ -33,13 +46,8 @@ type ImageRes struct {
|
|||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Properties struct {
|
Properties struct {
|
||||||
} `json:"properties"`
|
} `json:"properties"`
|
||||||
Result string `json:"result"`
|
Result string `json:"result"`
|
||||||
}
|
Channel string `json:"channel,omitempty"`
|
||||||
|
|
||||||
type ErrRes struct {
|
|
||||||
Error struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryRes struct {
|
type QueryRes struct {
|
||||||
@@ -66,3 +74,177 @@ type QueryRes struct {
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SubmitTime int `json:"submitTime"`
|
SubmitTime int `json:"submitTime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
|
||||||
|
return &Client{
|
||||||
|
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
||||||
|
licenseService: licenseService,
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
|
||||||
|
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||||
|
if task.NegPrompt != "" {
|
||||||
|
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
||||||
|
}
|
||||||
|
body := ImageReq{
|
||||||
|
BotType: "MID_JOURNEY",
|
||||||
|
Prompt: prompt,
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return c.doRequest(body, apiPath, task.ChannelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blend 融图
|
||||||
|
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
|
||||||
|
body := ImageReq{
|
||||||
|
BotType: "MID_JOURNEY",
|
||||||
|
Dimensions: "SQUARE",
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
for _, imgURL := range task.ImgArr {
|
||||||
|
imageData, err := utils.DownloadImage(imgURL, "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.doRequest(body, apiPath, task.ChannelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwapFace 换脸
|
||||||
|
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) != 2 {
|
||||||
|
return ImageRes{}, errors.New("参数错误,必须上传2张图片")
|
||||||
|
}
|
||||||
|
var sourceBase64 string
|
||||||
|
var targetBase64 string
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
}
|
||||||
|
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := gin.H{
|
||||||
|
"sourceBase64": sourceBase64,
|
||||||
|
"targetBase64": targetBase64,
|
||||||
|
"accountFilter": gin.H{
|
||||||
|
"instanceId": "",
|
||||||
|
},
|
||||||
|
"state": "",
|
||||||
|
}
|
||||||
|
return c.doRequest(body, apiPath, task.ChannelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale 放大指定的图片
|
||||||
|
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]string{
|
||||||
|
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
|
||||||
|
return c.doRequest(body, apiPath, task.ChannelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
|
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]string{
|
||||||
|
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
|
||||||
|
|
||||||
|
return c.doRequest(body, apiPath, task.ChannelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
|
||||||
|
var res ImageRes
|
||||||
|
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
|
||||||
|
if channel != "" {
|
||||||
|
session = session.Where("api_url", channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
err := session.Order("last_used_at ASC").First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
errMsg, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the api key last used time
|
||||||
|
if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
|
||||||
|
logger.Error("update api key last used time error: ", err)
|
||||||
|
}
|
||||||
|
res.Channel = apiKey.ApiURL
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
|
||||||
|
var res QueryRes
|
||||||
|
r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Get(apiURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return QueryRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,204 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/service"
|
|
||||||
"geekai/utils"
|
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PlusClient MidJourney Plus ProxyClient
|
|
||||||
type PlusClient struct {
|
|
||||||
Config types.MjPlusConfig
|
|
||||||
apiURL string
|
|
||||||
client *req.Client
|
|
||||||
licenseService *service.LicenseService
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
|
|
||||||
return &PlusClient{
|
|
||||||
Config: config,
|
|
||||||
apiURL: config.ApiURL,
|
|
||||||
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
|
||||||
licenseService: licenseService,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PlusClient) preCheck() error {
|
|
||||||
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
|
||||||
if err := c.preCheck(); err != nil {
|
|
||||||
return ImageRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
|
||||||
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
|
||||||
if task.NegPrompt != "" {
|
|
||||||
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
|
||||||
}
|
|
||||||
body := ImageReq{
|
|
||||||
BotType: "MID_JOURNEY",
|
|
||||||
Prompt: prompt,
|
|
||||||
Base64Array: make([]string, 0),
|
|
||||||
}
|
|
||||||
// 生成图片 Base64 编码
|
|
||||||
if len(task.ImgArr) > 0 {
|
|
||||||
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
} else {
|
|
||||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return c.doRequest(body, apiURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Blend 融图
|
|
||||||
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
|
||||||
if err := c.preCheck(); err != nil {
|
|
||||||
return ImageRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
|
||||||
logger.Info("API URL: ", apiURL)
|
|
||||||
body := ImageReq{
|
|
||||||
BotType: "MID_JOURNEY",
|
|
||||||
Dimensions: "SQUARE",
|
|
||||||
Base64Array: make([]string, 0),
|
|
||||||
}
|
|
||||||
// 生成图片 Base64 编码
|
|
||||||
if len(task.ImgArr) > 0 {
|
|
||||||
for _, imgURL := range task.ImgArr {
|
|
||||||
imageData, err := utils.DownloadImage(imgURL, "")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
} else {
|
|
||||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.doRequest(body, apiURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SwapFace 换脸
|
|
||||||
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
|
||||||
if err := c.preCheck(); err != nil {
|
|
||||||
return ImageRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
|
||||||
// 生成图片 Base64 编码
|
|
||||||
if len(task.ImgArr) != 2 {
|
|
||||||
return ImageRes{}, errors.New("参数错误,必须上传2张图片")
|
|
||||||
}
|
|
||||||
var sourceBase64 string
|
|
||||||
var targetBase64 string
|
|
||||||
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
} else {
|
|
||||||
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
|
||||||
}
|
|
||||||
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
} else {
|
|
||||||
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
|
||||||
}
|
|
||||||
|
|
||||||
body := gin.H{
|
|
||||||
"sourceBase64": sourceBase64,
|
|
||||||
"targetBase64": targetBase64,
|
|
||||||
"accountFilter": gin.H{
|
|
||||||
"instanceId": "",
|
|
||||||
},
|
|
||||||
"state": "",
|
|
||||||
}
|
|
||||||
return c.doRequest(body, apiURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upscale 放大指定的图片
|
|
||||||
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
|
||||||
if err := c.preCheck(); err != nil {
|
|
||||||
return ImageRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body := map[string]string{
|
|
||||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
|
||||||
"taskId": task.MessageId,
|
|
||||||
}
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
|
|
||||||
return c.doRequest(body, apiURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
|
||||||
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
|
||||||
if err := c.preCheck(); err != nil {
|
|
||||||
return ImageRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body := map[string]string{
|
|
||||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
|
||||||
"taskId": task.MessageId,
|
|
||||||
}
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
|
|
||||||
|
|
||||||
return c.doRequest(body, apiURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) {
|
|
||||||
var res ImageRes
|
|
||||||
var errRes ErrRes
|
|
||||||
logger.Info("API URL: ", apiURL)
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&errRes).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
|
||||||
var res QueryRes
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
Get(apiURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return QueryRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return QueryRes{}, errors.New("error status:" + r.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Client = &PlusClient{}
|
|
||||||
@@ -1,207 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"geekai/core/types"
|
|
||||||
logger2 "geekai/logger"
|
|
||||||
"geekai/service"
|
|
||||||
"geekai/service/oss"
|
|
||||||
"geekai/service/sd"
|
|
||||||
"geekai/store"
|
|
||||||
"geekai/store/model"
|
|
||||||
"geekai/utils"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ServicePool Mj service pool
|
|
||||||
type ServicePool struct {
|
|
||||||
services []*Service
|
|
||||||
taskQueue *store.RedisQueue
|
|
||||||
notifyQueue *store.RedisQueue
|
|
||||||
db *gorm.DB
|
|
||||||
uploaderManager *oss.UploaderManager
|
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
|
||||||
licenseService *service.LicenseService
|
|
||||||
}
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
|
|
||||||
services := make([]*Service, 0)
|
|
||||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
|
||||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
|
||||||
return &ServicePool{
|
|
||||||
taskQueue: taskQueue,
|
|
||||||
notifyQueue: notifyQueue,
|
|
||||||
services: services,
|
|
||||||
uploaderManager: manager,
|
|
||||||
db: db,
|
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
|
||||||
licenseService: licenseService,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
|
|
||||||
// stop old service
|
|
||||||
for _, s := range p.services {
|
|
||||||
s.Stop()
|
|
||||||
}
|
|
||||||
p.services = make([]*Service, 0)
|
|
||||||
|
|
||||||
for _, config := range plusConfigs {
|
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
cli := NewPlusClient(config, p.licenseService)
|
|
||||||
name := utils.Md5(config.ApiURL)
|
|
||||||
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
|
||||||
go func() {
|
|
||||||
plusService.Run()
|
|
||||||
}()
|
|
||||||
p.services = append(p.services, plusService)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for mid-journey proxy
|
|
||||||
for _, config := range proxyConfigs {
|
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cli := NewProxyClient(config)
|
|
||||||
name := utils.Md5(config.ApiURL)
|
|
||||||
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
|
||||||
go func() {
|
|
||||||
proxyService.Run()
|
|
||||||
}()
|
|
||||||
p.services = append(p.services, proxyService)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) CheckTaskNotify() {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
var message sd.NotifyMessage
|
|
||||||
err := p.notifyQueue.LPop(&message)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cli := p.Clients.Get(uint(message.UserId))
|
|
||||||
if cli == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = cli.Send([]byte(message.Message))
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) DownloadImages() {
|
|
||||||
go func() {
|
|
||||||
var items []model.MidJourneyJob
|
|
||||||
for {
|
|
||||||
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
|
|
||||||
if res.Error != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// download images
|
|
||||||
for _, v := range items {
|
|
||||||
if v.OrgURL == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("try to download image: %s", v.OrgURL)
|
|
||||||
mjService := p.getService(v.ChannelId)
|
|
||||||
if mjService == nil {
|
|
||||||
logger.Errorf("Invalid task: %+v", v)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
task, _ := mjService.Client.QueryTask(v.TaskId)
|
|
||||||
if len(task.Buttons) > 0 {
|
|
||||||
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
|
||||||
}
|
|
||||||
// 如果是返回的是 discord 图片地址,则使用代理下载
|
|
||||||
proxy := false
|
|
||||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
|
||||||
proxy = true
|
|
||||||
}
|
|
||||||
imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
logger.Infof("download image %s successfully.", v.OrgURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
v.ImgURL = imgURL
|
|
||||||
p.db.Updates(&v)
|
|
||||||
|
|
||||||
cli := p.Clients.Get(uint(v.UserId))
|
|
||||||
if cli == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = cli.Send([]byte(sd.Finished))
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Second * 5)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushTask push a new mj task in to task queue
|
|
||||||
func (p *ServicePool) PushTask(task types.MjTask) {
|
|
||||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
|
||||||
p.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAvailableService check if it has available mj service in pool
|
|
||||||
func (p *ServicePool) HasAvailableService() bool {
|
|
||||||
return len(p.services) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncTaskProgress 异步拉取任务
|
|
||||||
func (p *ServicePool) SyncTaskProgress() {
|
|
||||||
go func() {
|
|
||||||
var jobs []model.MidJourneyJob
|
|
||||||
for {
|
|
||||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
|
||||||
if res.Error != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, job := range jobs {
|
|
||||||
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
|
|
||||||
_ = servicePlus.Notify(job)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Second * 10)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) getService(name string) *Service {
|
|
||||||
for _, s := range p.services {
|
|
||||||
if s.Name == name {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/utils"
|
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ProxyClient MidJourney Proxy Client
|
|
||||||
type ProxyClient struct {
|
|
||||||
Config types.MjProxyConfig
|
|
||||||
apiURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
|
|
||||||
return &ProxyClient{Config: config, apiURL: config.ApiURL}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
|
|
||||||
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
|
||||||
if task.NegPrompt != "" {
|
|
||||||
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
|
||||||
}
|
|
||||||
body := ImageReq{
|
|
||||||
Prompt: prompt,
|
|
||||||
Base64Array: make([]string, 0),
|
|
||||||
}
|
|
||||||
// 生成图片 Base64 编码
|
|
||||||
if len(task.ImgArr) > 0 {
|
|
||||||
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
} else {
|
|
||||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
logger.Info("API URL: ", apiURL)
|
|
||||||
var res ImageRes
|
|
||||||
var errRes ErrRes
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("mj-api-secret", c.Config.ApiKey).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&errRes).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
errStr, _ := io.ReadAll(r.Body)
|
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Blend 融图
|
|
||||||
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
|
|
||||||
body := ImageReq{
|
|
||||||
Dimensions: "SQUARE",
|
|
||||||
Base64Array: make([]string, 0),
|
|
||||||
}
|
|
||||||
// 生成图片 Base64 编码
|
|
||||||
if len(task.ImgArr) > 0 {
|
|
||||||
for _, imgURL := range task.ImgArr {
|
|
||||||
imageData, err := utils.DownloadImage(imgURL, "")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
} else {
|
|
||||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var res ImageRes
|
|
||||||
var errRes ErrRes
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("mj-api-secret", c.Config.ApiKey).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&errRes).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SwapFace 换脸
|
|
||||||
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
|
|
||||||
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upscale 放大指定的图片
|
|
||||||
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
|
|
||||||
body := map[string]interface{}{
|
|
||||||
"action": "UPSCALE",
|
|
||||||
"index": task.Index,
|
|
||||||
"taskId": task.MessageId,
|
|
||||||
}
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
|
||||||
var res ImageRes
|
|
||||||
var errRes ErrRes
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("mj-api-secret", c.Config.ApiKey).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&errRes).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
|
||||||
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
|
|
||||||
body := map[string]interface{}{
|
|
||||||
"action": "VARIATION",
|
|
||||||
"index": task.Index,
|
|
||||||
"taskId": task.MessageId,
|
|
||||||
}
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
|
||||||
var res ImageRes
|
|
||||||
var errRes ErrRes
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("mj-api-secret", c.Config.ApiKey).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&errRes).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
|
||||||
var res QueryRes
|
|
||||||
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
Get(apiURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return QueryRes{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return QueryRes{}, errors.New("error status:" + r.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Client = &ProxyClient{}
|
|
||||||
@@ -11,10 +11,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/sd"
|
"geekai/service/oss"
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -23,127 +24,132 @@ import (
|
|||||||
|
|
||||||
// Service MJ 绘画服务
|
// Service MJ 绘画服务
|
||||||
type Service struct {
|
type Service struct {
|
||||||
Name string // service Name
|
client *Client // MJ Client
|
||||||
Client Client // MJ Client
|
taskQueue *store.RedisQueue
|
||||||
taskQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
db *gorm.DB
|
||||||
db *gorm.DB
|
wsService *service.WebsocketService
|
||||||
running bool
|
uploaderManager *oss.UploaderManager
|
||||||
retryCount map[uint]int
|
userService *service.UserService
|
||||||
|
clientIds map[uint]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
|
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
Name: name,
|
db: db,
|
||||||
db: db,
|
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||||
taskQueue: taskQueue,
|
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
|
||||||
notifyQueue: notifyQueue,
|
client: client,
|
||||||
Client: cli,
|
wsService: wsService,
|
||||||
running: true,
|
uploaderManager: manager,
|
||||||
retryCount: make(map[uint]int),
|
clientIds: map[uint]string{},
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const failedProgress = 101
|
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
// 将数据库中未提交的人物加载到队列
|
||||||
for s.running {
|
var jobs []model.MidJourneyJob
|
||||||
|
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||||||
|
for _, v := range jobs {
|
||||||
var task types.MjTask
|
var task types.MjTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("decode task info with error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
task.Id = v.Id
|
||||||
|
s.clientIds[task.Id] = task.ClientId
|
||||||
|
s.PushTask(task)
|
||||||
|
}
|
||||||
|
|
||||||
// 如果配置了多个中转平台的 API KEY
|
logger.Info("Starting MidJourney job consumer for service")
|
||||||
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
|
go func() {
|
||||||
if task.ChannelId != "" && task.ChannelId != s.Name {
|
for {
|
||||||
if s.retryCount[task.Id] > 5 {
|
var task types.MjTask
|
||||||
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
|
err := s.taskQueue.LPop(&task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("taking task with error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
|
||||||
s.taskQueue.RPush(task)
|
|
||||||
s.retryCount[task.Id]++
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// translate prompt
|
// translate prompt
|
||||||
if utils.HasChinese(task.Prompt) {
|
if utils.HasChinese(task.Prompt) {
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini")
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
task.Prompt = content
|
task.Prompt = content
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
// translate negative prompt
|
||||||
// translate negative prompt
|
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini")
|
if err == nil {
|
||||||
if err == nil {
|
task.NegPrompt = content
|
||||||
task.NegPrompt = content
|
} else {
|
||||||
} else {
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
|
||||||
if tx.Error != nil {
|
|
||||||
logger.Error("任务不存在,任务ID:", task.TaskId)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
|
||||||
var res ImageRes
|
|
||||||
switch task.Type {
|
|
||||||
case types.TaskImage:
|
|
||||||
res, err = s.Client.Imagine(task)
|
|
||||||
break
|
|
||||||
case types.TaskUpscale:
|
|
||||||
res, err = s.Client.Upscale(task)
|
|
||||||
break
|
|
||||||
case types.TaskVariation:
|
|
||||||
res, err = s.Client.Variation(task)
|
|
||||||
break
|
|
||||||
case types.TaskBlend:
|
|
||||||
res, err = s.Client.Blend(task)
|
|
||||||
break
|
|
||||||
case types.TaskSwapFace:
|
|
||||||
res, err = s.Client.SwapFace(task)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
|
||||||
var errMsg string
|
|
||||||
if err != nil {
|
|
||||||
errMsg = err.Error()
|
|
||||||
} else {
|
|
||||||
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Error("绘画任务执行失败:", errMsg)
|
// use fast mode as default
|
||||||
job.Progress = failedProgress
|
if task.Mode == "" {
|
||||||
job.ErrMsg = errMsg
|
task.Mode = "fast"
|
||||||
// update the task progress
|
}
|
||||||
|
s.clientIds[task.Id] = task.ClientId
|
||||||
|
|
||||||
|
var job model.MidJourneyJob
|
||||||
|
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||||
|
if tx.Error != nil {
|
||||||
|
logger.Error("任务不存在,任务ID:", task.TaskId)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("handle a new MidJourney task: %+v", task)
|
||||||
|
var res ImageRes
|
||||||
|
switch task.Type {
|
||||||
|
case types.TaskImage:
|
||||||
|
res, err = s.client.Imagine(task)
|
||||||
|
break
|
||||||
|
case types.TaskUpscale:
|
||||||
|
res, err = s.client.Upscale(task)
|
||||||
|
break
|
||||||
|
case types.TaskVariation:
|
||||||
|
res, err = s.client.Variation(task)
|
||||||
|
break
|
||||||
|
case types.TaskBlend:
|
||||||
|
res, err = s.client.Blend(task)
|
||||||
|
break
|
||||||
|
case types.TaskSwapFace:
|
||||||
|
res, err = s.client.SwapFace(task)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||||
|
var errMsg string
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
} else {
|
||||||
|
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error("绘画任务执行失败:", errMsg)
|
||||||
|
job.Progress = service.FailTaskProgress
|
||||||
|
job.ErrMsg = errMsg
|
||||||
|
// update the task progress
|
||||||
|
s.db.Updates(&job)
|
||||||
|
// 任务失败,通知前端
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Infof("任务提交成功:%+v", res)
|
||||||
|
// 更新任务 ID/频道
|
||||||
|
job.TaskId = res.Result
|
||||||
|
job.MessageId = res.Result
|
||||||
|
job.ChannelId = res.Channel
|
||||||
s.db.Updates(&job)
|
s.db.Updates(&job)
|
||||||
// 任务失败,通知前端
|
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
logger.Infof("任务提交成功:%+v", res)
|
}()
|
||||||
// 更新任务 ID/频道
|
|
||||||
job.TaskId = res.Result
|
|
||||||
job.MessageId = res.Result
|
|
||||||
job.ChannelId = s.Name
|
|
||||||
s.db.Updates(&job)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Stop() {
|
|
||||||
s.running = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CBReq struct {
|
type CBReq struct {
|
||||||
@@ -164,46 +170,6 @@ type CBReq struct {
|
|||||||
} `json:"properties"`
|
} `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Notify(job model.MidJourneyJob) error {
|
|
||||||
task, err := s.Client.QueryTask(job.TaskId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 任务执行失败了
|
|
||||||
if task.FailReason != "" {
|
|
||||||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": failedProgress,
|
|
||||||
"err_msg": task.FailReason,
|
|
||||||
})
|
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
|
||||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(task.Buttons) > 0 {
|
|
||||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
|
||||||
}
|
|
||||||
oldProgress := job.Progress
|
|
||||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
|
||||||
job.Prompt = task.PromptEn
|
|
||||||
if task.ImageUrl != "" {
|
|
||||||
job.OrgURL = task.ImageUrl
|
|
||||||
}
|
|
||||||
tx := s.db.Updates(&job)
|
|
||||||
if tx.Error != nil {
|
|
||||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
|
||||||
}
|
|
||||||
// 通知前端更新任务进度
|
|
||||||
if oldProgress != job.Progress {
|
|
||||||
message := sd.Running
|
|
||||||
if job.Progress == 100 {
|
|
||||||
message = sd.Finished
|
|
||||||
}
|
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetImageHash(action string) string {
|
func GetImageHash(action string) string {
|
||||||
split := strings.Split(action, "::")
|
split := strings.Split(action, "::")
|
||||||
if len(split) > 5 {
|
if len(split) > 5 {
|
||||||
@@ -211,3 +177,160 @@ func GetImageHash(action string) string {
|
|||||||
}
|
}
|
||||||
return split[len(split)-1]
|
return split[len(split)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var message service.NotifyMessage
|
||||||
|
err := s.notifyQueue.LPop(&message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||||
|
client := s.wsService.Clients.Get(message.ClientId)
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
utils.SendChannelMsg(client, types.ChMj, message.Message)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DownloadImages() {
|
||||||
|
go func() {
|
||||||
|
var items []model.MidJourneyJob
|
||||||
|
for {
|
||||||
|
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// download images
|
||||||
|
for _, v := range items {
|
||||||
|
if v.OrgURL == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("try to download image: %s", v.OrgURL)
|
||||||
|
// 如果是返回的是 discord 图片地址,则使用代理下载
|
||||||
|
proxy := false
|
||||||
|
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||||
|
proxy = true
|
||||||
|
}
|
||||||
|
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
logger.Infof("download image %s successfully.", v.OrgURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.ImgURL = imgURL
|
||||||
|
s.db.Updates(&v)
|
||||||
|
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{
|
||||||
|
ClientId: s.clientIds[v.Id],
|
||||||
|
UserId: v.UserId,
|
||||||
|
JobId: int(v.Id),
|
||||||
|
Message: service.TaskStatusFinished})
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushTask push a new mj task in to task queue
|
||||||
|
func (s *Service) PushTask(task types.MjTask) {
|
||||||
|
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncTaskProgress 异步拉取任务
|
||||||
|
func (s *Service) SyncTaskProgress() {
|
||||||
|
go func() {
|
||||||
|
var jobs []model.MidJourneyJob
|
||||||
|
for {
|
||||||
|
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 10 分钟还没完成的任务标记为失败
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||||
|
job.Progress = service.FailTaskProgress
|
||||||
|
job.ErrMsg = "任务超时"
|
||||||
|
s.db.Updates(&job)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with query task: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 任务执行失败了
|
||||||
|
if task.FailReason != "" {
|
||||||
|
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": service.FailTaskProgress,
|
||||||
|
"err_msg": task.FailReason,
|
||||||
|
})
|
||||||
|
logger.Errorf("task failed: %v", task.FailReason)
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{
|
||||||
|
ClientId: s.clientIds[job.Id],
|
||||||
|
UserId: job.UserId,
|
||||||
|
JobId: int(job.Id),
|
||||||
|
Message: service.TaskStatusFailed})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(task.Buttons) > 0 {
|
||||||
|
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
|
}
|
||||||
|
oldProgress := job.Progress
|
||||||
|
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||||
|
if task.ImageUrl != "" {
|
||||||
|
job.OrgURL = task.ImageUrl
|
||||||
|
}
|
||||||
|
err = s.db.Updates(&job).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with update database: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通知前端更新任务进度
|
||||||
|
if oldProgress != job.Progress {
|
||||||
|
message := service.TaskStatusRunning
|
||||||
|
if job.Progress == 100 {
|
||||||
|
message = service.TaskStatusFinished
|
||||||
|
}
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{
|
||||||
|
ClientId: s.clientIds[job.Id],
|
||||||
|
UserId: job.UserId,
|
||||||
|
JobId: int(job.Id),
|
||||||
|
Message: message})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找出失败的任务,并恢复其扣减算力
|
||||||
|
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||||
|
for _, job := range jobs {
|
||||||
|
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 更新任务状态
|
||||||
|
s.db.Model(&job).UpdateColumn("power", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
fileExt := utils.GetImgExt(file.Filename)
|
fileExt := utils.GetImgExt(file.Filename)
|
||||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||||
ContentType: file.Header.Get("Content-Type"),
|
ContentType: file.Header.Get("Body-Type"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
|
return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
|
||||||
|
|||||||
@@ -43,10 +43,8 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
|||||||
|
|
||||||
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
||||||
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
||||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
||||||
SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2
|
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
|
||||||
SetReturnUrl(config.ReturnURL). // 设置返回URL
|
|
||||||
SetNotifyUrl(config.NotifyURL)
|
|
||||||
|
|
||||||
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
||||||
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
||||||
@@ -55,23 +53,31 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
|||||||
return &AlipayService{config: &config, client: client}, nil
|
return &AlipayService{config: &config, client: client}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) {
|
type AlipayParams struct {
|
||||||
bm := make(gopay.BodyMap)
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
bm.Set("subject", subject)
|
Subject string `json:"subject"`
|
||||||
bm.Set("out_trade_no", outTradeNo)
|
TotalFee string `json:"total_fee"`
|
||||||
bm.Set("quit_url", s.config.ReturnURL)
|
ReturnURL string `json:"return_url"`
|
||||||
bm.Set("total_amount", amount)
|
NotifyURL string `json:"notify_url"`
|
||||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
|
||||||
return s.client.TradeWapPay(context.Background(), bm)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) {
|
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
|
||||||
bm := make(gopay.BodyMap)
|
bm := make(gopay.BodyMap)
|
||||||
bm.Set("subject", subject)
|
bm.Set("subject", params.Subject)
|
||||||
bm.Set("out_trade_no", outTradeNo)
|
bm.Set("out_trade_no", params.OutTradeNo)
|
||||||
bm.Set("total_amount", amount)
|
bm.Set("quit_url", params.ReturnURL)
|
||||||
|
bm.Set("total_amount", params.TotalFee)
|
||||||
|
bm.Set("product_code", "QUICK_WAP_WAY")
|
||||||
|
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
|
||||||
|
bm := make(gopay.BodyMap)
|
||||||
|
bm.Set("subject", params.Subject)
|
||||||
|
bm.Set("out_trade_no", params.OutTradeNo)
|
||||||
|
bm.Set("total_amount", params.TotalFee)
|
||||||
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
||||||
return s.client.TradePagePay(context.Background(), bm)
|
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TradeVerify 交易验证
|
// TradeVerify 交易验证
|
||||||
|
|||||||
139
api/service/payment/geekpay_service.go
Normal file
139
api/service/payment/geekpay_service.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package payment
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeekPayService Geek 支付服务
|
||||||
|
type GeekPayService struct {
|
||||||
|
config *types.GeekPayConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
|
||||||
|
return &GeekPayService{
|
||||||
|
config: &appConfig.GeekPayConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeekPayParams struct {
|
||||||
|
Method string `json:"method"` // 接口类型
|
||||||
|
Device string `json:"device"` // 设备类型
|
||||||
|
Type string `json:"type"` // 支付方式
|
||||||
|
OutTradeNo string `json:"out_trade_no"` // 商户订单号
|
||||||
|
Name string `json:"name"` // 商品名称
|
||||||
|
Money string `json:"money"` // 商品金额
|
||||||
|
ClientIP string `json:"clientip"` //用户IP地址
|
||||||
|
SubOpenId string `json:"sub_openid"` // 微信用户 openid,仅小程序支付需要
|
||||||
|
SubAppId string `json:"sub_appid"` // 小程序 AppId,仅小程序支付需要
|
||||||
|
NotifyURL string `json:"notify_url"`
|
||||||
|
ReturnURL string `json:"return_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pay 支付订单
|
||||||
|
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
||||||
|
p := map[string]string{
|
||||||
|
"pid": s.config.AppId,
|
||||||
|
//"method": params.Method,
|
||||||
|
"device": params.Device,
|
||||||
|
"type": params.Type,
|
||||||
|
"out_trade_no": params.OutTradeNo,
|
||||||
|
"name": params.Name,
|
||||||
|
"money": params.Money,
|
||||||
|
"clientip": params.ClientIP,
|
||||||
|
"notify_url": params.NotifyURL,
|
||||||
|
"return_url": params.ReturnURL,
|
||||||
|
"timestamp": fmt.Sprintf("%d", time.Now().Unix()),
|
||||||
|
}
|
||||||
|
p["sign"] = s.Sign(p)
|
||||||
|
p["sign_type"] = "MD5"
|
||||||
|
return s.sendRequest(s.config.ApiURL, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeekPayService) Sign(params map[string]string) string {
|
||||||
|
// 按字母顺序排序参数
|
||||||
|
var keys []string
|
||||||
|
for k := range params {
|
||||||
|
if params[k] == "" || k == "sign" || k == "sign_type" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
// 构建待签名字符串
|
||||||
|
var signStr strings.Builder
|
||||||
|
for _, k := range keys {
|
||||||
|
signStr.WriteString(k)
|
||||||
|
signStr.WriteString("=")
|
||||||
|
signStr.WriteString(params[k])
|
||||||
|
signStr.WriteString("&")
|
||||||
|
}
|
||||||
|
signString := strings.TrimSuffix(signStr.String(), "&") + s.config.PrivateKey
|
||||||
|
|
||||||
|
return utils.Md5(signString)
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeekPayResp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
TradeNo string `json:"trade_no"`
|
||||||
|
PayURL string `json:"payurl"`
|
||||||
|
QrCode string `json:"qrcode"`
|
||||||
|
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||||
|
form := url.Values{}
|
||||||
|
for k, v := range params {
|
||||||
|
form.Add(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/mapi.php", endpoint)
|
||||||
|
logger.Infof(apiURL)
|
||||||
|
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true, // 取消 SSL 证书验证
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
resp, err := client.PostForm(apiURL, form)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
logger.Debugf(string(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var r GeekPayResp
|
||||||
|
err = json.Unmarshal(body, &r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("当前支付渠道暂不支持")
|
||||||
|
}
|
||||||
|
if r.Code != 1 {
|
||||||
|
return nil, errors.New(r.Msg)
|
||||||
|
}
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
@@ -37,7 +37,7 @@ func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type HuPiPayReq struct {
|
type HuPiPayParams struct {
|
||||||
AppId string `json:"appid"`
|
AppId string `json:"appid"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
TradeOrderId string `json:"trade_order_id"`
|
TradeOrderId string `json:"trade_order_id"`
|
||||||
@@ -53,7 +53,7 @@ type HuPiPayReq struct {
|
|||||||
WapUrl string `json:"wap_url"`
|
WapUrl string `json:"wap_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type HuPiResp struct {
|
type HuPiPayResp struct {
|
||||||
Openid interface{} `json:"openid"`
|
Openid interface{} `json:"openid"`
|
||||||
UrlQrcode string `json:"url_qrcode"`
|
UrlQrcode string `json:"url_qrcode"`
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
@@ -62,7 +62,7 @@ type HuPiResp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pay 执行支付请求操作
|
// Pay 执行支付请求操作
|
||||||
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
|
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
params.AppId = s.appId
|
params.AppId = s.appId
|
||||||
@@ -80,22 +80,22 @@ func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
|
|||||||
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
||||||
resp, err := http.PostForm(apiURL, data)
|
resp, err := http.PostForm(apiURL, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
|
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
all, err := io.ReadAll(resp.Body)
|
all, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
|
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var res HuPiResp
|
var res HuPiPayResp
|
||||||
err = utils.JsonDecode(string(all), &res)
|
err = utils.JsonDecode(string(all), &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.ErrCode != 0 {
|
if res.ErrCode != 0 {
|
||||||
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
@@ -127,10 +127,10 @@ func (s *HuPiPayService) Sign(params url.Values) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check 校验订单状态
|
// Check 校验订单状态
|
||||||
func (s *HuPiPayService) Check(tradeNo string) error {
|
func (s *HuPiPayService) Check(outTradeNo string) error {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Add("appid", s.appId)
|
data.Add("appid", s.appId)
|
||||||
data.Add("open_order_id", tradeNo)
|
data.Add("out_trade_order", outTradeNo)
|
||||||
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
data.Add("time", stamp)
|
data.Add("time", stamp)
|
||||||
data.Add("nonce_str", stamp)
|
data.Add("nonce_str", stamp)
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
package payment
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/md5"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/utils"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JPayService struct {
|
|
||||||
config *types.JPayConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJPayService(appConfig *types.AppConfig) *JPayService {
|
|
||||||
return &JPayService{
|
|
||||||
config: &appConfig.JPayConfig,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type JPayReq struct {
|
|
||||||
TotalFee int `json:"total_fee"`
|
|
||||||
OutTradeNo string `json:"out_trade_no"`
|
|
||||||
Subject string `json:"body"`
|
|
||||||
NotifyURL string `json:"notify_url"`
|
|
||||||
ReturnURL string `json:"callback_url"`
|
|
||||||
}
|
|
||||||
type JPayReps struct {
|
|
||||||
OutTradeNo string `json:"out_trade_no"`
|
|
||||||
OrderId string `json:"payjs_order_id"`
|
|
||||||
ReturnCode int `json:"return_code"`
|
|
||||||
ReturnMsg string `json:"return_msg"`
|
|
||||||
Sign string `json:"Sign"`
|
|
||||||
TotalFee string `json:"total_fee"`
|
|
||||||
CodeUrl string `json:"code_url,omitempty"`
|
|
||||||
Qrcode string `json:"qrcode,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r JPayReps) IsOK() bool {
|
|
||||||
return r.ReturnMsg == "SUCCESS"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (js *JPayService) Pay(param JPayReq) JPayReps {
|
|
||||||
param.NotifyURL = js.config.NotifyURL
|
|
||||||
var p = url.Values{}
|
|
||||||
encode := utils.JsonEncode(param)
|
|
||||||
m := make(map[string]interface{})
|
|
||||||
_ = utils.JsonDecode(encode, &m)
|
|
||||||
for k, v := range m {
|
|
||||||
p.Add(k, fmt.Sprintf("%v", v))
|
|
||||||
}
|
|
||||||
p.Add("mchid", js.config.AppId)
|
|
||||||
|
|
||||||
p.Add("sign", js.sign(p))
|
|
||||||
|
|
||||||
cli := http.Client{}
|
|
||||||
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
|
|
||||||
r, err := cli.PostForm(apiURL, p)
|
|
||||||
if err != nil {
|
|
||||||
return JPayReps{ReturnMsg: err.Error()}
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
bs, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return JPayReps{ReturnMsg: err.Error()}
|
|
||||||
}
|
|
||||||
|
|
||||||
var data JPayReps
|
|
||||||
err = utils.JsonDecode(string(bs), &data)
|
|
||||||
if err != nil {
|
|
||||||
return JPayReps{ReturnMsg: err.Error()}
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
func (js *JPayService) PayH5(p url.Values) string {
|
|
||||||
p.Add("mchid", js.config.AppId)
|
|
||||||
p.Add("sign", js.sign(p))
|
|
||||||
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (js *JPayService) sign(params url.Values) string {
|
|
||||||
params.Del(`sign`)
|
|
||||||
var keys = make([]string, 0, 0)
|
|
||||||
for key := range params {
|
|
||||||
if params.Get(key) != `` {
|
|
||||||
keys = append(keys, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
var pList = make([]string, 0, 0)
|
|
||||||
for _, key := range keys {
|
|
||||||
var value = strings.TrimSpace(params.Get(key))
|
|
||||||
if len(value) > 0 {
|
|
||||||
pList = append(pList, key+"="+value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var src = strings.Join(pList, "&")
|
|
||||||
src += "&key=" + js.config.PrivateKey
|
|
||||||
|
|
||||||
md5bs := md5.Sum([]byte(src))
|
|
||||||
md5res := hex.EncodeToString(md5bs[:])
|
|
||||||
return strings.ToUpper(md5res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TradeVerify 查询订单支付状态
|
|
||||||
// @param tradeNo 支付平台交易 ID
|
|
||||||
func (js *JPayService) TradeVerify(tradeNo string) error {
|
|
||||||
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
|
|
||||||
params := url.Values{}
|
|
||||||
params.Add("payjs_order_id", tradeNo)
|
|
||||||
params.Add("sign", js.sign(params))
|
|
||||||
data := strings.NewReader(params.Encode())
|
|
||||||
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with http reqeust: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var r struct {
|
|
||||||
ReturnCode int `json:"return_code"`
|
|
||||||
Status int `json:"status"`
|
|
||||||
}
|
|
||||||
err = utils.JsonDecode(string(body), &r)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.ReturnCode == 1 && r.Status == 1 {
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
logger.Errorf("PayJs 支付验证响应:%s", string(body))
|
|
||||||
return errors.New("order not paid")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -46,18 +46,27 @@ func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
|
|||||||
return &WechatPayService{config: &config, client: client}, nil
|
return &WechatPayService{config: &config, client: client}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) {
|
type WechatPayParams struct {
|
||||||
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
|
TotalFee int `json:"total_fee"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
ClientIP string `json:"client_ip"`
|
||||||
|
ReturnURL string `json:"return_url"`
|
||||||
|
NotifyURL string `json:"notify_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||||
// 初始化 BodyMap
|
// 初始化 BodyMap
|
||||||
bm := make(gopay.BodyMap)
|
bm := make(gopay.BodyMap)
|
||||||
bm.Set("appid", s.config.AppId).
|
bm.Set("appid", s.config.AppId).
|
||||||
Set("mchid", s.config.MchId).
|
Set("mchid", s.config.MchId).
|
||||||
Set("description", subject).
|
Set("description", params.Subject).
|
||||||
Set("out_trade_no", outTradeNo).
|
Set("out_trade_no", params.OutTradeNo).
|
||||||
Set("time_expire", expire).
|
Set("time_expire", expire).
|
||||||
Set("notify_url", s.config.NotifyURL).
|
Set("notify_url", params.NotifyURL).
|
||||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||||
bm.Set("total", amount).
|
bm.Set("total", params.TotalFee).
|
||||||
Set("currency", "CNY")
|
Set("currency", "CNY")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -71,22 +80,22 @@ func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject s
|
|||||||
return wxRsp.Response.CodeUrl, nil
|
return wxRsp.Response.CodeUrl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) {
|
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||||
// 初始化 BodyMap
|
// 初始化 BodyMap
|
||||||
bm := make(gopay.BodyMap)
|
bm := make(gopay.BodyMap)
|
||||||
bm.Set("appid", s.config.AppId).
|
bm.Set("appid", s.config.AppId).
|
||||||
Set("mchid", s.config.MchId).
|
Set("mchid", s.config.MchId).
|
||||||
Set("description", subject).
|
Set("description", params.Subject).
|
||||||
Set("out_trade_no", outTradeNo).
|
Set("out_trade_no", params.OutTradeNo).
|
||||||
Set("time_expire", expire).
|
Set("time_expire", expire).
|
||||||
Set("notify_url", s.config.NotifyURL).
|
Set("notify_url", params.NotifyURL).
|
||||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||||
bm.Set("total", amount).
|
bm.Set("total", params.TotalFee).
|
||||||
Set("currency", "CNY")
|
Set("currency", "CNY")
|
||||||
}).
|
}).
|
||||||
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||||
bm.Set("payer_client_ip", ip).
|
bm.Set("payer_client_ip", params.ClientIP).
|
||||||
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||||
bm.Set("type", "Wap")
|
bm.Set("type", "Wap")
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
package sd
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/service/oss"
|
|
||||||
"geekai/store"
|
|
||||||
"geekai/store/model"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServicePool struct {
|
|
||||||
services []*Service
|
|
||||||
taskQueue *store.RedisQueue
|
|
||||||
notifyQueue *store.RedisQueue
|
|
||||||
db *gorm.DB
|
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
|
||||||
uploader *oss.UploaderManager
|
|
||||||
levelDB *store.LevelDB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
|
|
||||||
services := make([]*Service, 0)
|
|
||||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
|
||||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
|
||||||
|
|
||||||
return &ServicePool{
|
|
||||||
taskQueue: taskQueue,
|
|
||||||
notifyQueue: notifyQueue,
|
|
||||||
services: services,
|
|
||||||
db: db,
|
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
|
||||||
uploader: manager,
|
|
||||||
levelDB: levelDB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
|
|
||||||
// stop old service
|
|
||||||
for _, s := range p.services {
|
|
||||||
s.Stop()
|
|
||||||
}
|
|
||||||
p.services = make([]*Service, 0)
|
|
||||||
|
|
||||||
for k, config := range configs {
|
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// create sd service
|
|
||||||
name := fmt.Sprintf(" sd-service-%d", k)
|
|
||||||
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
|
|
||||||
// run sd service
|
|
||||||
go func() {
|
|
||||||
service.Run()
|
|
||||||
}()
|
|
||||||
|
|
||||||
p.services = append(p.services, service)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushTask push a new mj task in to task queue
|
|
||||||
func (p *ServicePool) PushTask(task types.SdTask) {
|
|
||||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
|
||||||
p.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) CheckTaskNotify() {
|
|
||||||
go func() {
|
|
||||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
|
||||||
for {
|
|
||||||
var message NotifyMessage
|
|
||||||
err := p.notifyQueue.LPop(&message)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
client := p.Clients.Get(uint(message.UserId))
|
|
||||||
if client == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = client.Send([]byte(message.Message))
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
|
||||||
func (p *ServicePool) CheckTaskStatus() {
|
|
||||||
go func() {
|
|
||||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
|
||||||
for {
|
|
||||||
var jobs []model.SdJob
|
|
||||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
|
||||||
if res.Error != nil {
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, job := range jobs {
|
|
||||||
// 5 分钟还没完成的任务直接删除
|
|
||||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
|
||||||
p.db.Delete(&job)
|
|
||||||
var user model.User
|
|
||||||
p.db.Where("id = ?", job.UserId).First(&user)
|
|
||||||
// 退回绘图次数
|
|
||||||
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
|
||||||
if res.Error == nil && res.RowsAffected > 0 {
|
|
||||||
p.db.Create(&model.PowerLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Amount: job.Power,
|
|
||||||
Balance: user.Power + job.Power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Model: "stable-diffusion",
|
|
||||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 5)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAvailableService check if it has available mj service in pool
|
|
||||||
func (p *ServicePool) HasAvailableService() bool {
|
|
||||||
return len(p.services) > 0
|
|
||||||
}
|
|
||||||
@@ -10,95 +10,104 @@ package sd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"strings"
|
"github.com/go-redis/redis/v8"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
// SD 绘画服务
|
// SD 绘画服务
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
httpClient *req.Client
|
httpClient *req.Client
|
||||||
config types.StableDiffusionConfig
|
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
name string // service name
|
wsService *service.WebsocketService
|
||||||
leveldb *store.LevelDB
|
userService *service.UserService
|
||||||
running bool // 运行状态
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||||
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
|
||||||
config: config,
|
|
||||||
httpClient: req.C(),
|
httpClient: req.C(),
|
||||||
taskQueue: taskQueue,
|
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||||
db: db,
|
db: db,
|
||||||
leveldb: levelDB,
|
wsService: wsService,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
running: true,
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
|
// 将数据库中未提交的人物加载到队列
|
||||||
for s.running {
|
var jobs []model.SdJob
|
||||||
|
s.db.Where("progress", 0).Find(&jobs)
|
||||||
|
for _, v := range jobs {
|
||||||
var task types.SdTask
|
var task types.SdTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("decode task info with error: %v", err)
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// translate prompt
|
|
||||||
if utils.HasChinese(task.Params.Prompt) {
|
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
|
||||||
if err == nil {
|
|
||||||
task.Params.Prompt = content
|
|
||||||
} else {
|
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// translate negative prompt
|
|
||||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
|
||||||
if err == nil {
|
|
||||||
task.Params.NegPrompt = content
|
|
||||||
} else {
|
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
|
||||||
err = s.Txt2Img(task)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("绘画任务执行失败:", err.Error())
|
|
||||||
// update the task progress
|
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": err.Error(),
|
|
||||||
})
|
|
||||||
// 通知前端,任务失败
|
|
||||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
task.Id = int(v.Id)
|
||||||
|
s.PushTask(task)
|
||||||
}
|
}
|
||||||
}
|
logger.Infof("Starting Stable-Diffusion job consumer")
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var task types.SdTask
|
||||||
|
err := s.taskQueue.LPop(&task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("taking task with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) Stop() {
|
// translate prompt
|
||||||
s.running = false
|
if utils.HasChinese(task.Params.Prompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.Prompt), task.TranslateModelId)
|
||||||
|
if err == nil {
|
||||||
|
task.Params.Prompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate negative prompt
|
||||||
|
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), task.TranslateModelId)
|
||||||
|
if err == nil {
|
||||||
|
task.Params.NegPrompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
|
||||||
|
err = s.Txt2Img(task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("绘画任务执行失败:", err.Error())
|
||||||
|
// update the task progress
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": service.FailTaskProgress,
|
||||||
|
"err_msg": err.Error(),
|
||||||
|
})
|
||||||
|
// 通知前端,任务失败
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Txt2ImgReq 文生图请求实体
|
// Txt2ImgReq 文生图请求实体
|
||||||
@@ -130,9 +139,8 @@ type Txt2ImgResp struct {
|
|||||||
|
|
||||||
// TaskProgressResp 任务进度响应实体
|
// TaskProgressResp 任务进度响应实体
|
||||||
type TaskProgressResp struct {
|
type TaskProgressResp struct {
|
||||||
Progress float64 `json:"progress"`
|
Progress float64 `json:"progress"`
|
||||||
EtaRelative float64 `json:"eta_relative"`
|
EtaRelative float64 `json:"eta_relative"`
|
||||||
CurrentImage string `json:"current_image"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Txt2Img 文生图 API
|
// Txt2Img 文生图 API
|
||||||
@@ -160,12 +168,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
}
|
}
|
||||||
var res Txt2ImgResp
|
var res Txt2ImgResp
|
||||||
var errChan = make(chan error)
|
var errChan = make(chan error)
|
||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
|
||||||
logger.Debugf("send image request to %s", apiURL)
|
var apiKey model.ApiKey
|
||||||
|
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
|
||||||
|
logger.Infof("send image request to %s", apiURL)
|
||||||
// send a request to sd api endpoint
|
// send a request to sd api endpoint
|
||||||
go func() {
|
go func() {
|
||||||
response, err := s.httpClient.R().
|
response, err := s.httpClient.R().
|
||||||
SetHeader("Authorization", s.config.ApiKey).
|
SetHeader("Authorization", apiKey.Value).
|
||||||
SetBody(body).
|
SetBody(body).
|
||||||
SetSuccessResult(&res).
|
SetSuccessResult(&res).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
@@ -178,6 +193,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update the last used time
|
||||||
|
apiKey.LastUsedAt = time.Now().Unix()
|
||||||
|
s.db.Updates(&apiKey)
|
||||||
|
|
||||||
// 保存 Base64 图片
|
// 保存 Base64 图片
|
||||||
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -206,21 +225,15 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
|
|
||||||
// task finished
|
// task finished
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||||
// 从 leveldb 中删除预览图片数据
|
|
||||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
err, resp := s.checkTaskProgress()
|
err, resp := s.checkTaskProgress(apiKey)
|
||||||
// 更新任务进度
|
// 更新任务进度
|
||||||
if err == nil && resp.Progress > 0 {
|
if err == nil && resp.Progress > 0 {
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||||
// 发送更新状态信号
|
// 发送更新状态信号
|
||||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||||
// 保存预览图片数据
|
|
||||||
if resp.CurrentImage != "" {
|
|
||||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@@ -229,11 +242,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 执行任务
|
// 执行任务
|
||||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
|
||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
|
||||||
var res TaskProgressResp
|
var res TaskProgressResp
|
||||||
response, err := s.httpClient.R().
|
response, err := s.httpClient.R().
|
||||||
SetHeader("Authorization", s.config.ApiKey).
|
SetHeader("Authorization", apiKey.Value).
|
||||||
SetSuccessResult(&res).
|
SetSuccessResult(&res).
|
||||||
Get(apiURL)
|
Get(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -245,3 +258,67 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
|||||||
|
|
||||||
return nil, &res
|
return nil, &res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) PushTask(task types.SdTask) {
|
||||||
|
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||||
|
for {
|
||||||
|
var message service.NotifyMessage
|
||||||
|
err := s.notifyQueue.LPop(&message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Debugf("notify message: %+v", message)
|
||||||
|
client := s.wsService.Clients.Get(message.ClientId)
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
utils.SendChannelMsg(client, types.ChSd, message.Message)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
|
func (s *Service) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
|
for {
|
||||||
|
var jobs []model.SdJob
|
||||||
|
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 5 分钟还没完成的任务标记为失败
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
||||||
|
job.Progress = service.FailTaskProgress
|
||||||
|
job.ErrMsg = "任务超时"
|
||||||
|
s.db.Updates(&job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找出失败的任务,并恢复其扣减算力
|
||||||
|
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||||
|
for _, job := range jobs {
|
||||||
|
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: "stable-diffusion",
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d, Err: %s", job.Id, job.ErrMsg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 更新任务状态
|
||||||
|
s.db.Model(&job).UpdateColumn("power", 0)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
package sd
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import logger2 "geekai/logger"
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
type NotifyMessage struct {
|
|
||||||
UserId int `json:"user_id"`
|
|
||||||
JobId int `json:"job_id"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
Running = "RUNNING"
|
|
||||||
Finished = "FINISH"
|
|
||||||
Failed = "FAIL"
|
|
||||||
)
|
|
||||||
@@ -29,7 +29,7 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
|
|||||||
|
|
||||||
func (s *SmtpService) SendVerifyCode(to string, code int) error {
|
func (s *SmtpService) SendVerifyCode(to string, code int) error {
|
||||||
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
|
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
|
||||||
body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
|
body := fmt.Sprintf("【%s】:您的验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
|
||||||
|
|
||||||
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
|
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
|
||||||
if s.config.UseTls {
|
if s.config.UseTls {
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/sd"
|
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
@@ -34,17 +34,21 @@ type Service struct {
|
|||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
wsService *service.WebsocketService
|
||||||
|
clientIds map[string]string
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
||||||
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
|
wsService: wsService,
|
||||||
|
clientIds: map[string]string{},
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,22 +60,17 @@ func (s *Service) PushTask(task types.SunoTask) {
|
|||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
// 将数据库中未提交的人物加载到队列
|
// 将数据库中未提交的人物加载到队列
|
||||||
var jobs []model.SunoJob
|
var jobs []model.SunoJob
|
||||||
s.db.Where("task_id", "").Find(&jobs)
|
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||||||
for _, v := range jobs {
|
for _, v := range jobs {
|
||||||
s.PushTask(types.SunoTask{
|
var task types.SunoTask
|
||||||
Id: v.Id,
|
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||||
Channel: v.Channel,
|
if err != nil {
|
||||||
UserId: v.UserId,
|
logger.Errorf("decode task info with error: %v", err)
|
||||||
Type: v.Type,
|
continue
|
||||||
Title: v.Title,
|
}
|
||||||
RefTaskId: v.RefTaskId,
|
task.Id = v.Id
|
||||||
RefSongId: v.RefSongId,
|
s.PushTask(task)
|
||||||
Prompt: v.Prompt,
|
s.clientIds[v.TaskId] = task.ClientId
|
||||||
Tags: v.Tags,
|
|
||||||
Model: v.ModelName,
|
|
||||||
Instrumental: v.Instrumental,
|
|
||||||
ExtendSecs: v.ExtendSecs,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
logger.Info("Starting Suno job consumer...")
|
logger.Info("Starting Suno job consumer...")
|
||||||
go func() {
|
go func() {
|
||||||
@@ -82,14 +81,21 @@ func (s *Service) Run() {
|
|||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("taking task with error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
var r RespVo
|
||||||
r, err := s.Create(task)
|
if task.Type == 3 && task.SongId != "" { // 歌曲拼接
|
||||||
|
r, err = s.Merge(task)
|
||||||
|
} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
|
||||||
|
r, err = s.Upload(task)
|
||||||
|
} else { // 歌曲创作
|
||||||
|
r, err = s.Create(task)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("create task with error: %v", err)
|
logger.Errorf("create task with error: %v", err)
|
||||||
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||||
"err_msg": err.Error(),
|
"err_msg": err.Error(),
|
||||||
"progress": 101,
|
"progress": service.FailTaskProgress,
|
||||||
})
|
})
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,6 +104,7 @@ func (s *Service) Run() {
|
|||||||
"task_id": r.Data,
|
"task_id": r.Data,
|
||||||
"channel": r.Channel,
|
"channel": r.Channel,
|
||||||
})
|
})
|
||||||
|
s.clientIds[r.Data] = task.ClientId
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -138,7 +145,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var res RespVo
|
var res RespVo
|
||||||
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL)
|
apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
|
||||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||||
r, err := req.C().R().
|
r, err := req.C().R().
|
||||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
@@ -157,6 +164,100 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
|||||||
if res.Code != "success" {
|
if res.Code != "success" {
|
||||||
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||||
}
|
}
|
||||||
|
// update the last_use_at for api key
|
||||||
|
apiKey.LastUsedAt = time.Now().Unix()
|
||||||
|
session.Updates(&apiKey)
|
||||||
|
res.Channel = apiKey.ApiURL
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
|
||||||
|
// 读取 API KEY
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
|
||||||
|
if task.Channel != "" {
|
||||||
|
session = session.Where("api_url", task.Channel)
|
||||||
|
}
|
||||||
|
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||||
|
if tx.Error != nil {
|
||||||
|
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"clip_id": task.SongId,
|
||||||
|
"is_infill": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
var res RespVo
|
||||||
|
apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
|
||||||
|
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(reqBody).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
err = json.Unmarshal(body, &res)
|
||||||
|
if err != nil {
|
||||||
|
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != "success" {
|
||||||
|
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||||
|
}
|
||||||
|
// update the last_use_at for api key
|
||||||
|
apiKey.LastUsedAt = time.Now().Unix()
|
||||||
|
session.Updates(&apiKey)
|
||||||
|
res.Channel = apiKey.ApiURL
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
|
||||||
|
// 读取 API KEY
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
|
||||||
|
if task.Channel != "" {
|
||||||
|
session = session.Where("api_url", task.Channel)
|
||||||
|
}
|
||||||
|
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||||
|
if tx.Error != nil {
|
||||||
|
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"url": task.AudioURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
var res RespVo
|
||||||
|
apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
|
||||||
|
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(reqBody).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != 200 {
|
||||||
|
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
err = json.Unmarshal(body, &res)
|
||||||
|
if err != nil {
|
||||||
|
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != "success" {
|
||||||
|
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||||
|
}
|
||||||
|
// update the last_use_at for api key
|
||||||
|
apiKey.LastUsedAt = time.Now().Unix()
|
||||||
|
session.Updates(&apiKey)
|
||||||
res.Channel = apiKey.ApiURL
|
res.Channel = apiKey.ApiURL
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
@@ -165,24 +266,24 @@ func (s *Service) CheckTaskNotify() {
|
|||||||
go func() {
|
go func() {
|
||||||
logger.Info("Running Suno task notify checking ...")
|
logger.Info("Running Suno task notify checking ...")
|
||||||
for {
|
for {
|
||||||
var message sd.NotifyMessage
|
var message service.NotifyMessage
|
||||||
err := s.notifyQueue.LPop(&message)
|
err := s.notifyQueue.LPop(&message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := s.Clients.Get(uint(message.UserId))
|
logger.Debugf("notify message: %+v", message)
|
||||||
|
logger.Debugf("client id: %+v", s.wsService.Clients)
|
||||||
|
client := s.wsService.Clients.Get(message.ClientId)
|
||||||
|
logger.Debugf("%+v", client)
|
||||||
if client == nil {
|
if client == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte(message.Message))
|
utils.SendChannelMsg(client, types.ChSuno, message.Message)
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) DownloadImages() {
|
func (s *Service) DownloadFiles() {
|
||||||
go func() {
|
go func() {
|
||||||
var items []model.SunoJob
|
var items []model.SunoJob
|
||||||
for {
|
for {
|
||||||
@@ -210,7 +311,7 @@ func (s *Service) DownloadImages() {
|
|||||||
v.AudioURL = audioURL
|
v.AudioURL = audioURL
|
||||||
v.Progress = 100
|
v.Progress = 100
|
||||||
s.db.Updates(&v)
|
s.db.Updates(&v)
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second * 10)
|
time.Sleep(time.Second * 10)
|
||||||
@@ -276,15 +377,29 @@ func (s *Service) SyncTaskProgress() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
|
||||||
} else if task.Data.FailReason != "" {
|
} else if task.Data.FailReason != "" {
|
||||||
job.Progress = 101
|
job.Progress = service.FailTaskProgress
|
||||||
job.ErrMsg = task.Data.FailReason
|
job.ErrMsg = task.Data.FailReason
|
||||||
s.db.Updates(&job)
|
s.db.Updates(&job)
|
||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 找出失败的任务,并恢复其扣减算力
|
||||||
|
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||||
|
for _, job := range jobs {
|
||||||
|
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: job.ModelName,
|
||||||
|
Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 更新任务状态
|
||||||
|
s.db.Model(&job).UpdateColumn("power", 0)
|
||||||
|
}
|
||||||
time.Sleep(time.Second * 10)
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -328,15 +443,15 @@ type QueryRespVo struct {
|
|||||||
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
|
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
|
||||||
// 读取 API KEY
|
// 读取 API KEY
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
|
err := s.db.Session(&gorm.Session{}).Where("type", "suno").
|
||||||
Where("api_url", channel).
|
Where("api_url", channel).
|
||||||
Where("enabled", true).
|
Where("enabled", true).
|
||||||
Order("last_used_at DESC").First(&apiKey)
|
Order("last_used_at DESC").First(&apiKey).Error
|
||||||
if tx.Error != nil {
|
if err != nil {
|
||||||
return QueryRespVo{}, errors.New("no available API KEY for Suno")
|
return QueryRespVo{}, errors.New("no available API KEY for Suno")
|
||||||
}
|
}
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId)
|
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
|
||||||
var res QueryRespVo
|
var res QueryRespVo
|
||||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,166 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
|
const FailTaskProgress = 101
|
||||||
|
const (
|
||||||
|
TaskStatusRunning = "RUNNING"
|
||||||
|
TaskStatusFinished = "FINISH"
|
||||||
|
TaskStatusFailed = "FAIL"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NotifyMessage struct {
|
||||||
|
UserId int `json:"user_id"`
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
|
JobId int `json:"job_id"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||||
|
|
||||||
|
const ImagePromptOptimizeTemplate = `
|
||||||
|
Create a highly effective prompt to provide to an AI image generation tool in order to create an artwork based on a desired concept.
|
||||||
|
|
||||||
|
Please specify details about the artwork, such as the style, subject, mood, and other important characteristics you want the resulting image to have.
|
||||||
|
|
||||||
|
Remember, prompts should always be output in English.
|
||||||
|
|
||||||
|
# Steps
|
||||||
|
|
||||||
|
1. **Subject Description**: Describe the main subject of the image clearly. Include as much detail as possible about what should be in the scene. For example, "a majestic lion roaring at sunrise" or "a futuristic city with flying cars."
|
||||||
|
|
||||||
|
2. **Art Style**: Specify the art style you envision. Possible options include 'realistic', 'impressionist', a specific artist name, or imaginative styles like "cyberpunk." This helps the AI achieve your visual expectations.
|
||||||
|
|
||||||
|
3. **Mood or Atmosphere**: Convey the feeling you want the image to evoke. For instance, peaceful, chaotic, epic, etc.
|
||||||
|
|
||||||
|
4. **Color Palette and Lighting**: Mention color preferences or lighting. For example, "vibrant with shades of blue and purple" or "dim and dramatic lighting."
|
||||||
|
|
||||||
|
5. **Optional Features**: You can add any additional attributes, such as background details, attention to textures, or any specific kind of framing.
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
|
||||||
|
- **Prompt Format**: A descriptive phrase that includes key aspects of the artwork (subject, style, mood, colors, lighting, any optional features).
|
||||||
|
|
||||||
|
Here is an example of how the final prompt should look:
|
||||||
|
|
||||||
|
"An ethereal landscape featuring towering ice mountains, in an impressionist style reminiscent of Claude Monet, with a serene mood. The sky is glistening with soft purples and whites, with a gentle morning sun illuminating the scene."
|
||||||
|
|
||||||
|
**Please input the prompt words directly in English, and do not input any other explanatory statements**
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
|
||||||
|
1. **Input**:
|
||||||
|
- Subject: A white tiger in a dense jungle
|
||||||
|
- Art Style: Realistic
|
||||||
|
- Mood: Intense, mysterious
|
||||||
|
- Lighting: Dramatic contrast with light filtering through leaves
|
||||||
|
|
||||||
|
**Output Prompt**: "A realistic rendering of a white tiger stealthily moving through a dense jungle, with an intense, mysterious mood. The lighting creates strong contrasts as beams of sunlight filter through a thick canopy of leaves."
|
||||||
|
|
||||||
|
2. **Input**:
|
||||||
|
- Subject: An enchanted castle on a floating island
|
||||||
|
- Art Style: Fantasy
|
||||||
|
- Mood: Majestic, magical
|
||||||
|
- Colors: Bright blues, greens, and gold
|
||||||
|
|
||||||
|
**Output Prompt**: "A majestic fantasy castle on a floating island above the clouds, with bright blues, greens, and golds to create a magical, dreamy atmosphere. Textured cobblestone details and glistening waters surround the scene."
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
|
||||||
|
- Ensure that you mix different aspects to get a comprehensive and visually compelling prompt.
|
||||||
|
- Be as descriptive as possible as it often helps generate richer, more detailed images.
|
||||||
|
- If you want the image to resemble a particular artist's work, be sure to mention the artist explicitly. e.g., "in the style of Van Gogh."
|
||||||
|
|
||||||
|
The theme of the creation is:【%s】
|
||||||
|
`
|
||||||
|
|
||||||
|
const LyricPromptTemplate = `
|
||||||
|
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
|
||||||
|
请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。
|
||||||
|
输出格式如下:
|
||||||
|
歌曲名称
|
||||||
|
第一节:
|
||||||
|
{{歌词内容}}
|
||||||
|
副歌:
|
||||||
|
{{歌词内容}}
|
||||||
|
|
||||||
|
第二节:
|
||||||
|
{{歌词内容}}
|
||||||
|
副歌:
|
||||||
|
{{歌词内容}}
|
||||||
|
|
||||||
|
尾声:
|
||||||
|
{{歌词内容}}
|
||||||
|
`
|
||||||
|
|
||||||
|
const VideoPromptTemplate = `
|
||||||
|
As an expert in video generation prompts, please create a detailed descriptive prompt for the following video concept. The description should include the setting, character appearance, actions, overall atmosphere, and camera angles. Please make it as detailed and vivid as possible to help ensure that every aspect of the video is accurately captured.
|
||||||
|
|
||||||
|
Please remember that regardless of the user’s input, the final output must be in English.
|
||||||
|
|
||||||
|
# Details to Include
|
||||||
|
|
||||||
|
- Describe the overall visual style of the video (e.g., animated, realistic, retro tone, etc.)
|
||||||
|
- Identify key characters or objects in the video and describe their appearance, attire, and expressions
|
||||||
|
- Describe the environment of the scene, including weather, lighting, colors, and important details
|
||||||
|
- Explain the behavior and interactions of the characters
|
||||||
|
- Include any unique camera angles, movements, or special effects
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
Provide the prompt in paragraph form, ensuring that the description is detailed enough for a video generation system to recreate the envisioned scene. Include the beginning, middle, and end of the scene to convey a complete storyline.
|
||||||
|
|
||||||
|
# Example
|
||||||
|
**User Input:**
|
||||||
|
“A small cat basking in the sun on a balcony.”
|
||||||
|
|
||||||
|
**Generated Prompt:**
|
||||||
|
On a bright spring afternoon, an orange-striped kitten lies lazily on a balcony, basking in the warm sunlight. The iron railings around the balcony cast soft shadows that dance gently with the light. The cat’s eyes are half-closed, exuding a sense of contentment and tranquility in its surroundings. In the distance, a few fluffy white clouds drift slowly across the blue sky. The camera initially focuses on the cat’s face, capturing the delicate details of its fur, and then gradually zooms out to reveal the full balcony scene, immersing viewers in a moment of calm and relaxation.
|
||||||
|
|
||||||
|
The theme of the creation is:【%s】
|
||||||
|
`
|
||||||
|
|
||||||
|
const MetaPromptTemplate = `
|
||||||
|
Given a task description or existing prompt, produce a detailed system prompt to guide a language model in completing the task effectively.
|
||||||
|
|
||||||
|
Please remember, the final output must be the same language with user’s input.
|
||||||
|
|
||||||
|
# Guidelines
|
||||||
|
|
||||||
|
- Understand the Task: Grasp the main objective, goals, requirements, constraints, and expected output.
|
||||||
|
- Minimal Changes: If an existing prompt is provided, improve it only if it's simple. For complex prompts, enhance clarity and add missing elements without altering the original structure.
|
||||||
|
- Reasoning Before Conclusions**: Encourage reasoning steps before any conclusions are reached. ATTENTION! If the user provides examples where the reasoning happens afterward, REVERSE the order! NEVER START EXAMPLES WITH CONCLUSIONS!
|
||||||
|
- Reasoning Order: Call out reasoning portions of the prompt and conclusion parts (specific fields by name). For each, determine the ORDER in which this is done, and whether it needs to be reversed.
|
||||||
|
- Conclusion, classifications, or results should ALWAYS appear last.
|
||||||
|
- Examples: Include high-quality examples if helpful, using placeholders [in brackets] for complex elements.
|
||||||
|
- What kinds of examples may need to be included, how many, and whether they are complex enough to benefit from placeholders.
|
||||||
|
- Clarity and Conciseness: Use clear, specific language. Avoid unnecessary instructions or bland statements.
|
||||||
|
- Formatting: Use markdown features for readability. DO NOT USE CODE BLOCKS UNLESS SPECIFICALLY REQUESTED.
|
||||||
|
- Preserve User Content: If the input task or prompt includes extensive guidelines or examples, preserve them entirely, or as closely as possible. If they are vague, consider breaking down into sub-steps. Keep any details, guidelines, examples, variables, or placeholders provided by the user.
|
||||||
|
- Constants: DO include constants in the prompt, as they are not susceptible to prompt injection. Such as guides, rubrics, and examples.
|
||||||
|
- Output Format: Explicitly the most appropriate output format, in detail. This should include length and syntax (e.g. short sentence, paragraph, JSON, etc.)
|
||||||
|
- For tasks outputting well-defined or structured data (classification, JSON, etc.) bias toward outputting a JSON.
|
||||||
|
- JSON should never be wrapped in code blocks unless explicitly requested.
|
||||||
|
|
||||||
|
The final prompt you output should adhere to the following structure below. Do not include any additional commentary, only output the completed system prompt. SPECIFICALLY, do not include any additional messages at the start or end of the prompt. (e.g. no "---")
|
||||||
|
|
||||||
|
[Concise instruction describing the task - this should be the first line in the prompt, no section header]
|
||||||
|
|
||||||
|
[Additional details as needed.]
|
||||||
|
|
||||||
|
[Optional sections with headings or bullet points for detailed steps.]
|
||||||
|
|
||||||
|
# Steps [optional]
|
||||||
|
|
||||||
|
[optional: a detailed breakdown of the steps necessary to accomplish the task]
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
|
||||||
|
[Specifically call out how the output should be formatted, be it response length, structure e.g. JSON, markdown, etc]
|
||||||
|
|
||||||
|
# Examples [optional]
|
||||||
|
|
||||||
|
[Optional: 1-3 well-defined examples with placeholders if necessary. Clearly mark where examples start and end, and what the input and output are. User placeholders as necessary.]
|
||||||
|
[If the examples are shorter than what a realistic example is expected to be, make a reference with () explaining how real examples should be longer / shorter / different. AND USE PLACEHOLDERS! ]
|
||||||
|
|
||||||
|
# Notes [optional]
|
||||||
|
|
||||||
|
[optional: edge cases, details, and an area to call or repeat out specific important considerations]
|
||||||
|
`
|
||||||
|
|||||||
83
api/service/user_service.go
Normal file
83
api/service/user_service.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
lock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserService(db *gorm.DB) *UserService {
|
||||||
|
return &UserService{db: db, lock: sync.Mutex{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncreasePower 增加用户算力
|
||||||
|
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
tx := s.db.Begin()
|
||||||
|
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
tx.Where("id", userId).First(&user)
|
||||||
|
err = tx.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: log.Type,
|
||||||
|
Amount: power,
|
||||||
|
Balance: user.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: log.Model,
|
||||||
|
Remark: log.Remark,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecreasePower 减少用户算力
|
||||||
|
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
tx := s.db.Begin()
|
||||||
|
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return fmt.Errorf("扣减算力失败:%v", err)
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
tx.Where("id", userId).First(&user)
|
||||||
|
err = tx.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: log.Type,
|
||||||
|
Amount: power,
|
||||||
|
Balance: user.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: log.Model,
|
||||||
|
Remark: log.Remark,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return fmt.Errorf("记录算力日志失败:%v", err)
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
377
api/service/video/luma.go
Normal file
377
api/service/video/luma.go
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
package video
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
httpClient *req.Client
|
||||||
|
db *gorm.DB
|
||||||
|
uploadManager *oss.UploaderManager
|
||||||
|
taskQueue *store.RedisQueue
|
||||||
|
notifyQueue *store.RedisQueue
|
||||||
|
wsService *service.WebsocketService
|
||||||
|
clientIds map[uint]string
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||||
|
return &Service{
|
||||||
|
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||||
|
db: db,
|
||||||
|
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||||
|
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
|
||||||
|
wsService: wsService,
|
||||||
|
uploadManager: manager,
|
||||||
|
clientIds: map[uint]string{},
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) PushTask(task types.VideoTask) {
|
||||||
|
logger.Infof("add a new Video task to the task list: %+v", task)
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Run() {
|
||||||
|
// 将数据库中未提交的人物加载到队列
|
||||||
|
var jobs []model.VideoJob
|
||||||
|
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||||||
|
for _, v := range jobs {
|
||||||
|
var task types.VideoTask
|
||||||
|
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("decode task info with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
task.Id = v.Id
|
||||||
|
s.PushTask(task)
|
||||||
|
s.clientIds[v.Id] = task.ClientId
|
||||||
|
}
|
||||||
|
logger.Info("Starting Video job consumer...")
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var task types.VideoTask
|
||||||
|
err := s.taskQueue.LPop(&task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("taking task with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate prompt
|
||||||
|
if utils.HasChinese(task.Prompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||||
|
if err == nil {
|
||||||
|
task.Prompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.ClientId != "" {
|
||||||
|
s.clientIds[task.Id] = task.ClientId
|
||||||
|
}
|
||||||
|
|
||||||
|
var r LumaRespVo
|
||||||
|
r, err = s.LumaCreate(task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("create task with error: %v", err)
|
||||||
|
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||||
|
"err_msg": err.Error(),
|
||||||
|
"progress": service.FailTaskProgress,
|
||||||
|
"cover_url": "/images/failed.jpg",
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("update task with error: %v", err)
|
||||||
|
}
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新任务信息
|
||||||
|
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||||
|
"task_id": r.Id,
|
||||||
|
"channel": r.Channel,
|
||||||
|
"prompt_ext": r.Prompt,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("update task with error: %v", err)
|
||||||
|
s.PushTask(task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
type LumaRespVo struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
State string `json:"state"`
|
||||||
|
QueueState interface{} `json:"queue_state"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Video interface{} `json:"video"`
|
||||||
|
VideoRaw interface{} `json:"video_raw"`
|
||||||
|
Liked interface{} `json:"liked"`
|
||||||
|
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||||
|
Thumbnail interface{} `json:"thumbnail"`
|
||||||
|
Channel string `json:"channel,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
|
||||||
|
// 读取 API KEY
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
|
||||||
|
if task.Channel != "" {
|
||||||
|
session = session.Where("api_url", task.Channel)
|
||||||
|
}
|
||||||
|
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||||
|
if tx.Error != nil {
|
||||||
|
return LumaRespVo{}, errors.New("no available API KEY for Luma")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"user_prompt": task.Prompt,
|
||||||
|
"expand_prompt": task.Params.PromptOptimize,
|
||||||
|
"loop": task.Params.Loop,
|
||||||
|
"image_url": task.Params.StartImgURL,
|
||||||
|
"image_end_url": task.Params.EndImgURL,
|
||||||
|
}
|
||||||
|
var res LumaRespVo
|
||||||
|
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
|
||||||
|
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(reqBody).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != 200 && r.StatusCode != 201 {
|
||||||
|
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
err = json.Unmarshal(body, &res)
|
||||||
|
if err != nil {
|
||||||
|
return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the last_use_at for api key
|
||||||
|
apiKey.LastUsedAt = time.Now().Unix()
|
||||||
|
session.Updates(&apiKey)
|
||||||
|
res.Channel = apiKey.ApiURL
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Suno task notify checking ...")
|
||||||
|
for {
|
||||||
|
var message service.NotifyMessage
|
||||||
|
err := s.notifyQueue.LPop(&message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Debugf("Receive notify message: %+v", message)
|
||||||
|
client := s.wsService.Clients.Get(message.ClientId)
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
utils.SendChannelMsg(client, types.ChLuma, message.Message)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DownloadFiles() {
|
||||||
|
go func() {
|
||||||
|
var items []model.VideoJob
|
||||||
|
for {
|
||||||
|
res := s.db.Where("progress", 102).Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range items {
|
||||||
|
if v.WaterURL == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("try download video: %s", v.WaterURL)
|
||||||
|
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("download video with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Infof("download video success: %s", videoURL)
|
||||||
|
v.WaterURL = videoURL
|
||||||
|
|
||||||
|
if v.VideoURL != "" {
|
||||||
|
logger.Infof("try download no water video: %s", v.VideoURL)
|
||||||
|
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("download video with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Infof("download no water video success: %s", videoURL)
|
||||||
|
v.VideoURL = videoURL
|
||||||
|
v.Progress = 100
|
||||||
|
s.db.Updates(&v)
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncTaskProgress 异步拉取任务
|
||||||
|
func (s *Service) SyncTaskProgress() {
|
||||||
|
go func() {
|
||||||
|
var jobs []model.VideoJob
|
||||||
|
for {
|
||||||
|
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("query task with error: %v", err)
|
||||||
|
// 更新任务信息
|
||||||
|
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||||||
|
"err_msg": err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("task: %+v", task)
|
||||||
|
if task.State == "completed" { // 更新任务信息
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"progress": 102, // 102 表示资源未下载完成,
|
||||||
|
"water_url": task.Video.Url,
|
||||||
|
"raw_data": utils.JsonEncode(task),
|
||||||
|
"prompt_ext": task.Prompt,
|
||||||
|
"cover_url": task.Thumbnail.Url,
|
||||||
|
}
|
||||||
|
if task.Video.DownloadUrl != "" {
|
||||||
|
data["video_url"] = task.Video.DownloadUrl
|
||||||
|
}
|
||||||
|
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("更新数据库失败:%v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找出失败的任务,并恢复其扣减算力
|
||||||
|
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||||
|
for _, job := range jobs {
|
||||||
|
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
|
Type: types.PowerRefund,
|
||||||
|
Model: "luma",
|
||||||
|
Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 更新任务状态
|
||||||
|
s.db.Model(&job).UpdateColumn("power", 0)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
type LumaTaskVo struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Liked interface{} `json:"liked"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Video struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
Thumbnail string `json:"thumbnail"`
|
||||||
|
DownloadUrl string `json:"download_url"`
|
||||||
|
} `json:"video"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
BatchId string `json:"batch_id"`
|
||||||
|
Thumbnail struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
} `json:"thumbnail"`
|
||||||
|
VideoRaw struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
} `json:"video_raw"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
LastFrame struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
} `json:"last_frame"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
|
||||||
|
// 读取 API KEY
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
|
||||||
|
Where("api_url", channel).
|
||||||
|
Where("enabled", true).
|
||||||
|
Order("last_used_at DESC").First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
|
||||||
|
var res LumaTaskVo
|
||||||
|
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
if r.StatusCode != 200 {
|
||||||
|
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
err = json.Unmarshal(body, &res)
|
||||||
|
if err != nil {
|
||||||
|
return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
13
api/service/ws_service.go
Normal file
13
api/service/ws_service.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "geekai/core/types"
|
||||||
|
|
||||||
|
type WebsocketService struct {
|
||||||
|
Clients *types.LMap[string, *types.WsClient] // clientId => Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebsocketService() *WebsocketService {
|
||||||
|
return &WebsocketService{
|
||||||
|
Clients: types.NewLMap[string, *types.WsClient](),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
package wx
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
logger2 "geekai/logger"
|
|
||||||
"geekai/store/model"
|
|
||||||
"github.com/eatmoreapple/openwechat"
|
|
||||||
"github.com/skip2/go-qrcode"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 微信收款机器人
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
type Bot struct {
|
|
||||||
bot *openwechat.Bot
|
|
||||||
token string
|
|
||||||
db *gorm.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWeChatBot(db *gorm.DB) *Bot {
|
|
||||||
bot := openwechat.DefaultBot(openwechat.Desktop)
|
|
||||||
return &Bot{
|
|
||||||
bot: bot,
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) Run() error {
|
|
||||||
logger.Info("Starting WeChat Bot...")
|
|
||||||
|
|
||||||
// set message handler
|
|
||||||
b.bot.MessageHandler = func(msg *openwechat.Message) {
|
|
||||||
b.messageHandler(msg)
|
|
||||||
}
|
|
||||||
// scan code login callback
|
|
||||||
b.bot.UUIDCallback = b.qrCodeCallBack
|
|
||||||
debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG"))
|
|
||||||
if debug {
|
|
||||||
reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json")
|
|
||||||
err = b.bot.HotLogin(reloadStorage, true)
|
|
||||||
} else {
|
|
||||||
err = b.bot.Login()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("微信登录成功!")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// message handler
|
|
||||||
func (b *Bot) messageHandler(msg *openwechat.Message) {
|
|
||||||
sender, err := msg.Sender()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只处理微信支付的推送消息
|
|
||||||
if sender.NickName == "微信支付" ||
|
|
||||||
msg.MsgType == openwechat.MsgTypeApp ||
|
|
||||||
msg.AppMsgType == openwechat.AppMsgTypeUrl {
|
|
||||||
// 解析支付金额
|
|
||||||
message := parseTransactionMessage(msg.Content)
|
|
||||||
transaction := extractTransaction(message)
|
|
||||||
logger.Infof("解析到收款信息:%+v", transaction)
|
|
||||||
if transaction.TransId != "" {
|
|
||||||
var item model.Reward
|
|
||||||
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
|
|
||||||
if item.Id > 0 {
|
|
||||||
logger.Error("当前交易 ID 己经存在!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res = b.db.Create(&model.Reward{
|
|
||||||
TxId: transaction.TransId,
|
|
||||||
Amount: transaction.Amount,
|
|
||||||
Remark: transaction.Remark,
|
|
||||||
Status: false,
|
|
||||||
})
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Errorf("交易保存失败: %v", res.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) qrCodeCallBack(uuid string) {
|
|
||||||
logger.Info("请使用微信扫描下面二维码登录")
|
|
||||||
q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
|
|
||||||
logger.Info(q.ToString(true))
|
|
||||||
}
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
package wx
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/xml"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Message 转账消息
|
|
||||||
type Message struct {
|
|
||||||
Des string
|
|
||||||
Url string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transaction 解析后的交易信息
|
|
||||||
type Transaction struct {
|
|
||||||
TransId string `json:"trans_id"` // 微信转账交易 ID
|
|
||||||
Amount float64 `json:"amount"` // 微信转账交易金额
|
|
||||||
Remark string `json:"remark"` // 转账备注
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析微信转账消息
|
|
||||||
func parseTransactionMessage(xmlData string) *Message {
|
|
||||||
decoder := xml.NewDecoder(strings.NewReader(xmlData))
|
|
||||||
message := Message{}
|
|
||||||
for {
|
|
||||||
token, err := decoder.Token()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch se := token.(type) {
|
|
||||||
case xml.StartElement:
|
|
||||||
var value string
|
|
||||||
if se.Name.Local == "des" && message.Des == "" {
|
|
||||||
if err := decoder.DecodeElement(&value, &se); err == nil {
|
|
||||||
message.Des = strings.TrimSpace(value)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
|
|
||||||
if err := decoder.DecodeElement(&value, &se); err == nil {
|
|
||||||
if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
|
|
||||||
message.Url = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 兼容旧版消息记录
|
|
||||||
if message.Url == "" {
|
|
||||||
var msg struct {
|
|
||||||
XMLName xml.Name `xml:"msg"`
|
|
||||||
AppMsg struct {
|
|
||||||
Des string `xml:"des"`
|
|
||||||
Url string `xml:"url"`
|
|
||||||
} `xml:"appmsg"`
|
|
||||||
}
|
|
||||||
if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil {
|
|
||||||
message.Url = msg.AppMsg.Url
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &message
|
|
||||||
}
|
|
||||||
|
|
||||||
// 导出交易信息
|
|
||||||
func extractTransaction(message *Message) Transaction {
|
|
||||||
var tx = Transaction{}
|
|
||||||
// 导出交易金额和备注
|
|
||||||
lines := strings.Split(message.Des, "\n")
|
|
||||||
for _, line := range lines {
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 解析收款金额
|
|
||||||
prefix := "收款金额¥"
|
|
||||||
if strings.HasPrefix(line, prefix) {
|
|
||||||
if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
|
|
||||||
tx.Amount = value
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 解析收款备注
|
|
||||||
prefix = "付款方备注"
|
|
||||||
if strings.HasPrefix(line, prefix) {
|
|
||||||
tx.Remark = line[len(prefix):]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析交易 ID
|
|
||||||
parse, err := url.Parse(message.Url)
|
|
||||||
if err == nil {
|
|
||||||
tx.TransId = parse.Query().Get("id")
|
|
||||||
if tx.TransId == "" {
|
|
||||||
tx.TransId = parse.Query().Get("trans_id")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx
|
|
||||||
}
|
|
||||||
@@ -81,54 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
|
|||||||
// 自动将 VIP 会员的算力补充到每月赠送的最大值
|
// 自动将 VIP 会员的算力补充到每月赠送的最大值
|
||||||
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||||
logger.Info("开始进行月底账号盘点...")
|
logger.Info("开始进行月底账号盘点...")
|
||||||
var users []model.User
|
|
||||||
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
|
|
||||||
if res.Error != nil {
|
|
||||||
return "No vip users found"
|
|
||||||
}
|
|
||||||
|
|
||||||
var sysConfig model.Config
|
|
||||||
res = e.db.Where("marker", "system").First(&sysConfig)
|
|
||||||
if res.Error != nil {
|
|
||||||
return "error with get system config: " + res.Error.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
var config types.SystemConfig
|
|
||||||
err := utils.JsonDecode(sysConfig.Config, &config)
|
|
||||||
if err != nil {
|
|
||||||
return "error with decode system config: " + err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, u := range users {
|
|
||||||
// 处理过期的 VIP
|
|
||||||
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
|
|
||||||
u.Vip = false
|
|
||||||
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if u.Power < config.VipMonthPower {
|
|
||||||
power := config.VipMonthPower - u.Power
|
|
||||||
// update user
|
|
||||||
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
|
|
||||||
// 记录算力变动日志
|
|
||||||
if tx.Error == nil {
|
|
||||||
var user model.User
|
|
||||||
e.db.Where("id", u.Id).First(&user)
|
|
||||||
e.db.Create(&model.PowerLog{
|
|
||||||
UserId: u.Id,
|
|
||||||
Username: u.Username,
|
|
||||||
Type: types.PowerRecharge,
|
|
||||||
Amount: power,
|
|
||||||
Mark: types.PowerAdd,
|
|
||||||
Balance: user.Power,
|
|
||||||
Model: "系统盘点",
|
|
||||||
Remark: fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("月底盘点完成!")
|
|
||||||
return "success"
|
return "success"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,15 +29,9 @@ func NewLevelDB() (*LevelDB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *LevelDB) Put(key string, value interface{}) error {
|
func (db *LevelDB) Put(key string, value interface{}) error {
|
||||||
var byteData []byte
|
byteData, err := json.Marshal(value)
|
||||||
if v, ok := value.(string); ok {
|
if err != nil {
|
||||||
byteData = []byte(v)
|
return err
|
||||||
} else {
|
|
||||||
b, err := json.Marshal(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
byteData = b
|
|
||||||
}
|
}
|
||||||
return db.driver.Put([]byte(key), byteData, nil)
|
return db.driver.Put([]byte(key), byteData, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
12
api/store/model/app_type.go
Normal file
12
api/store/model/app_type.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type AppType struct {
|
||||||
|
Id uint `gorm:"primarykey"`
|
||||||
|
Name string
|
||||||
|
Icon string
|
||||||
|
Enabled bool
|
||||||
|
SortNum int
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
@@ -4,16 +4,17 @@ import "gorm.io/gorm"
|
|||||||
|
|
||||||
type ChatMessage struct {
|
type ChatMessage struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
ChatId string // 会话 ID
|
ChatId string // 会话 ID
|
||||||
UserId uint // 用户 ID
|
UserId uint // 用户 ID
|
||||||
RoleId uint // 角色 ID
|
RoleId uint // 角色 ID
|
||||||
Model string // AI模型
|
Model string // AI模型
|
||||||
Type string
|
Type string
|
||||||
Icon string
|
Icon string
|
||||||
Tokens int
|
Tokens int
|
||||||
Content string
|
TotalTokens int // 总 token 消耗
|
||||||
UseContext bool // 是否可以作为聊天上下文
|
Content string
|
||||||
DeletedAt gorm.DeletedAt
|
UseContext bool // 是否可以作为聊天上下文
|
||||||
|
DeletedAt gorm.DeletedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ChatMessage) TableName() string {
|
func (ChatMessage) TableName() string {
|
||||||
|
|||||||
@@ -12,4 +12,5 @@ type ChatModel struct {
|
|||||||
MaxContext int // 最大上下文长度
|
MaxContext int // 最大上下文长度
|
||||||
Temperature float32 // 模型温度
|
Temperature float32 // 模型温度
|
||||||
KeyId int // 绑定 API KEY ID
|
KeyId int // 绑定 API KEY ID
|
||||||
|
Type string // 模型类型
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
type ChatRole struct {
|
type ChatRole struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
Tid int
|
||||||
Key string `gorm:"column:marker;unique"` // 角色唯一标识
|
Key string `gorm:"column:marker;unique"` // 角色唯一标识
|
||||||
Name string // 角色名称
|
Name string // 角色名称
|
||||||
Context string `gorm:"column:context_json"` // 角色语料信息 json
|
Context string `gorm:"column:context_json"` // 角色语料信息 json
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ type DallJob struct {
|
|||||||
Id uint `gorm:"primarykey;column:id"`
|
Id uint `gorm:"primarykey;column:id"`
|
||||||
UserId uint
|
UserId uint
|
||||||
Prompt string
|
Prompt string
|
||||||
|
TaskInfo string // 原始任务信息
|
||||||
ImgURL string
|
ImgURL string
|
||||||
OrgURL string
|
OrgURL string
|
||||||
Publish bool
|
Publish bool
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ type MidJourneyJob struct {
|
|||||||
Type string
|
Type string
|
||||||
UserId int
|
UserId int
|
||||||
TaskId string
|
TaskId string
|
||||||
|
TaskInfo string // 原始任务信息
|
||||||
ChannelId string
|
ChannelId string
|
||||||
MessageId string
|
MessageId string
|
||||||
ReferenceId string
|
ReferenceId string
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Order 充值订单
|
// Order 充值订单
|
||||||
@@ -18,6 +17,6 @@ type Order struct {
|
|||||||
Status types.OrderStatus
|
Status types.OrderStatus
|
||||||
Remark string
|
Remark string
|
||||||
PayTime int64
|
PayTime int64
|
||||||
PayWay string // 支付方式
|
PayWay string // 支付渠道
|
||||||
DeletedAt gorm.DeletedAt
|
PayType string // 支付类型
|
||||||
}
|
}
|
||||||
|
|||||||
16
api/store/model/redeem.go
Normal file
16
api/store/model/redeem.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// 兑换码
|
||||||
|
|
||||||
|
type Redeem struct {
|
||||||
|
Id uint `gorm:"primarykey;column:id"`
|
||||||
|
UserId uint // 用户 ID
|
||||||
|
Name string // 名称
|
||||||
|
Power int // 算力
|
||||||
|
Code string // 兑换码
|
||||||
|
Enabled bool // 启用状态
|
||||||
|
RedeemedAt int64 // 兑换时间
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
// 用户打赏
|
|
||||||
|
|
||||||
type Reward struct {
|
|
||||||
BaseModel
|
|
||||||
UserId uint // 用户 ID
|
|
||||||
TxId string // 交易ID
|
|
||||||
Amount float64 // 打赏金额
|
|
||||||
Remark string // 打赏备注
|
|
||||||
Status bool // 核销状态
|
|
||||||
Exchange string // 众筹兑换详情,JSON
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ type SdJob struct {
|
|||||||
Type string
|
Type string
|
||||||
UserId int
|
UserId int
|
||||||
TaskId string
|
TaskId string
|
||||||
|
TaskInfo string // 原始任务信息
|
||||||
ImgURL string
|
ImgURL string
|
||||||
Progress int
|
Progress int
|
||||||
Prompt string
|
Prompt string
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user