Compare commits

..

157 Commits

Author SHA1 Message Date
RockYang
4e6f14cb9e 更新数据库文件 2025-01-10 17:56:00 +08:00
RockYang
8dc03a7509 添加备案信息配置项,给登录页面 Logo 增加圆角 2025-01-10 17:33:25 +08:00
RockYang
57b1b44645 优化编译指令,减少程序体积 2025-01-08 16:53:09 +08:00
RockYang
aa17a33093 支持微模型绑定 Dalle 绘图的 API KEY 2025-01-08 10:50:37 +08:00
RockYang
80e27c40e9 管理后台用户算力日志页面增加过滤查询功能 2025-01-08 10:19:35 +08:00
RockYang
8250e876a5 允许配置登录注册页面的自定义 Logo 2025-01-07 15:13:34 +08:00
RockYang
9f98491368 修复微信登录功能 2025-01-07 14:44:47 +08:00
RockYang
fe160f978b 首页增加码云连接地址 2025-01-06 19:35:45 +08:00
RockYang
7da5b7163c 支持在 Chat 页面显示,隐藏对话列表 2025-01-06 19:17:18 +08:00
RockYang
cffc722622 O1 模型支持流式输出 2025-01-06 11:56:29 +08:00
RockYang
a7baf1dc9e optimize login page styles 2024-12-31 08:37:21 +08:00
RockYang
488169683f micro fixs, update database SQL file 2024-12-27 17:02:27 +08:00
RockYang
2ba3c52e6e 移动端菜单按需加载,后台可以配置是否显示 2024-12-26 18:59:58 +08:00
RockYang
18179613fc 更新移动端 Dalle 绘图页面,支持模型选择 2024-12-26 18:50:45 +08:00
RockYang
8af0fec8ec 增加移动端登录页面 2024-12-26 16:52:20 +08:00
RockYang
acee2d9d81 add 'type' field for ChatModel, support Chat and Image model 2024-12-25 18:57:18 +08:00
RockYang
cbf06eea24 fine-tune page styles, use iframe to load external page in navigation bar 2024-12-25 11:10:23 +08:00
RockYang
989b4a64d6 set white background color for close icon 2024-12-24 18:13:59 +08:00
RockYang
b01b10014a fine-tune the new UI theme 2024-12-24 17:43:40 +08:00
RockYang
e857f98e5c fine-tune the new UI theme, merge the code and fixed conflicts 2024-12-24 16:13:36 +08:00
RockYang
274cff71b1 引入tailwind css,调整样式 2024-12-24 11:07:04 +08:00
lqins
06573c5d12 邀请列表修改 2024-12-23 18:28:11 +08:00
lqins
937e5befa2 修改部分细节 2024-12-23 10:49:28 +08:00
lqins
fb403bde8b 颜色修改 2024-12-22 13:55:01 +08:00
lqins
ba174ef3ee 发送style 2024-12-22 13:00:19 +08:00
lqins
b7b702862f add hover color 2024-12-22 12:52:02 +08:00
lqins
6df2b5735b 暂无数据 2024-12-22 12:43:40 +08:00
lqins
130e151a06 Merge branch 'front-1.0' of https://gitee.com/blackfox/geekai-plus into front-1.0 2024-12-22 12:41:35 +08:00
lqins
ab903e3cc1 细节处修改 2024-12-22 12:37:32 +08:00
RockYang
237387b2ab 支持按次收费的 OpenAI 实时语音通话功能 2024-12-20 18:21:54 +08:00
lqins
0c1f650e9c style:样式切换 2024-12-19 17:09:47 +08:00
lqins
357c77ef30 style:样式切换 2024-12-19 16:57:57 +08:00
RockYang
dc7c049a7b fixed bug for luma api response data parse error 2024-12-16 16:44:24 +08:00
lqins
710b008453 feat:chat style 2024-12-11 09:36:12 +08:00
RockYang
6e7aecc568 update version 2024-11-27 16:51:14 +08:00
RockYang
b68f7e3fd1 update database file 2024-11-27 14:34:39 +08:00
RockYang
d30d5585c6 redeem export function is ready 2024-11-27 11:52:18 +08:00
lqins
1b7c7a0dc1 delect clearable 2024-11-23 15:51:44 +08:00
lqins
207f2b5ac4 change color 2024-11-23 15:47:57 +08:00
lqins
d13fa1392f feat:about account 2024-11-23 15:40:05 +08:00
lqins
9bf886fe98 go 2024-11-20 18:54:50 +08:00
lqins
aeef77ac24 test 2024-11-20 00:18:14 +08:00
廖庆斯
9a97a1ee72 feat: change theme and index style 2024-11-20 00:09:25 +08:00
RockYang
6aaf607ed7 fixed bug for chat context not work for chating with image 2024-11-12 18:23:27 +08:00
RockYang
cff0397735 meta prompt function is ready 2024-11-12 17:13:38 +08:00
RockYang
2aa0b51c09 auto restore user's power for failure tasks 2024-11-11 18:12:35 +08:00
RockYang
ce8a2d0222 save task origin info for AI generating jobs 2024-11-11 17:22:08 +08:00
RockYang
135755d21d enable to set the translate model 2024-11-08 18:06:39 +08:00
RockYang
5be4e83876 fixed bug for audio and video downloading 2024-11-05 11:38:32 +08:00
RockYang
cbc9eb3a59 merge database files 2024-10-30 18:12:50 +08:00
RockYang
0593359ef8 micro fix 2024-10-30 18:11:36 +08:00
RockYang
2081d3ce29 add databse sql file 2024-10-23 20:01:17 +08:00
RockYang
41d9c097e8 update database sql file 2024-10-23 18:16:36 +08:00
RockYang
1fe1e40a43 优化实时语音对话组件,处理异常 2024-10-23 18:04:09 +08:00
RockYang
ad6e2dd370 modify text link color for register page 2024-10-21 18:26:19 +08:00
RockYang
bb63f23414 Merge branch 'main' of gitee.com:blackfox/geekai-plus 2024-10-21 18:21:21 +08:00
RockYang
43f6bf74f2 更换登录页面背景图片 2024-10-21 18:21:04 +08:00
RockYang
662d7b099e 给 realtime 语音对话增加音效 2024-10-18 06:26:05 +08:00
RockYang
d5eeeea764 优化充值产品定价逻辑,确保手机端和PC端显示的价格一致 2024-10-17 18:15:25 +08:00
RockYang
43c507c597 the relay server for openai websocket is ready 2024-10-17 16:46:41 +08:00
RockYang
e356771049 add websocket relayer for openai realtime api 2024-10-16 18:16:09 +08:00
RockYang
48139290ed integrated openai realtime console 2024-10-15 19:25:18 +08:00
RockYang
bd852c82b7 add PCM16 audio stream to wave is reday 2024-10-14 18:39:50 +08:00
RockYang
13564993d7 add voice chat test case 2024-10-12 19:07:29 +08:00
RockYang
bfc1e1bc2c suno and luma task management funtion in admin console is ready 2024-10-10 17:07:40 +08:00
RockYang
ba20717a09 image task list page for admin console is ready 2024-10-09 18:17:44 +08:00
RockYang
52e40daf23 fixed bug in FileSelect component for deleting files 2024-10-08 18:00:46 +08:00
RockYang
430a7b2297 fixed bug for websocket message handler rebind 2024-10-08 16:41:19 +08:00
RockYang
c91a38a882 fixed webscoket event re-bind bug 2024-10-05 21:18:59 +08:00
RockYang
6e02bee4b7 fixed alipay mobile payment 2024-10-05 11:45:44 +08:00
RockYang
b62218110e fixed alipay mobile payment 2024-10-05 10:23:00 +08:00
RockYang
e2960b2607 auto jump to mobile page when use mobile device access the page 2024-10-04 11:25:01 +08:00
RockYang
88e7c39066 fixed bug for: websocket is not auto connected when user not login 2024-10-02 07:26:34 +08:00
RockYang
2a6dd636fa update database 2024-09-30 17:12:23 +08:00
RockYang
6bf38f78d5 add message handler ONLY when websocket connect successfully 2024-09-30 16:33:26 +08:00
RockYang
5a04a935be support wechat and alipay payment for mobile page 2024-09-30 16:20:40 +08:00
RockYang
8923e938d2 optimize the vue component communication, replace event listening with share data 2024-09-30 14:20:59 +08:00
RockYang
1a1734abf0 websocket api refactor is ready 2024-09-29 19:28:47 +08:00
RockYang
8093a3eeb2 mj websocket refactor is ready 2024-09-29 07:51:08 +08:00
RockYang
9edb3d0a82 sd websocket refactor is finished 2024-09-27 18:28:54 +08:00
RockYang
d95fab11be refactor websocket message protocol, keep the only connection for all clients 2024-09-27 17:50:54 +08:00
RockYang
6ef09c8ad5 add ws handler 2024-09-25 18:43:12 +08:00
RockYang
283a023a06 update database sql file for v4.1.4 2024-09-23 15:54:22 +08:00
RockYang
d315edef5f logout the user when it has been disabled 2024-09-20 16:49:03 +08:00
RockYang
5fa17b300e add release v4.1.4 2024-09-20 15:50:04 +08:00
RockYang
32919de7a7 add email white list check in register handler 2024-09-20 14:10:40 +08:00
RockYang
7d126aab41 add email white list 2024-09-20 10:40:37 +08:00
RockYang
16ac57ced3 payment for mobile page is ready 2024-09-19 19:03:03 +08:00
RockYang
4976b967e7 wechat payment for mobile page is ready 2024-09-19 17:59:27 +08:00
RockYang
e874178782 wechat payment for pc is ready 2024-09-19 14:42:25 +08:00
RockYang
8cb66ad01b fixed bug geek-plus#6, register page first tab not auto active 2024-09-19 09:03:14 +08:00
RockYang
f887a39912 geek payment notify api is ready 2024-09-18 22:24:05 +08:00
RockYang
2beffd3dd3 urgent bug fix: remove suno and luma task will recharge user power 2024-09-18 20:33:29 +08:00
RockYang
d8cb92d8d4 Geek Pay notify is ready 2024-09-18 18:07:49 +08:00
RockYang
158db83965 add geek payment 2024-09-18 07:03:46 +08:00
RockYang
603bfa7def recommend user to use Google Chrome 2024-09-15 10:25:50 +08:00
RockYang
829fb879a6 merge app type function branch 2024-09-14 18:17:55 +08:00
RockYang
0385e60ce1 refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model 2024-09-14 17:06:13 +08:00
胡双明
aaea23f785 feat: 应用分类功能 2024-09-14 11:05:49 +08:00
RockYang
131efd6ba5 refactor chat message body struct 2024-09-14 07:11:45 +08:00
RockYang
866564370d return at least one chat role for getUserRoles API 2024-09-14 05:55:56 +08:00
RockYang
dcdc0d8918 add tid field for chat app role 2024-09-13 18:32:13 +08:00
RockYang
6c7fa17e50 fixed bug, filelist page support pagination, do not load captcha component for user login first time 2024-09-13 17:03:05 +08:00
胡双明
38a0d00142 Merge branch 'main' into husm_2024-09-02 2024-09-13 10:25:15 +08:00
RockYang
5c77e67b0f fixed bug for reset password 2024-09-12 17:25:19 +08:00
RockYang
961cee5e41 fixed bug for register page code verification 2024-09-12 15:42:09 +08:00
胡双明
c9cc93be8c Merge branch 'main' into husm_2024-09-02 2024-09-11 09:57:01 +08:00
RockYang
49f2e1a71e optimize download function for suno 2024-09-11 08:39:28 +08:00
胡双明
97eff6085a feat: 210 AI对话页面文件列表增加分页功能 2024-09-10 18:14:34 +08:00
RockYang
8b2e2d61af file list api support pagination 2024-09-10 15:24:36 +08:00
胡双明
c096efb416 Merge branch 'main' into husm_2024-09-02 2024-09-10 14:29:08 +08:00
RockYang
cdaf6fb9dc update v4.1.3 database sql file 2024-09-10 11:15:26 +08:00
RockYang
78f443ed6d update v4.1.3 database sql file 2024-09-10 11:11:17 +08:00
RockYang
54e8d72b10 support multiple delete users, update database sql file 2024-09-10 10:56:04 +08:00
RockYang
05161f48fd update config file 2024-09-09 18:58:03 +08:00
RockYang
e971bf6b88 merge luma page code for v4.1.3 2024-09-09 18:07:10 +08:00
RockYang
55b979784c Merge remote-tracking branch 'inet/husm_2024-09-02' into dev-4.1.3 2024-09-09 10:54:43 +08:00
RockYang
97aa922b5f 优化聊天页面代码,刷新页面之后自动加载当前对话的 role_id 和 model_id 2024-09-05 14:50:37 +08:00
胡双明
11c760a4e8 feat: 生成视频页面2 2024-09-05 11:56:02 +08:00
RockYang
87b03332d9 add auto execute task to downloading video files 2024-09-05 10:54:58 +08:00
胡双明
8b14eeadf4 feat: 生成视频页面 2024-09-05 08:50:49 +08:00
RockYang
e0ead127e0 remove unused css file 2024-09-04 22:42:56 +08:00
RockYang
0887bcdee0 adjust chat page styles 2024-09-04 18:07:39 +08:00
RockYang
67d83041d7 user can select function tools by themself 2024-09-04 14:53:21 +08:00
RockYang
1350f388f0 add page and total field for pagination vo 2024-09-03 18:32:15 +08:00
RockYang
65dde9e69d add sync lock for sub or add user's power 2024-09-03 12:09:36 +08:00
RockYang
2e5bd238b7 luma create and list api is ready 2024-09-02 18:08:50 +08:00
胡双明
8fc8fd6cba feat(Luma): 加上传前后帧调换功能 2024-09-02 17:33:03 +08:00
RockYang
dfc6c87250 optimize ChatPlus page, fixed bug for websocket reconnection 2024-09-02 16:35:15 +08:00
RockYang
b63e01225e add luma api service 2024-08-30 18:12:14 +08:00
RockYang
561b82027a update change log file 2024-08-30 16:47:52 +08:00
RockYang
f6d8fbf570 suno add new function for merging full songs and upload custom music 2024-08-30 16:46:48 +08:00
RockYang
568201ebbb add logo for favirate icon 2024-08-29 13:36:35 +08:00
RockYang
ab421f2185 luma page, upload image and remove image function is ready 2024-08-26 17:59:05 +08:00
RockYang
f71a2f5263 add download for video 2024-08-26 07:24:04 +08:00
RockYang
d000cc5a67 luma page video list component is ready 2024-08-23 18:25:58 +08:00
RockYang
04d6ba0853 luma page is ready 2024-08-20 14:31:40 +08:00
RockYang
8d7c028ca8 refactor reset password functions 2024-08-19 12:05:00 +08:00
RockYang
3ae7ebfeaf fixed styles 2024-08-19 06:42:53 +08:00
RockYang
aa42d38387 add bind mobile, bind email, bind wechat function is ready 2024-08-14 15:56:50 +08:00
RockYang
43843b92f2 add mobile and email filed for user 2024-08-13 18:40:50 +08:00
RockYang
5da879600a add verification code for login and register page 2024-08-13 14:55:47 +08:00
RockYang
87ed2064e3 add Captcha components 2024-08-13 10:01:07 +08:00
RockYang
34e96e91d4 fixed conflicts 2024-08-13 06:48:22 +08:00
RockYang
8c4c2b89ce add wechat login for login dialog 2024-08-12 18:00:34 +08:00
RockYang
373021c191 add drag icon for dragable rows 2024-08-12 14:00:50 +08:00
RockYang
740c3c1b00 release v4.1.2 2024-08-09 18:50:02 +08:00
RockYang
67c7132e6b redeem code function is ready 2024-08-09 18:19:51 +08:00
RockYang
c77843424b add redeem code function 2024-08-08 18:28:50 +08:00
RockYang
2d4959aa7d add cache for getting user info and system configs 2024-08-08 14:37:33 +08:00
RockYang
167c59a159 add clear unpaid order functions 2024-08-07 18:00:28 +08:00
RockYang
1d0006ce59 refactor stable diffusion service, use api key instead of configs 2024-08-07 17:30:59 +08:00
RockYang
6a8b4ee2f1 refactor midjourney service, use api key in database 2024-08-06 18:30:57 +08:00
RockYang
72b1515b68 show sql error message 2024-08-05 16:14:44 +08:00
RockYang
3f0252b498 update change log 2024-08-01 08:54:04 +08:00
RockYang
1d9d487f0e restore use power when removed not finish jobs 2024-07-31 16:08:46 +08:00
RockYang
1bcbf74883 remove chat debug log 2024-07-28 18:55:17 +08:00
394 changed files with 37671 additions and 43757 deletions

View File

@@ -1,5 +1,5 @@
name: Bug 报告 🐛 name: Bug 报告 🐛
description: geekai 提交错误报告 description: chatgpt-plus 提交错误报告
labels: ['Bug'] labels: ['Bug']
body: body:
- type: checkboxes - type: checkboxes

View File

@@ -1,5 +1,5 @@
name: 功能优化 🚀 name: 功能优化 🚀
description: geekai 提交优化建议 description: chatgpt-plus 提交优化建议
labels: ['feature'] labels: ['feature']
body: body:
- type: checkboxes - type: checkboxes

View File

@@ -1,290 +1,383 @@
# 更新日志 # 更新日志
## v4.1.9
- 功能优化:优化系统配置,移除已废弃的配置项
- 功能优化GPT-O1 模型支持流式输出
- 功能优化:优化代码引用快样式,支持主题切换
- 功能优化:登录,注册页面允许替换用户自己的 Logo 和 Title
- Bug 修复:修复 OpenAI 实时语音通话没有检测用户算力不足的 Bug
- 功能新增:管理后台增加算力日志查询功能,支持按用户,按模型,按日期,按类型查询算力日志
- 功能优化:支持为模型绑定 Dalle 和 chat 类型的 API KEY
- 功能新增:支持管理后台设置 ICP 备案号
## v4.1.8
- 功能优化:**UI 全新改版,支持主题切换**。 :rocket: :rocket: :rocket:
- 功能新增Gitee AI API 接口接入,目前支持 Gitee 的 SD 绘图接口,支持 Gitee 的 AI 对话接口。:rocket: :rocket: :rocket:
- Bug 修复:修复音 Luma API 更新导致任务响应解析失败的错误
- 功能优化:支持 Suno v4.0 模型支持
- Bug 修复:修复 Suno 已完成任务删除失败的 错误
- 功能新增:支持 OpenAI 实时语音通话功能,目前已经支持按次收费,支持管理员设置每次实时语音通话的算力消耗
- 功能新增:生成提示词需要消耗算力,支持管理员设置每次生成提示词的算力消耗,防止被白嫖
- 功能新增DALL-E-3 绘图支持 Flux 绘图模型,支持在管理后添加 Flux,SD 等绘图模型
- 功能优化Markdown 支持解析 emoji 表情
- 功能优化:当管理后台禁用了某个绘图菜单的时候,移动端绘图菜单也会同步禁用(不显示该功能)
## v4.1.7
- Bug 修复:手机邮箱相关的注册问题 [#IB0HS5](https://gitee.com/blackfox/geekai/issues/IB0HS5)
- Bug 修复:音乐视频无法下载,思维导图下载后看不清文字[#IB0N2E](https://gitee.com/blackfox/geekai/issues/IB0N2E)
- 功能优化:保存所有 AIGC 任务的原始信息,程序启动之后自动将未执行的任务加入到 redis 队列
- 功能优化:失败的任务自动退回算力,而不需要在删除的时候再退回
- 功能新增:支持设置一个专门的模型来翻译提示词,提供 Mate 提示词生成功能
- Bug 修复:修复图片对话的时候,上下文不起作用的 Bug
- 功能新增:管理后台新增批量导出兑换码功能
## v4.1.6
- 功能新增:**支持 OpenAI 实时语音对话功能** :rocket: :rocket: :rocket:, Beta 版,目前没有做算力计费控制,目前只有 VIP 用户可以使用。
- 功能优化:优化 MysQL 容器配置文档,解决 MysQL 容器资源占用过高问题
- 功能新增:管理后台增加 AI 绘图任务管理,可在管理后台浏览和删除用户的绘图任务
- 功能新增:管理后台增加 Suno 和 Luma 任务管理功能
- Bug 修复:修复管理后台删除兑换码报 404 错误
- 功能优化:优化充值产品定价逻辑,可以设置原价和优惠价,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!****升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!****升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**。
## v4.1.5
- 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
- Bug 修复:兼容手机端原生微信支付和支付宝支付渠道
- Bug 修复:修复删除绘图任务时候因为字段长度过短导致 SQL 执行失败问题
- 功能优化:优化 Vue 组件通信代码,使用共享数据来替换之前的事件订阅模式,效率更高一些
- 功能优化:优化思维导图生成功果页面,优化用户体验
## v4.1.4
- 功能优化:用户文件列表组件增加分页功能支持
- Bug 修复:修复用户注册失败 Bug注册操作只弹出一次行为验证码
- 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
- 功能新增:给 AI 应用(角色)增加分类,前端支持分类筛选
- 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
- 功能优化:移除 PayJS 支付渠道支持PayJs 已经关闭注册服务,请使用其他支付方式。
- 功能新增:新增 GeeK 易支付支付渠道支持支付宝微信支付QQ 钱包京东支付抖音支付Paypal 支付等支付方式
- Bug 修复:修复注册页面 tab 组件没有自动选中问题 [#6](https://github.com/yangjian102621/geekai-plus/issues/6)
- 功能优化Luma 生成视频任务增加自动翻译功能
- Bug 修复Suno 和 Luma 任务没有判断用户算力
- 功能新增:邮箱注册增加邮箱后缀白名单,防止使用某些垃圾邮箱注册薅羊毛
- 功能优化:清空未支付订单时,只清空超过 15 分钟未支付的订单
## v4.1.3
- 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
- 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
- 功能优化:管理后台给可以拖动排序的组件添加拖动图标
- 功能优化Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
- Bug 修复:手机端角色和模型选择不生效
- Bug 修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
- 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
- 功能优化:给算力增减服务全部加上数据库事务和同步锁
- 功能优化:支持用户在前端对话界面选择插件
- 功能新增:支持 Luma 文生视频功能
## v4.1.2
- Bug 修复:修复思维导图页面获取模型失败的问题
- 功能优化:优化 MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
- Bug 修复:修复后台拖动排序组件 Bug
- 功能优化:更新数据库失败时候显示具体的的报错信息
- Bug 修复:修复管理后台对话详情页内容显示异常问题
- 功能优化:管理后台新增清空所有未支付订单的功能
- 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
- 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
## v4.1.1 ## v4.1.1
* Bug修复修复 GPT 模型 function call 调用后没有输出的问题
* 功能新增:允许获取 License 授权用户可以自定义版权信息 - Bug 修复:修复 GPT 模型 function call 调用后没有输出的问题
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件 - 功能新增:允许获取 License 授权用户可以自定义版权信息
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求 - 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用 - 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
* 功能新增MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息 - 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式 - 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
* 功能新增:允许在管理后台设置首页显示的导航菜单 - 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
* Bug修复修复注册页面先显示关闭注册组件然后再显示注册组件 - 功能新增:允许在管理后台设置首页显示的导航菜单
* 功能新增:增加 Suno 文生歌曲功能 - Bug 修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加 - 功能新增:增加 Suno 文生歌曲功能
* 功能优化:在所有列表页面增加返回顶部按钮 - 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
- 功能优化:在所有列表页面增加返回顶部按钮
## v4.1.0 ## v4.1.0
* bug修复修复移动端修改聊天标题不生效的问题
* Bug修复修复用户注册不显示用户名的问题
* Bug修复修复管理后台拖动排序不生效的问题
* 功能优化:允许用户设置自定义首页背景图片
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
* 功能优化:优化聊天界面的用户上传文件的列表样式
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
- bug 修复:修复移动端修改聊天标题不生效的问题
- Bug 修复:修复用户注册不显示用户名的问题
- Bug 修复:修复管理后台拖动排序不生效的问题
- 功能优化:允许用户设置自定义首页背景图片
- 功能新增:**支持 AI 解读 PDF, Word, Excel 等文件**
- 功能优化:优化聊天界面的用户上传文件的列表样式
- 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
- 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
## v4.0.9 ## v4.0.9
* 环境升级:升级 Golang 到 go1.22.4
* 功能增加:接入微信商户号支付渠道 - 环境升级:升级 Golang 到 go1.22.4
* Bug修复修复前端页面菜单把页面撑开底部留白问题 - 功能增加:接入微信商户号支付渠道
* 功能优化:聊天页面自动根据内容调整输入框的高度 - Bug 修复:修复前端页面菜单把页面撑开,底部留白问题
* Bug修复修复Dalle绘图失败退回算力的问题 - 功能优化:聊天页面自动根据内容调整输入框的高度
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力 - Bug 修复:修复 Dalle 绘图失败退回算力的问题
* 功能优化:允许设置邮件验证码的抬头 - 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
* Bug修复修复免费模型不会记录聊天记录的bug - 功能优化:允许设置邮件验证码的抬头
* Bug修复修复聊天输入公式显示异常的Bug - Bug 修复:修复免费模型不会记录聊天记录的 bug
- Bug 修复:修复聊天输入公式显示异常的 Bug
## v4.0.8 ## v4.0.8
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
* 功能优化:当数据库更新失败的时候记录错误日志 - 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度 - 功能优化:当数据库更新失败的时候记录错误日志
* Bug修复修复移动端聊天页面模型切换不生效的Bug - 功能优化:聊天输入框会随着输入内容的增多自动调整高度
* 功能优化给PC端扫码支付增加签名验证和有效期验证 - Bug 修复:修复移动端聊天页面模型切换不生效的 Bug
* Bug修复修复支付码生成API权限控制的问题 - 功能优化:给 PC 端扫码支付增加签名验证和有效期验证
* Bug修复模型算力设置为0时不扣减用户算力并且不记录算力消费日志 - Bug 修复:修复支付码生成 API 权限控制的问题
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片 - Bug 修复:模型算力设置为 0 时,不扣减用户算力,并且不记录算力消费日志
* 功能新增H5端支持 Dalle 绘图 - 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
- 功能新增H5 端支持 Dalle 绘图
## v4.0.7 ## v4.0.7
* 功能优化:升级quic-go支持 Go1.21 - 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片 - Bug 修复:修复弹窗等于图形验证码一直验证失败的问题
* Bug修复修复弹窗等于图形验证码一直验证失败的问题 - 功能重构:重构前端 UI 页面,增加顶部导航
* 功能重构:重构前端 UI 页面,增加顶部导航 - 功能优化:优化 Vue 非父子组件之间的通信方式
* 功能优化:优化 Vue 非父子组件之间的通信方式 - 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
## v4.0.6 ## v4.0.6
* Bug修复修复PC端画廊页面的瀑布流组件样式错乱问题 - Bug 修复:修复 PC 端画廊页面的瀑布流组件样式错乱问题
* 功能新增:给思维导图增加 ToolBar实现思维导图的放大缩小和定位 - 功能新增:给思维导图增加 ToolBar实现思维导图的放大缩小和定位
* Bug修复修复思维导图不扣费的Bug - Bug 修复:修复思维导图不扣费的 Bug
* Bug修复修复管理后台角色删除失败的Bug - Bug 修复:修复管理后台角色删除失败的 Bug
* Bug修复兼容最新版秋叶SD懒人包的 SD API新增 scheduler 参数 - Bug 修复:兼容最新版秋叶 SD 懒人包的 SD API新增 scheduler 参数
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY - 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
* Bug修复修复注册用户提示注册人数达到上限的 Bug - Bug 修复:修复注册用户提示注册人数达到上限的 Bug
* 功能优化将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件 - 功能优化:将 MJ,SD,Dall 绘画页面的任务列表全改成瀑布流组件
## v4.0.5 ## v4.0.5
* 功能优化:已授权系统在后台显示授权信息 - 功能优化:已授权系统在后台显示授权信息
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误 - 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
* 功能优化:优化首页登录注册页面的 UI - 功能优化:优化首页登录注册页面的 UI
* BUG修复修复License验证的逻辑漏洞 - BUG 修复:修复 License 验证的逻辑漏洞
* Bug修复后台添加用户的时候密码规则限制跟前台注册保持一致 - Bug 修复:后台添加用户的时候密码规则限制跟前台注册保持一致
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题 - 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
* 功能新增:移动端新增 DALL-E 绘画功能 - 功能新增:移动端新增 DALL-E 绘画功能
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题 - 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
* 功能新增:移动支持免登录预览功能 - 功能新增:移动支持免登录预览功能
* Bug修复解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题 - Bug 修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
* Bug修复修复部分中转 API 模型会出现第一输出的字符被淹没的Bug - Bug 修复:修复部分中转 API 模型会出现第一输出的字符被淹没的 Bug
## v4.0.4 ## v4.0.4
* Bug修复修复统一千问第二句不回复的问题 - Bug 修复:修复统一千问第二句不回复的问题
* 功能优化MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度 - 功能优化MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
* 功能新增Dalle AI 绘画功能实现 - 功能新增Dalle AI 绘画功能实现
* Bug修复修复思维导图格式乱码问题 - Bug 修复:修复思维导图格式乱码问题
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题 - 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
* 功能新增:支持从应用列表直接和某个应用对话 - 功能新增:支持从应用列表直接和某个应用对话
* 功能优化优化算力日志的页面和首页的UI - 功能优化:优化算力日志的页面和首页的 UI
* 功能新增:支持思维导图导出 PNG 图片下载 - 功能新增:支持思维导图导出 PNG 图片下载
## v4.0.3 ## v4.0.3
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型 - 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
* Bug修复兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug - Bug 修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
* Bug修复修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题 - Bug 修复:修复 MidJourney 在任务超时后出现后面的任务覆盖前面任务的问题
* 功能新增:支持上传图片和视觉模型 - 功能新增:支持上传图片和视觉模型
* 功能优化:优化聊天页面的复制代码按钮样式乱码 - 功能优化:优化聊天页面的复制代码按钮样式乱码
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图 - 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
* 功能新增支持为角色绑定对话模型比如绑定某个角色只能用GPT3.5或者 GPT4 - 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用 GPT3.5 或者 GPT4
* 功能新增:支持为模型绑定 API KEY比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。 - 功能新增:支持为模型绑定 API KEY比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
* 功能新增:支持管理后台 Logo 修改 - 功能新增:支持管理后台 Logo 修改
## 4.0.2 ## 4.0.2
* 功能新增:支持前端菜单可以配置 - 功能新增:支持前端菜单可以配置
* 功能优化:在登录和注册界面标题显示软件版本号 - 功能优化:在登录和注册界面标题显示软件版本号
* 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数 - 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题 - 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
* Bug修复解决因为图片上传使用相对路径而导致融图失败的问题。 - Bug 修复:解决因为图片上传使用相对路径而导致融图失败的问题。
* 功能新增:手机端支持 Stable-Diffusion 绘画 - 功能新增:手机端支持 Stable-Diffusion 绘画
* 功能新增:管理后台登录页面增加行为验证码,防止爆破 - 功能新增:管理后台登录页面增加行为验证码,防止爆破
## v4.0.1 ## v4.0.1
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口SDAPI 兼容各种 stable-diffusion - 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口SDAPI 兼容各种 stable-diffusion
发行版,稳定性更强一些 发行版,稳定性更强一些
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API兼容 - 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API兼容
MJ-Plus 中转 MJ-Plus 中转
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力 - 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
* Bug修复修复 iphone 手机无法通过图形验证码的Bug使用滑动验证码替换 - Bug 修复:修复 iphone 手机无法通过图形验证码的 Bug使用滑动验证码替换
* Bug修复修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug - Bug 修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的 Bug
## v4.0.0 ## v4.0.0
非兼容版本重大重构引入算力概念将系统中所有的能力AI对话MJ绘画SD绘画DALL绘画全部使用算力来兑换。 非兼容版本重大重构引入算力概念将系统中所有的能力AI 对话MJ 绘画SD 绘画DALL 绘画)全部使用算力来兑换。
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 只要你的算力值余额不为 0你就可以进行任何操作。比如一次 GPT3.5 对话消耗 1 个单位算力,一次 GPT4 对话消耗 10 个算力。一次 MJ
对话消耗15个算力... 对话消耗 15 个算力...
* 功能重构:重构整体系统,全部采用算力来进行结算 - 功能重构:重构整体系统,全部采用算力来进行结算
* 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽 - 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
* 功能优化:移动端聊天页面图片支持预览和放大功能 - 功能优化:移动端聊天页面图片支持预览和放大功能
* 功能优化MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题 - 功能优化MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
* 功能优化:**PC端不登录也可以预览功能只有在发起操作的时候才需要登录** - 功能优化:**PC 端不登录也可以预览功能,只有在发起操作的时候才需要登录**
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能 - 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
* 功能新增支持H5支付 - 功能新增:支持 H5 支付
* 功能优化:支持数学公式的识别和美化输出 - 功能优化:支持数学公式的识别和美化输出
* 功能新增:新增算力消费日志功能 - 功能新增:新增算力消费日志功能
* 功能优化:整合 XXL-JOB 实现订单清理每日算力派发VIP 算力重置等任务 - 功能优化:整合 XXL-JOB 实现订单清理每日算力派发VIP 算力重置等任务
* 功能新增:管理后台新增7日内新增用户和新增订单统计 - 功能新增:管理后台新增 7 日内新增用户和新增订单统计
## v3.2.7 ## v3.2.7
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能 - 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能 - 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
* Bug修复修复 issue [ - Bug 修复:修复 issue [
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532) 管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
* 功能优化:在对话和聊天记录表中新增冗余字段 model存储对话模型 - 功能优化:在对话和聊天记录表中新增冗余字段 model存储对话模型
* Bug修复IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144) - Bug 修复IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
* Bug修复重新生成按钮功能失效问题 - Bug 修复:重新生成按钮功能失效问题
* Bug修复对话输入HTML标签不显示的问题 - Bug 修复:对话输入 HTML 标签不显示的问题
* 功能优化gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY - 功能优化gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
* 功能新增:新增删除文件功能 - 功能新增:新增删除文件功能
* Bug修复解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载 - Bug 修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
* 功能新增:后台管理新怎对话查看和检索功能 - 功能新增:后台管理新怎对话查看和检索功能
## v3.2.6 ## v3.2.6
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号 - 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
* 功能优化:兼用旧版本微信收款消息解析 - 功能优化:兼用旧版本微信收款消息解析
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源 - 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片 - 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码 - 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
* 功能新增:新增网站公告,可以在管理后台自定义配置 - 功能新增:新增网站公告,可以在管理后台自定义配置
* 功能新增:新增阿里通义千问大模型支持 - 功能新增:新增阿里通义千问大模型支持
* Bug修复修复 MJ 放大任务失败时候 img_call 会增加的 Bug - Bug 修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
* 功能优化新增虎皮椒和PayJS订单状态校验功能增加安全性 - 功能优化:新增虎皮椒和 PayJS 订单状态校验功能,增加安全性
* Bug修复修复微信转账交易 ID 提取失败 Bug - Bug 修复:修复微信转账交易 ID 提取失败 Bug
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug - 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
* 功能新增:新增短信宝短信平台发送平台集成 - 功能新增:新增短信宝短信平台发送平台集成
## v3.2.5 ## v3.2.5
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。 - 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 - 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
Plus 账号了!!! Plus 账号了!!!
* 功能优化:增强 markdown 图片和引用块解析。 - 功能优化:增强 markdown 图片和引用块解析。
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。 - 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
* 功能优化function call 兼用中转 API。 - 功能优化function call 兼用中转 API。
* Bug修复修复部分已知的 Bug。 - Bug 修复:修复部分已知的 Bug。
## v3.2.4.1 ## v3.2.4.1
* 功能新增:新增 PayJs 支付通道 - 功能新增:新增 PayJs 支付通道
* Bug修复紧急修复后台添加用户失败问题 - Bug 修复:紧急修复后台添加用户失败问题
* Bug修复紧急修复使用中转 API-KEY 无法绘图的问题 - Bug 修复:紧急修复使用中转 API-KEY 无法绘图的问题
* Bug修复允许用户关闭手机和邮箱注册通道移除验证码依赖 - Bug 修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
## v3.2.4 ## v3.2.4
* 功能新增:重磅更新,支持邮箱注册 - 功能新增:重磅更新,支持邮箱注册
* 功能优化:优化函数调用授权 - 功能优化:优化函数调用授权
* 功能优化:给用户表新增 nickname 字段 - 功能优化:给用户表新增 nickname 字段
* 功能优化:管理后台给聊天角色增加启用/禁用开关 - 功能优化:管理后台给聊天角色增加启用/禁用开关
* Bug修复SD绘画出现重复扣减绘图次数 - Bug 修复SD 绘画出现重复扣减绘图次数
* 功能优化:优化聊天对话导出样式,适应移动端 - 功能优化:优化聊天对话导出样式,适应移动端
* 功能新增:众筹核销可以选择兑换对话还是绘图的额度 - 功能新增:众筹核销可以选择兑换对话还是绘图的额度
* Bug修复修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92) - Bug 修复:修复[从历史记录获取 reply 有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
* Bug修复修复 MidJourney 绘图任务调度Bug为 task_id 建议唯一索引 - Bug 修复:修复 MidJourney 绘图任务调度 Bug为 task_id 建议唯一索引
* 功能重构:重构了 API KEY模块支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。 - 功能重构:重构了 API KEY 模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
## v3.2.3 ## v3.2.3
* 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现 - 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
* 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数 - 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
* Bug修复修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73) - Bug 修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
* Bug修复确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道 - Bug 修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
* 功能新增:新增删除绘图任何和图片功能 - 功能新增:新增删除绘图任何和图片功能
* Bug修复修复虎皮椒支付二维码重复扫码时报错问题 - Bug 修复:修复虎皮椒支付二维码重复扫码时报错问题
* 功能优化:自动将 AI 绘画中的中文提示词翻译成英文 - 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
* 功能优化优化AI绘画的大图压缩算法新增图片缓存 - 功能优化:优化 AI 绘画的大图压缩算法,新增图片缓存
* 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率 - 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
* Bug修复修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71) - Bug 修复:修复[Azure Api 更换 api-version 参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
* Bug修复修复科大讯飞 V1.5 API 请求失败的问题 - Bug 修复:修复科大讯飞 V1.5 API 请求失败的问题
* Bug修复绘图失败后自动恢复用户的剩余绘图次数 - Bug 修复:绘图失败后,自动恢复用户的剩余绘图次数
* 功能新增:为移动端新增 SD 绘图功能,分享功能 - 功能新增:为移动端新增 SD 绘图功能,分享功能
## v3.2.2 ## v3.2.2
* 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务 - 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
* 功能新增AI绘画页面增加翻译和重写提示词功能 - 功能新增AI 绘画页面增加翻译和重写提示词功能
* 功能优化OSS上传组件支持在 Bucket 下设置二级目录 - 功能优化OSS 上传组件支持在 Bucket 下设置二级目录
* Bug修复修复阿里云 OSS 访问路径错误 - Bug 修复:修复阿里云 OSS 访问路径错误
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket - 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
## v3.2.1 ## v3.2.1
* 功能优化:切换角色和模型的时候自动创建新的对话 - 功能优化:切换角色和模型的时候自动创建新的对话
* Bug修复修复文件上传失败No such file bug - Bug 修复:修复文件上传失败 No such file bug
* 功能新增MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数 - 功能新增MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
* Bug修复[PC端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59) - Bug 修复:[PC 端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
* 功能新增:增加 arm64 架构打包脚本 - 功能新增:增加 arm64 架构打包脚本
* 功能新增:支持 dall-e3 绘图的 API 地址自定义配置 - 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
* 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道 - 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
## v3.2.0 ## v3.2.0
* 功能新增:新增邀请注册功能 - 功能新增:新增邀请注册功能
* 功能优化增加中间件自动对HTTP请求的参数去掉首尾空格 - 功能优化:增加中间件自动对 HTTP 请求的参数去掉首尾空格
* 功能优化:增加中间件自动为大图片生成缩略图 - 功能优化:增加中间件自动为大图片生成缩略图
* 功能优化MidJourney 页面图片加载优化,实现图片预览懒加载 - 功能优化MidJourney 页面图片加载优化,实现图片预览懒加载
* 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件 - 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
* Bug修复修复阿里云 OSS 域名设置不起做用的bug - Bug 修复:修复阿里云 OSS 域名设置不起做用的 bug
* Bug修复修复MidJourney绘图失败后重复添加到队列的问题 - Bug 修复:修复 MidJourney 绘图失败后重复添加到队列的问题
## v3.1.9 ## v3.1.9
* 功能新增:增加讯飞星火大模型 v3.0 支持 - 功能新增:增加讯飞星火大模型 v3.0 支持
* 功能新增:新增找回密码功能 - 功能新增:新增找回密码功能
* 功能新增:支持 Markdown 代码复制功能 - 功能新增:支持 Markdown 代码复制功能
* Bug修复: xxl-job 任务调度失败的 Bug - Bug 修复: xxl-job 任务调度失败的 Bug
* 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font - 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font
* Bug修复Stable-Diffusion 绘画成功之后没有扣减用户画图次数 - Bug 修复Stable-Diffusion 绘画成功之后没有扣减用户画图次数
* 功能优化:优化会员充值页面 ItemList 组件 - 功能优化:优化会员充值页面 ItemList 组件
* 功能优化:给首页 Logo 增加链接 - 功能优化:给首页 Logo 增加链接
* Bug修复[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51) - Bug 修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51)
* Bug修复聊天上下文失效问题 - Bug 修复:聊天上下文失效问题
* 功能优化:关闭注册时显示联系管理员二维码 - 功能优化:关闭注册时显示联系管理员二维码
* 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能 - 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能
* Bug修复后台启用用户 VIP 不生效问题 - Bug 修复:后台启用用户 VIP 不生效问题
* 功能优化:充值支付页面的支付说明文字可以后台配置 - 功能优化:充值支付页面的支付说明文字可以后台配置
* Bug修复ChatGLM百度文心科大讯飞模型输出代码不换行问题 - Bug 修复ChatGLM百度文心科大讯飞模型输出代码不换行问题
## v3.1.8 ## v3.1.8
1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道 1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道
2. Bug修复修复 MidJourney API 参数版本更新导致调用失败问题 2. Bug 修复:修复 MidJourney API 参数版本更新导致调用失败问题
3. Bug修复修复 Stable Diffusion 调用后没有更新绘图调用次数问题 3. Bug 修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题
4. Bug修复修复七牛云上传报错 expired token 4. Bug 修复:修复七牛云上传报错 expired token
5. Bug修复修复高权重模型导致的对话次数为负数的漏洞 5. Bug 修复:修复高权重模型导致的对话次数为负数的漏洞
6. 功能优化:将聊天报错信息定义为统一常量,方便修改 6. 功能优化:将聊天报错信息定义为统一常量,方便修改
7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式 7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式
8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单 8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单
## v3.1.7 ## v3.1.7
1. 功能新增支持文心4.0 AI 模型 1. 功能新增:支持文心 4.0 AI 模型
2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型 2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型
3. 功能新增模型新增权重字段不同的模型每次调用耗费的点数可以设置不同比如GPT4GPT3.510倍 3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如 GPT4GPT3.510
4. 功能新增:新增系统配置关闭 AI 模型的函数功能 4. 功能新增:新增系统配置关闭 AI 模型的函数功能
5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式 5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式
## v3.1.6 ## v3.1.6
1. 功能新增新增AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品 1. 功能新增:新增 AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品
2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用 2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用
3. 功能优化:优化瀑布流组件的页面布局 3. 功能优化:优化瀑布流组件的页面布局
4. 功能优化:新注册用户成功之后自动登录 4. 功能优化:新注册用户成功之后自动登录
@@ -296,55 +389,55 @@
2. 功能新增:新增科大讯飞星火大模型 API 接入支持 2. 功能新增:新增科大讯飞星火大模型 API 接入支持
3. 功能重构:将 chat_handler 的所有功能实现放入单独的包中 3. 功能重构:将 chat_handler 的所有功能实现放入单独的包中
4. 功能新增:新增系统配置 `enabled_function` 用于启用和关闭函数功能 4. 功能新增:新增系统配置 `enabled_function` 用于启用和关闭函数功能
5. Bug修复修复管理后台更新 API Key 失败的 Bug 5. Bug 修复:修复管理后台更新 API Key 失败的 Bug
6. Bug修复修复新建的对话无法更新对话标题的 Bug 6. Bug 修复:修复新建的对话无法更新对话标题的 Bug
7. 功能优化:其他一些小的体验优化工作 7. 功能优化:其他一些小的体验优化工作
## v3.1.4 ## v3.1.4
1. 功能新增:新增阿里云 OSS 图片上传实现目前已支持本地存储七牛云Minio和阿里云 OSS 四种存储介质。 1. 功能新增:新增阿里云 OSS 图片上传实现目前已支持本地存储七牛云Minio 和阿里云 OSS 四种存储介质。
2. 功能新增:**增加 Stable Diffusion 绘画功能页面**。 2. 功能新增:**增加 Stable Diffusion 绘画功能页面**。
3. 功能重构:将 [chatgpt-plus-exts](https://github.com/yangjian102621/chatgpt-plus-exts) 合并到本项目,部署更加简单,无需部署两个项目了。 3. 功能重构:将 [chatgpt-plus-exts](https://github.com/yangjian102621/chatgpt-plus-exts) 合并到本项目,部署更加简单,无需部署两个项目了。
4. Bug修复修复[用户注册报错BUG #37](https://github.com/yangjian102621/chatgpt-plus/issues/37)。 4. Bug 修复:修复[用户注册报错 BUG #37](https://github.com/yangjian102621/chatgpt-plus/issues/37)。
5. Bug修复修复 MidJourney API 接口升级导致图片文保存失败的 Bug。 5. Bug 修复:修复 MidJourney API 接口升级导致图片文保存失败的 Bug。
6. 功能优化:增加阿里云短信服务配置项 `Sign``CodeTempId` 用来配置自己的短信签名和短信验证码模版 ID。 6. 功能优化:增加阿里云短信服务配置项 `Sign``CodeTempId` 用来配置自己的短信签名和短信验证码模版 ID。
7. 功能优化:添加系统配置用来设置自定义的众筹微信收款二维码。 7. 功能优化:添加系统配置用来设置自定义的众筹微信收款二维码。
8. 功能优化:优化绘画页面的弹窗样式和页面布局。 8. 功能优化:优化绘画页面的弹窗样式和页面布局。
## v3.1.3 ## v3.1.3
1. 页面重构:重后 Home 页面拆分成聊天MJ绘画SD 绘画,应用广场等多个功能菜单。 1. 页面重构:重后 Home 页面拆分成聊天MJ 绘画SD 绘画,应用广场等多个功能菜单。
2. 功能新增:新增 MidJourney 专业绘画页面,开放更高级的 MJ 绘画姿势。 2. 功能新增:新增 MidJourney 专业绘画页面,开放更高级的 MJ 绘画姿势。
3. 功能优化:采用队列的方式控制绘画任务并发,简化任务回调通知逻辑,给任务回调加锁。 3. 功能优化:采用队列的方式控制绘画任务并发,简化任务回调通知逻辑,给任务回调加锁。
4. 功能优化:精简用户表字段,删除用户名和昵称,只保留手机号。 4. 功能优化:精简用户表字段,删除用户名和昵称,只保留手机号。
5. 功能优化:优化文件上传服务工厂实现,只创建激活的 Uploader 服务,节省资源。 5. 功能优化:优化文件上传服务工厂实现,只创建激活的 Uploader 服务,节省资源。
6. Bug修复修复 JWT token 有效期计算错误的 Bug。 6. Bug 修复:修复 JWT token 有效期计算错误的 Bug。
## v3.1.2 ## v3.1.2
1. 功能新增:新增七牛云 OSS 实现目前已支持三种文件上传服务Local, Minio, QiNiu OSS。 1. 功能新增:新增七牛云 OSS 实现目前已支持三种文件上传服务Local, Minio, QiNiu OSS。
2. 功能新增:新增桌面版,使用 electron 套壳网页版。 2. 功能新增:新增桌面版,使用 electron 套壳网页版。
3. Bug修复自动去除众筹核销时候转账单号中的空格防止复制的时候多复制了空格。 3. Bug 修复:自动去除众筹核销时候转账单号中的空格,防止复制的时候多复制了空格。
4. 功能优化ChatPlus.vue 页面支持通过 chat_id path variable 来定位到指定的聊天。 4. 功能优化ChatPlus.vue 页面支持通过 chat_id path variable 来定位到指定的聊天。
5. 功能优化:取消导出聊天页面的授权验证 5. 功能优化:取消导出聊天页面的授权验证
6. 功能优化:所有路由跳转都使用绝对路径 6. 功能优化:所有路由跳转都使用绝对路径
## v3.1.1 ## v3.1.1
紧急修复版本采用弹窗的方式显示验证码解决验证码在低分辨率下被掩盖的Bug 紧急修复版本,采用弹窗的方式显示验证码,解决验证码在低分辨率下被掩盖的 Bug
## v3.1.0(大版本更新) ## v3.1.0(大版本更新)
1. 功能重构:将聊天模型独立拆分,以便支持多平台模型,目前已经内置支持 OPenAIAzure 以及 1. 功能重构:将聊天模型独立拆分,以便支持多平台模型,目前已经内置支持 OPenAIAzure 以及
ChatGLM用户可以在这两个平台的模型中随意切换体验不同的模型聊天。 ChatGLM用户可以在这两个平台的模型中随意切换体验不同的模型聊天。
2. 功能重构:重写系统 API 授权机制,使用 JWT 替换传统的 session 会话授权,使得 API 授权变得更加灵活。 2. 功能重构:重写系统 API 授权机制,使用 JWT 替换传统的 session 会话授权,使得 API 授权变得更加灵活。
3. 功能重构重构文件夹上传服务支持多种文件上传存储handler目前已经实现本地存储和 minio oss 存储。 3. 功能重构:重构文件夹上传服务,支持多种文件上传存储 handler目前已经实现本地存储和 minio oss 存储。
4. 功能优化:更新头像自动删除旧的图片资源。 4. 功能优化:更新头像自动删除旧的图片资源。
5. 功能优化:将应用日志在终端输出的同时存盘,方便 docker 部署查看日志。 5. 功能优化:将应用日志在终端输出的同时存盘,方便 docker 部署查看日志。
6. 功能新增:允许用户配置自己的 OPenAIAzure 以及 ChatGLM API KEY。 6. 功能新增:允许用户配置自己的 OPenAIAzure 以及 ChatGLM API KEY。
7. 功能优化:优化移动版的行为验证码样式,修复低分辨率显示器验证码被遮挡的 Bug 7. 功能优化:优化移动版的行为验证码样式,修复低分辨率显示器验证码被遮挡的 Bug
8. 升级 gin, element-plusredis 组件到最新版本。 8. 升级 gin, element-plusredis 组件到最新版本。
9. Bug修复修复若干已知的的 Bug 9. Bug 修复:修复若干已知的的 Bug
## v3.0.7 ## v3.0.7
@@ -354,7 +447,7 @@
4. 功能新增:支持导出聊天记录为 PDF 文件。 4. 功能新增:支持导出聊天记录为 PDF 文件。
5. 功能优化:在后台 dashboard 页面新增统计今日众筹收入。 5. 功能优化:在后台 dashboard 页面新增统计今日众筹收入。
6. 功能优化:支持用户设置默认的 GPT 模型 6. 功能优化:支持用户设置默认的 GPT 模型
7. Bug修复修复若干已知的的 Bug 7. Bug 修复:修复若干已知的的 Bug
## v3.0.6 ## v3.0.6
@@ -362,8 +455,8 @@
2. 管理后台:新增重置用户密码功能 2. 管理后台:新增重置用户密码功能
3. 管理后台:支持关闭注册功能,新增添加用户功能,适用于内部使用场景 3. 管理后台:支持关闭注册功能,新增添加用户功能,适用于内部使用场景
4. 管理后台:新增仪表盘页面,统计当天的新增用户,新增会话数据,以及 Token 消耗 4. 管理后台:新增仪表盘页面,统计当天的新增用户,新增会话数据,以及 Token 消耗
5. Bug修复修复注册页面验证码不显示 Bug 5. Bug 修复:修复注册页面验证码不显示 Bug
6. Bug修复优化上下文 Token 计算算法,修复聊天上下文超出限制时循环发送消息的 Bug 6. Bug 修复:优化上下文 Token 计算算法,修复聊天上下文超出限制时循环发送消息的 Bug
7. 功能修正:允许用户使用手机号码登录 7. 功能修正:允许用户使用手机号码登录
8. 功能优化:更新系统配置后同步更新服务端内存变量数据 8. 功能优化:更新系统配置后同步更新服务端内存变量数据
9. 功能优化:优化打包脚本,减少容器镜像大小 9. 功能优化:优化打包脚本,减少容器镜像大小
@@ -421,5 +514,5 @@
4. 新增聊天设置功能,用户可以导入自己的 API KEY 4. 新增聊天设置功能,用户可以导入自己的 API KEY
5. 保存聊天记录,支持聊天上下文。 5. 保存聊天记录,支持聊天上下文。
6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。 6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。
7. 引入 ip2region 组件记录用户的登录IP和地址。 7. 引入 ip2region 组件,记录用户的登录 IP 和地址。
8. 支持会话搜索过滤。 8. 支持会话搜索过滤。

View File

@@ -1,92 +1,19 @@
# GeekAI # GeekAI-PLUS
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务 基于 GeekAI 项目开发的高级版增加了很多高级功能比如思维导图Dalle 绘画等。**高级版源码不会一次性开放,只提供镜像给大家免费使用**,源码会逐步逐步按照版同步迁移到[社区版GeekAI](https://github.com/yangjian102621/geekai)。所以如果大家想要二次开发,请移步去社区版
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure, ## 演示站点
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。 [Geek-AI 创作系统](https://www.geekai.me)
主要特性: ## 文档地址
[Geek-AI 文档](https://www.geekai.me/docs/)
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。 ## 部署
- 基于 Websocket 实现,完美的打字机体验 1. 安装 docker 和 docker-compose 程序,这个自行解决。
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。 2. 直接在项目根目录运行启动命令:
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。 ```shell
- 支持 Suno 文生音乐 docker-compose up -d
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。 ```
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
绘画函数插件。
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
- [x] 更友好的 UI 界面
- [x] 支持 Dall-E 文生图功能
- [x] 支持文生思维导图
- [x] 支持为模型绑定指定的 API KEY支持为角色绑定指定的模型等功能
- [x] 支持网站 Logo 版权等信息的修改
## 功能截图 ## 功能截图
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。 请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
### 体验地址
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
## 快速部署
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
## 使用须知
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
## 项目地址
* Github 地址https://github.com/yangjian102621/geekai
* 码云地址https://gitee.com/blackfox/geekai
## 客户端下载
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
## TODOLIST
* [ ] 支持基于知识库的 AI 问答
* [ ] 文生视频,文生歌曲功能
* [ ] 微信支付功能
## 项目文档
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。**
![微信名片](docs/imgs/wx.png)
## 参与贡献
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合
### Commit 类型
* feat: 新特性或功能
* fix: 缺陷修复
* docs: 文档更新
* style: 代码风格或者组件样式更新
* refactor: 代码重构,不引入新功能和缺陷修复
* opt: 性能优化
* chore: 一些不涉及到功能变动的小提交,比如修改文字表述,修改注释等
## 打赏
如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~
![打赏](docs/imgs/donate.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/geekai&type=Date)

View File

@@ -3,11 +3,11 @@ NAME := geekai
all: amd64 arm64 all: amd64 arm64
amd64: amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w" -o bin/$(NAME)-linux main.go
.PHONY: amd64 .PHONY: amd64
arm64: arm64:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -ldflags "-s -w" -o bin/$(NAME)-linux main.go
.PHONY: arm64 .PHONY: arm64
clean: clean:

View File

@@ -3,8 +3,6 @@ ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local" MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录 StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = ""
WeChatBot = false
TikaHost = "http://tika:9998" TikaHost = "http://tika:9998"
[Session] [Session]
@@ -65,23 +63,6 @@ TikaHost = "http://tika:9998"
SubDir = "" SubDir = ""
Domain = "" Domain = ""
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[[MjPlusConfigs]]
Enabled = false
ApiURL = "https://api.chat-plus.net"
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动 [XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务 Enabled = false # 是否启用 XXL JOB 服务
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址 ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
@@ -90,6 +71,15 @@ TikaHost = "http://tika:9998"
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
RegistryKey = "chatgpt-plus" # 任务注册 key RegistryKey = "chatgpt-plus" # 任务注册 key
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
# 支付宝商户支付
[AlipayConfig] [AlipayConfig]
Enabled = false # 启用支付宝支付通道 Enabled = false # 启用支付宝支付通道
SandBox = false # 是否启用沙盒模式 SandBox = false # 是否启用沙盒模式
@@ -99,31 +89,13 @@ TikaHost = "http://tika:9998"
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书 PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书 AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书 RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
# 虎皮椒支付
[HuPiPayConfig] [HuPiPayConfig]
Enabled = false Enabled = false
Name = "wechat"
AppId = "" AppId = ""
AppSecret = "" AppSecret = ""
ApiURL = "https://api.xunhupay.com" ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
[JPayConfig] # PayJs 支付配置
Enabled = false
Name = "wechat" # 请不要改动
AppId = "" # 商户 ID
PrivateKey = "" # 秘钥
ApiURL = "https://payjs.cn"
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
# 微信商户支付 # 微信商户支付
[WechatPayConfig] [WechatPayConfig]
@@ -133,6 +105,11 @@ TikaHost = "http://tika:9998"
SerialNo = "" # API 证书序列号 SerialNo = "" # API 证书序列号
PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内 PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的 ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的
NotifyURL = "https://ai.r9it.com/api/payment/wechat/notify" # 支付成功异步回调地址,域名改成自己的
ReturnURL = "" # 支付成功同步回调地址
# 易支付
[GeekPayConfig]
Enabled = true
AppId = "" # 商户ID
PrivateKey = "" # 商户私钥
ApiURL = "https://pay.geekai.cn"
Methods = ["alipay", "wxpay", "qqpay", "jdpay", "douyin", "paypal"] # 支持的支付方式

View File

@@ -15,12 +15,6 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/nfnt/resize"
"golang.org/x/image/webp"
"gorm.io/gorm"
"image" "image"
"image/jpeg" "image/jpeg"
"io" "io"
@@ -29,6 +23,13 @@ import (
"runtime/debug" "runtime/debug"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/nfnt/resize"
"golang.org/x/image/webp"
"gorm.io/gorm"
) )
type AppServer struct { type AppServer struct {
@@ -51,9 +52,9 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
func (s *AppServer) Init(debug bool, client *redis.Client) { func (s *AppServer) Init(debug bool, client *redis.Client) {
if debug { // 调试模式允许跨域请求 API if debug { // 调试模式允许跨域请求 API
s.Debug = debug s.Debug = debug
s.Engine.Use(corsMiddleware())
logger.Info("Enabled debug mode") logger.Info("Enabled debug mode")
} }
s.Engine.Use(corsMiddleware())
s.Engine.Use(staticResourceMiddleware()) s.Engine.Use(staticResourceMiddleware())
s.Engine.Use(authorizeMiddleware(s, client)) s.Engine.Use(authorizeMiddleware(s, client))
s.Engine.Use(parameterHandlerMiddleware()) s.Engine.Use(parameterHandlerMiddleware())
@@ -65,13 +66,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
func (s *AppServer) Run(db *gorm.DB) error { func (s *AppServer) Run(db *gorm.DB) error {
// load system configs // load system configs
var sysConfig model.Config var sysConfig model.Config
res := db.Where("marker", "system").First(&sysConfig) err := db.Where("marker", "system").First(&sysConfig).Error
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to load system config: %v", err)
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return fmt.Errorf("failed to decode system config: %v", err)
} }
logger.Infof("http://%s", s.Config.Listen) logger.Infof("http://%s", s.Config.Listen)
return s.Engine.Run(s.Config.Listen) return s.Engine.Run(s.Config.Listen)
@@ -101,9 +102,9 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段 //允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization") c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
// 允许浏览器(客户端)可以解析的头部 (重要) // 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间 //设置缓存时间
c.Header("Access-Control-Max-Age", "172800") c.Header("Access-Control-Max-Age", "172800")
//允许客户端传递校验信息比如 cookie (重要) //允许客户端传递校验信息比如 cookie (重要)
@@ -127,12 +128,19 @@ func corsMiddleware() gin.HandlerFunc {
// 用户授权验证 // 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
var tokenString string var tokenString string
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/") isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader) tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" { } else if clientProtocols != "" { // Websocket 连接
tokenString = c.Query("token") // 解析子协议内容
protocols := strings.Split(clientProtocols, ",")
if protocols[0] == "realtime" {
tokenString = strings.TrimSpace(protocols[1][25:])
} else if protocols[0] == "token" {
tokenString = strings.TrimSpace(protocols[1])
}
} else { } else {
tokenString = c.GetHeader(types.UserAuthHeader) tokenString = c.GetHeader(types.UserAuthHeader)
} }
@@ -201,33 +209,29 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/admin/logout" || c.Request.URL.Path == "/api/admin/logout" ||
c.Request.URL.Path == "/api/admin/login/captcha" || c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" || c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/user/session" ||
c.Request.URL.Path == "/api/chat/history" || c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" || c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" || c.Request.URL.Path == "/api/app/list" ||
c.Request.URL.Path == "/api/app/type/list" ||
c.Request.URL.Path == "/api/app/list/user" ||
c.Request.URL.Path == "/api/model/list" || c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" || c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/dall/imgWall" || c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" || c.Request.URL.Path == "/api/menu/list" ||
c.Request.URL.Path == "/api/markMap/client" || c.Request.URL.Path == "/api/markMap/client" ||
c.Request.URL.Path == "/api/payment/alipay/notify" ||
c.Request.URL.Path == "/api/payment/hupipay/notify" ||
c.Request.URL.Path == "/api/payment/payjs/notify" ||
c.Request.URL.Path == "/api/payment/wechat/notify" ||
c.Request.URL.Path == "/api/payment/doPay" || c.Request.URL.Path == "/api/payment/doPay" ||
c.Request.URL.Path == "/api/payment/payWays" || c.Request.URL.Path == "/api/payment/payWays" ||
c.Request.URL.Path == "/api/suno/client" || c.Request.URL.Path == "/api/suno/detail" ||
c.Request.URL.Path == "/api/suno/Detail" ||
c.Request.URL.Path == "/api/suno/play" || c.Request.URL.Path == "/api/suno/play" ||
c.Request.URL.Path == "/api/download" ||
c.Request.URL.Path == "/api/dall/models" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") || strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") || strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
@@ -367,6 +371,7 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 直接输出图像数据流 // 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes()) c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求 c.Abort() // 中断请求
} }
c.Next() c.Next()
} }

View File

@@ -38,7 +38,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload", BasePath: "./static/upload",
}, },
}, },
WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
} }
} }

View File

@@ -9,14 +9,15 @@ package types
// ApiRequest API 请求实体 // ApiRequest API 请求实体
type ApiRequest struct { type ApiRequest struct {
Model string `json:"model,omitempty"` // 兼容百度文心一言 Model string `json:"model,omitempty"`
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言 MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"` MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
Messages []interface{} `json:"messages,omitempty"` Stream bool `json:"stream,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM Messages []interface{} `json:"messages,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ResponseFormat interface{} `json:"response_format,omitempty"` // 响应格式
ToolChoice string `json:"tool_choice,omitempty"` ToolChoice string `json:"tool_choice,omitempty"`
@@ -52,16 +53,17 @@ type Delta struct {
// ChatSession 聊天会话对象 // ChatSession 聊天会话对象
type ChatSession struct { type ChatSession struct {
SessionId string `json:"session_id"` UserId uint `json:"user_id"`
UserId uint `json:"user_id"` ClientIP string `json:"client_ip"` // 客户端 IP
ClientIP string `json:"client_ip"` // 客户端 IP ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 Model ChatModel `json:"model"` // GPT 模型
Model ChatModel `json:"model"` // GPT 模型 Start int64 `json:"start"` // 开始请求时间戳
Tools []int `json:"tools"` // 工具函数列表
Stream bool `json:"stream"` // 是否采用流式输出
} }
type ChatModel struct { type ChatModel struct {
Id uint `json:"id"` Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"` Name string `json:"name"`
Value string `json:"value"` Value string `json:"value"`
Power int `json:"power"` Power int `json:"power"`
@@ -91,7 +93,7 @@ const (
PowerConsume = PowerType(2) // 消费 PowerConsume = PowerType(2) // 消费
PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款 PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
PowerInvite = PowerType(4) // 邀请奖励 PowerInvite = PowerType(4) // 邀请奖励
PowerReward = PowerType(5) // 众筹 PowerRedeem = PowerType(5) // 众筹
PowerGift = PowerType(6) // 系统赠送 PowerGift = PowerType(6) // 系统赠送
) )
@@ -103,9 +105,12 @@ func (t PowerType) String() string {
return "消费" return "消费"
case PowerRefund: case PowerRefund:
return "退款" return "退款"
case PowerReward: case PowerRedeem:
return "众筹" return "兑换"
case PowerGift:
return "赠送"
case PowerInvite:
return "邀请"
} }
return "其他" return "其他"
} }

View File

@@ -17,15 +17,17 @@ var ErrConClosed = errors.New("connection Closed")
// WsClient websocket client // WsClient websocket client
type WsClient struct { type WsClient struct {
Id string
Conn *websocket.Conn Conn *websocket.Conn
lock sync.Mutex lock sync.Mutex
mt int mt int
Closed bool Closed bool
} }
func NewWsClient(conn *websocket.Conn) *WsClient { func NewWsClient(conn *websocket.Conn, id string) *WsClient {
return &WsClient{ return &WsClient{
Conn: conn, Conn: conn,
Id: id,
lock: sync.Mutex{}, lock: sync.Mutex{},
mt: 2, // fixed bug for 'Invalid UTF-8 in text frame' mt: 2, // fixed bug for 'Invalid UTF-8 in text frame'
Closed: false, Closed: false,

View File

@@ -12,28 +12,23 @@ import (
) )
type AppConfig struct { type AppConfig struct {
Path string `toml:"-"` Path string `toml:"-"`
Listen string Listen string
Session Session Session Session
AdminSession Session AdminSession Session
ProxyURL string ProxyURL string
MysqlDns string // mysql 连接地址 MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录 StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息 Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config SmtpConfig SmtpConfig // 邮件发送配置
MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig XXLConfig XXLConfig
AlipayConfig AlipayConfig // 支付宝支付渠道配置 AlipayConfig AlipayConfig // 支付宝支付渠道配置
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置 HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
SmtpConfig SmtpConfig // 邮件发送配置 GeekPayConfig GeekPayConfig // GEEK 支付配置
JPayConfig JPayConfig // payjs 支付配置
WechatPayConfig WechatPayConfig // 微信支付渠道配置 WechatPayConfig WechatPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址 TikaHost string // TiKa 服务器地址
} }
@@ -53,27 +48,6 @@ type ApiConfig struct {
Token string Token string
} }
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct { type AlipayConfig struct {
Enabled bool // 是否启用该支付通道 Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境 SandBox bool // 是否沙盒环境
@@ -83,8 +57,8 @@ type AlipayConfig struct {
PublicKey string // 用户公钥文件路径 PublicKey string // 用户公钥文件路径
AlipayPublicKey string // 支付宝公钥文件路径 AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径 RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知地址
ReturnURL string // 支付成功返回地址 ReturnURL string // 同步通知地址
} }
type WechatPayConfig struct { type WechatPayConfig struct {
@@ -94,29 +68,27 @@ type WechatPayConfig struct {
SerialNo string // 商户证书的证书序列号 SerialNo string // 商户证书的证书序列号
PrivateKey string // 用户私钥文件路径 PrivateKey string // 用户私钥文件路径
ApiV3Key string // API V3 秘钥 ApiV3Key string // API V3 秘钥
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知地址
ReturnURL string // 支付成功返回地址
} }
type HuPiPayConfig struct { //虎皮椒第四方支付配置 type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道 Enabled bool // 是否启用该支付通道
Name string // 支付名称wechat/alipay
AppId string // App ID AppId string // App ID
AppSecret string // app 密钥 AppSecret string // app 密钥
ApiURL string // 支付网关 ApiURL string // 支付网关
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知地址
ReturnURL string // 支付成功返回地址 ReturnURL string // 同步通知地址
} }
// JPayConfig PayJs 支付配置 // GeekPayConfig GEEK支付配置
type JPayConfig struct { type GeekPayConfig struct {
Enabled bool Enabled bool
Name string // 支付名称,默认 wechat AppId string // 商户 ID
AppId string // 商户 ID PrivateKey string // 私钥
PrivateKey string // 私钥 ApiURL string // API 网关
ApiURL string // API 网关 NotifyURL string // 异步通知地址
NotifyURL string // 异步回调地址 ReturnURL string // 同步通知地址
ReturnURL string // 支付成功返回地址 Methods []string // 支付方式
} }
type XXLConfig struct { // XXL 任务调度配置 type XXLConfig struct { // XXL 任务调度配置
@@ -156,31 +128,30 @@ func (c RedisConfig) Url() string {
} }
type SystemConfig struct { type SystemConfig struct {
Title string `json:"title,omitempty"` // 网站标题 Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题 AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` Logo string `json:"logo,omitempty"` // 圆形 Logo
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值 InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力 DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值 InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值 VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册 RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册 EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间 OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明 VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力 MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力 SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力 DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力 SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
@@ -188,8 +159,15 @@ type SystemConfig struct {
ContextDeep int `json:"context_deep,omitempty"` ContextDeep int `json:"context_deep,omitempty"`
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词 SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
ICP string `json:"icp"` // ICP 备案号
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
} }

View File

@@ -16,7 +16,7 @@ type MKey interface {
string | int | uint string | int | uint
} }
type MValue interface { type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []Message *WsClient | *ChatSession | context.CancelFunc | []interface{}
} }
type LMap[K MKey, T MValue] struct { type LMap[K MKey, T MValue] struct {
lock sync.RWMutex lock sync.RWMutex

View File

@@ -22,3 +22,18 @@ type OrderRemark struct {
Price float64 `json:"price"` Price float64 `json:"price"`
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
} }
var PayMethods = map[string]string{
"alipay": "支付宝商号",
"wechat": "微信商号",
"hupi": "虎皮椒",
"geek": "易支付",
}
var PayNames = map[string]string{
"alipay": "支付宝",
"wxpay": "微信支付",
"qqpay": "QQ钱包",
"jdpay": "京东支付",
"douyin": "抖音支付",
"paypal": "PayPal支付",
}

View File

@@ -24,30 +24,35 @@ const (
// MjTask MidJourney 任务 // MjTask MidJourney 任务
type MjTask struct { type MjTask struct {
Id uint `json:"id"` Id uint `json:"id"` // 任务ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"` // 中转任务ID
ImgArr []string `json:"img_arr"` ClientId string `json:"client_id"`
ChannelId string `json:"channel_id"` ImgArr []string `json:"img_arr"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
NegPrompt string `json:"neg_prompt,omitempty"` NegPrompt string `json:"neg_prompt,omitempty"`
Params string `json:"full_prompt"` Params string `json:"full_prompt"`
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"` MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"` MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"` ChannelId string `json:"channel_id"` // 渠道ID用来区分是哪个渠道创建的任务一个任务的 create 和 action 操作必须要再同一个渠道
Mode string `json:"mode"` // 绘画模式relax, fast, turbo
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
} }
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` ClientId string `json:"client_id"`
Params SdTaskParams `json:"params"` UserId int `json:"user_id"`
RetryCount int `json:"retry_count"` Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
} }
type SdTaskParams struct { type SdTaskParams struct {
ClientId string `json:"client_id"` // 客户端ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词 NegPrompt string `json:"neg_prompt"` // 反向提示词
@@ -68,29 +73,63 @@ type SdTaskParams struct {
// DallTask DALL-E task // DallTask DALL-E task
type DallTask struct { type DallTask struct {
JobId uint `json:"job_id"` ClientId string `json:"client_id"`
UserId uint `json:"user_id"` ModelId uint `json:"model_id"`
Prompt string `json:"prompt"` ModelName string `json:"model_name"`
N int `json:"n"` Id uint `json:"id"`
Quality string `json:"quality"` UserId uint `json:"user_id"`
Size string `json:"size"` Prompt string `json:"prompt"`
Style string `json:"style"` N int `json:"n"`
Quality string `json:"quality"`
Power int `json:"power"` Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
} }
type SunoTask struct { type SunoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"` Id uint `json:"id"`
Channel string `json:"channel"` Channel string `json:"channel"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Type int `json:"type"` Type int `json:"type"`
TaskId string `json:"task_id"`
Title string `json:"title"` Title string `json:"title"`
RefTaskId string `json:"ref_task_id"` RefTaskId string `json:"ref_task_id,omitempty"`
RefSongId string `json:"ref_song_id"` RefSongId string `json:"ref_song_id,omitempty"`
Prompt string `json:"prompt"` // 提示词/歌词 Prompt string `json:"prompt"` // 提示词/歌词
Tags string `json:"tags"` Tags string `json:"tags"`
Model string `json:"model"` Model string `json:"model"`
Instrumental bool `json:"instrumental"` // 是否纯音乐 Instrumental bool `json:"instrumental"` // 是否纯音乐
ExtendSecs int `json:"extend_secs"` // 延长秒杀 ExtendSecs int `json:"extend_secs,omitempty"` // 延长秒杀
SongId string `json:"song_id,omitempty"` // 合并歌曲ID
AudioURL string `json:"audio_url"` // 用户上传音频地址
}
const (
VideoLuma = "luma"
VideoRunway = "runway"
VideoCog = "cog"
)
type VideoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
}
type VideoParams struct {
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
Loop bool `json:"loop"` // 是否循环参考图
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址
Model string `json:"model"` // 使用哪个模型生成视频
Radio string `json:"radio"` // 视频尺寸
Style string `json:"style"` // 风格
Duration int `json:"duration"` // 视频时长(秒)
} }

View File

@@ -17,21 +17,48 @@ type BizVo struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
// WsMessage Websocket message // ReplyMessage 对话回复消息结构
type WsMessage struct { type ReplyMessage struct {
Type WsMsgType `json:"type"` // 消息类别start, end, img Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat
Content interface{} `json:"content"` ClientId string `json:"clientId"` // 客户端ID
Type WsMsgType `json:"type"` // 消息类别
Body interface{} `json:"body"`
} }
type WsMsgType string type WsMsgType string
type WsChannel string
const ( const (
WsStart = WsMsgType("start") MsgTypeText = WsMsgType("text") // 输出内容
WsMiddle = WsMsgType("middle") MsgTypeEnd = WsMsgType("end")
WsEnd = WsMsgType("end") MsgTypeErr = WsMsgType("error")
WsErr = WsMsgType("error") MsgTypePing = WsMsgType("ping") // 心跳消息
ChPing = WsChannel("ping")
ChChat = WsChannel("chat")
ChMj = WsChannel("mj")
ChSd = WsChannel("sd")
ChDall = WsChannel("dall")
ChSuno = WsChannel("suno")
ChLuma = WsChannel("luma")
) )
// InputMessage 对话输入消息结构
type InputMessage struct {
Channel WsChannel `json:"channel"` // 消息频道
Type WsMsgType `json:"type"` // 消息类别
Body interface{} `json:"body"`
}
type ChatMessage struct {
Tools []int `json:"tools,omitempty"` // 允许调用工具列表
Stream bool `json:"stream,omitempty"` // 是否采用流式输出
RoleId int `json:"role_id"`
ModelId int `json:"model_id"`
ChatId string `json:"chat_id"`
Content string `json:"content"`
}
type BizCode int type BizCode int
const ( const (

View File

@@ -8,7 +8,6 @@ require (
github.com/BurntSushi/toml v1.1.0 github.com/BurntSushi/toml v1.1.0
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
github.com/eatmoreapple/openwechat v1.2.1
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang-jwt/jwt/v5 v5.0.0
@@ -30,7 +29,6 @@ require (
github.com/go-pay/gopay v1.5.101 github.com/go-pay/gopay v1.5.101
github.com/google/go-tika v0.3.1 github.com/google/go-tika v0.3.1
github.com/microcosm-cc/bluemonday v1.0.26 github.com/microcosm-cc/bluemonday v1.0.26
github.com/mojocn/base64Captcha v1.3.6
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.3.1 github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0 github.com/syndtr/goleveldb v1.0.0
@@ -45,9 +43,13 @@ require (
github.com/go-pay/util v0.0.2 // indirect github.com/go-pay/util v0.0.2 // indirect
github.com/go-pay/xlog v0.0.2 // indirect github.com/go-pay/xlog v0.0.2 // indirect
github.com/go-pay/xtime v0.0.2 // indirect github.com/go-pay/xtime v0.0.2 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gorilla/css v1.0.0 // indirect github.com/gorilla/css v1.0.0 // indirect
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99 // indirect
github.com/howeyc/fsnotify v0.9.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a // indirect
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect

View File

@@ -28,8 +28,6 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
@@ -84,8 +82,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
@@ -104,11 +100,15 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99 h1:A6qlLfihaWef15viqtecCz4XknZcgjgD7mEuhu7bHEc=
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99/go.mod h1:ukFDwXV66bGV7JnfyxFKuKiVp4zH4orBKXML+VCSrhI=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/howeyc/fsnotify v0.9.0 h1:0gtV5JmOKH4A8SsFxG2BczSeXWWPvcMT0euZt5gDAxY=
github.com/howeyc/fsnotify v0.9.0/go.mod h1:41HzSPxBGeFRQKEEwgh49TRw/nKBsYZ2cF1OzPjSJsA=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I= github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI= github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
@@ -141,6 +141,9 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs= github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs=
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs= github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
@@ -157,8 +160,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.6 h1:gZEKu1nsKpttuIAQgWHO+4Mhhls8cAKyiV2Ew03H+Tw=
github.com/mojocn/base64Captcha v1.3.6/go.mod h1:i5CtHvm+oMbj1UzEPXaA8IH/xHFZ3DGY3Wh3dBpZ28E=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
@@ -176,6 +177,10 @@ github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:Ff
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU= github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a h1:Tg4E4cXPZSZyd3H1tJlYo6ZreXV0ZJvE/lorNqyw1AU=
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a/go.mod h1:9Or9aIl95Kp43zONcHd5tLZGKXb9iLx0pZjau0uJ5zg=
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99 h1:+X7Gb40b5Bl3v5+3MiGK8Jhemjp65MHc+nkVCfq1Yfc=
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99/go.mod h1:2LLTtftTZSdAPR/iVyennXZDLZOYzyDn+T0qEKJ8eSw=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -220,12 +225,10 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
@@ -267,7 +270,6 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.13.0/go.mod h1:6mmbMOeV28HuMTgA6OSRkdXKYw/t5W9Uwn2Yv1r3Yxk=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
@@ -300,6 +302,7 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -323,7 +326,6 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=

View File

@@ -14,6 +14,7 @@ import (
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
@@ -28,33 +29,49 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1 const SuperManagerID = 1
type ManagerHandler struct { type ManagerHandler struct {
handler.BaseHandler handler.BaseHandler
redis *redis.Client redis *redis.Client
captcha *service.CaptchaService
} }
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler { func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, captcha *service.CaptchaService) *ManagerHandler {
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client} return &ManagerHandler{
BaseHandler: handler.BaseHandler{DB: db, App: app},
redis: client,
captcha: captcha,
}
} }
// Login 登录 // Login 登录
func (h *ManagerHandler) Login(c *gin.Context) { func (h *ManagerHandler) Login(c *gin.Context) {
var data Manager var data struct {
Username string `json:"username"`
Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var manager model.AdminUser var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager) res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil { if res.Error != nil {

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
@@ -15,6 +16,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -53,17 +55,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey.Enabled = data.Enabled apiKey.Enabled = data.Enabled
apiKey.ProxyURL = data.ProxyURL apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name apiKey.Name = data.Name
res := h.DB.Save(&apiKey) err := h.DB.Save(&apiKey).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
var keyVo vo.ApiKey var keyVo vo.ApiKey
err := utils.CopyObject(apiKey, &keyVo) err = utils.CopyObject(apiKey, &keyVo)
if err != nil { if err != nil {
resp.ERROR(c, "数据拷贝失败!") resp.ERROR(c, fmt.Sprintf("拷贝数据失败:%v", err))
return return
} }
keyVo.Id = apiKey.Id keyVo.Id = apiKey.Id
@@ -71,20 +72,18 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
resp.SUCCESS(c, keyVo) resp.SUCCESS(c, keyVo)
} }
// List 获取 API KEY 列表
func (h *ApiKeyHandler) List(c *gin.Context) { func (h *ApiKeyHandler) List(c *gin.Context) {
status := h.GetBool(c, "status") status := h.GetBool(c, "status")
t := h.GetTrim(c, "type") t := c.Query("type")
platform := h.GetTrim(c, "platform")
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if status { if status {
session = session.Where("enabled", true) session = session.Where("enabled", true)
} }
if t != "" { if t != "" {
session = session.Where("type", t) types := strings.Split(t, "|")
} session = session.Where("type IN ?", types)
if platform != "" {
session = session.Where("platform", platform)
} }
var items []model.ApiKey var items []model.ApiKey
@@ -119,10 +118,9 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -135,10 +133,9 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) {
return return
} }
res := h.DB.Where("id", id).Delete(&model.ApiKey{}) err := h.DB.Where("id", id).Delete(&model.ApiKey{}).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
@@ -21,16 +22,16 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
type ChatRoleHandler struct { type ChatAppHandler struct {
handler.BaseHandler handler.BaseHandler
} }
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// Save 创建或者更新某个角色 // Save 创建或者更新某个角色
func (h *ChatRoleHandler) Save(c *gin.Context) { func (h *ChatAppHandler) Save(c *gin.Context) {
var data vo.ChatRole var data vo.ChatRole
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -45,11 +46,16 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
role.Id = data.Id role.Id = data.Id
if data.CreatedAt > 0 { if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0) role.CreatedAt = time.Unix(data.CreatedAt, 0)
} else {
err = h.DB.Where("marker", data.Key).First(&role).Error
if err == nil {
resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
return
}
} }
res := h.DB.Save(&role) err = h.DB.Save(&role).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
// 填充 ID 数据 // 填充 ID 数据
@@ -58,7 +64,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
resp.SUCCESS(c, data) resp.SUCCESS(c, data)
} }
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatAppHandler) List(c *gin.Context) {
var items []model.ChatRole var items []model.ChatRole
var roles = make([]vo.ChatRole, 0) var roles = make([]vo.ChatRole, 0)
res := h.DB.Order("sort_num ASC").Find(&items) res := h.DB.Order("sort_num ASC").Find(&items)
@@ -69,13 +75,18 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// initialize model mane for role // initialize model mane for role
modelIds := make([]int, 0) modelIds := make([]int, 0)
typeIds := make([]int, 0)
for _, v := range items { for _, v := range items {
if v.ModelId > 0 { if v.ModelId > 0 {
modelIds = append(modelIds, v.ModelId) modelIds = append(modelIds, v.ModelId)
} }
if v.Tid > 0 {
typeIds = append(typeIds, v.Tid)
}
} }
modelNameMap := make(map[int]string) modelNameMap := make(map[int]string)
typeNameMap := make(map[int]string)
if len(modelIds) > 0 { if len(modelIds) > 0 {
var models []model.ChatModel var models []model.ChatModel
tx := h.DB.Where("id IN ?", modelIds).Find(&models) tx := h.DB.Where("id IN ?", modelIds).Find(&models)
@@ -85,6 +96,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
} }
} }
} }
if len(typeIds) > 0 {
var appTypes []model.AppType
tx := h.DB.Where("id IN ?", typeIds).Find(&appTypes)
if tx.Error == nil {
for _, m := range appTypes {
typeNameMap[int(m.Id)] = m.Name
}
}
}
for _, v := range items { for _, v := range items {
var role vo.ChatRole var role vo.ChatRole
@@ -94,6 +114,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
role.CreatedAt = v.CreatedAt.Unix() role.CreatedAt = v.CreatedAt.Unix()
role.UpdatedAt = v.UpdatedAt.Unix() role.UpdatedAt = v.UpdatedAt.Unix()
role.ModelName = modelNameMap[role.ModelId] role.ModelName = modelNameMap[role.ModelId]
role.TypeName = typeNameMap[role.Tid]
roles = append(roles, role) roles = append(roles, role)
} }
} }
@@ -102,7 +123,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
} }
// Sort 更新角色排序 // Sort 更新角色排序
func (h *ChatRoleHandler) Sort(c *gin.Context) { func (h *ChatAppHandler) Sort(c *gin.Context) {
var data struct { var data struct {
Ids []uint `json:"ids"` Ids []uint `json:"ids"`
Sorts []int `json:"sorts"` Sorts []int `json:"sorts"`
@@ -114,10 +135,9 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }
@@ -125,7 +145,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *ChatRoleHandler) Set(c *gin.Context) { func (h *ChatAppHandler) Set(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
Filed string `json:"filed"` Filed string `json:"filed"`
@@ -137,16 +157,15 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *ChatRoleHandler) Remove(c *gin.Context) { func (h *ChatAppHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id <= 0 { if id <= 0 {

View File

@@ -0,0 +1,148 @@
package admin
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatAppTypeHandler struct {
handler.BaseHandler
}
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或更新App类型
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Icon string `json:"icon"`
SortNum int `json:"sort_num"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id == 0 { // for add
err := h.DB.Where("name", data.Name).First(&model.AppType{}).Error
if err == nil {
resp.ERROR(c, "当前分类已经存在")
return
}
err = h.DB.Create(&model.AppType{
Name: data.Name,
Icon: data.Icon,
Enabled: data.Enabled,
SortNum: data.SortNum,
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
} else { // for update
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).Updates(map[string]interface{}{
"name": data.Name,
"icon": data.Icon,
"enabled": data.Enabled,
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var appType vo.AppType
err = utils.CopyObject(v, &appType)
if err != nil {
continue
}
appType.Id = v.Id
appType.CreatedAt = v.CreatedAt.Unix()
appTypes = append(appTypes, appType)
}
resp.SUCCESS(c, appTypes)
}
// Remove 删除App类型
func (h *ChatAppTypeHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.AppType{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用|禁用
func (h *ChatAppTypeHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Sort 更新排序
func (h *ChatAppTypeHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
for index, id := range data.Ids {
err := h.DB.Model(&model.AppType{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}

View File

@@ -259,10 +259,9 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
// RemoveMessage 删除聊天记录 // RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) { func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}) err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}).Error
if tx.Error != nil { if err != nil {
logger.Error("error with update database", tx.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -43,6 +43,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
Temperature float32 `json:"temperature"` // 模型温度 Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id,omitempty"` KeyId int `json:"key_id,omitempty"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
Type string `json:"type"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -65,7 +66,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
item.MaxContext = data.MaxContext item.MaxContext = data.MaxContext
item.Temperature = data.Temperature item.Temperature = data.Temperature
item.KeyId = data.KeyId item.KeyId = data.KeyId
item.Type = data.Type
var res *gorm.DB var res *gorm.DB
if data.Id > 0 { if data.Id > 0 {
res = h.DB.Save(&item) res = h.DB.Save(&item)
@@ -147,10 +148,9 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -168,10 +168,9 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }
@@ -186,10 +185,9 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
return return
} }
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}) err := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -12,8 +12,6 @@ import (
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
"geekai/service/mj"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
@@ -28,16 +26,12 @@ type ConfigHandler struct {
handler.BaseHandler handler.BaseHandler
levelDB *store.LevelDB levelDB *store.LevelDB
licenseService *service.LicenseService licenseService *service.LicenseService
mjServicePool *mj.ServicePool
sdServicePool *sd.ServicePool
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{ return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db}, BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB, levelDB: levelDB,
mjServicePool: mjPool,
sdServicePool: sdPool,
licenseService: licenseService, licenseService: licenseService,
} }
} }
@@ -147,57 +141,69 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
resp.SUCCESS(c, license) resp.SUCCESS(c, license)
} }
// GetAppConfig 获取内置配置 // FixData 修复数据
func (h *ConfigHandler) GetAppConfig(c *gin.Context) { func (h *ConfigHandler) FixData(c *gin.Context) {
resp.SUCCESS(c, gin.H{ resp.ERROR(c, "当前升级版本没有数据需要修正!")
"mj_plus": h.App.Config.MjPlusConfigs, return
"mj_proxy": h.App.Config.MjProxyConfigs, //var fixed bool
"sd": h.App.Config.SdConfigs, //version := "data_fix_4.1.4"
}) //err := h.levelDB.Get(version, &fixed)
} //if err == nil || fixed {
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
// SaveDrawingConfig 保存AI绘画配置 // return
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) { //}
var data struct { //tx := h.DB.Begin()
Sd []types.StableDiffusionConfig `json:"sd"` //var users []model.User
MjPlus []types.MjPlusConfig `json:"mj_plus"` //err = tx.Find(&users).Error
MjProxy []types.MjProxyConfig `json:"mj_proxy"` //if err != nil {
} // resp.ERROR(c, err.Error())
if err := c.ShouldBindJSON(&data); err != nil { // return
resp.ERROR(c, types.InvalidArgs) //}
return //for _, user := range users {
} // if user.Email != "" || user.Mobile != "" {
// continue
changed := false // }
if configChanged(data.Sd, h.App.Config.SdConfigs) { // if utils.IsValidEmail(user.Username) {
logger.Debugf("SD 配置变动了") // user.Email = user.Username
h.App.Config.SdConfigs = data.Sd // } else if utils.IsValidMobile(user.Username) {
h.sdServicePool.InitServices(data.Sd) // user.Mobile = user.Username
changed = true // }
} // err = tx.Save(&user).Error
// if err != nil {
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) { // resp.ERROR(c, err.Error())
logger.Debugf("MidJourney 配置变动了") // tx.Rollback()
h.App.Config.MjPlusConfigs = data.MjPlus // return
h.App.Config.MjProxyConfigs = data.MjProxy // }
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy) //}
changed = true //
} //var orders []model.Order
//err = h.DB.Find(&orders).Error
if changed { //if err != nil {
err := core.SaveConfig(h.App.Config) // resp.ERROR(c, err.Error())
if err != nil { // return
resp.ERROR(c, "更新配置文档失败!") //}
return //for _, order := range orders {
} // if order.PayWay == "支付宝" {
} // order.PayWay = "alipay"
// order.PayType = "alipay"
resp.SUCCESS(c) // } else if order.PayWay == "微信支付" {
// order.PayWay = "wechat"
} // order.PayType = "wxpay"
// } else if order.PayWay == "hupi" {
func configChanged(c1 interface{}, c2 interface{}) bool { // order.PayType = "wxpay"
encode1 := utils.JsonEncode(c1) // }
encode2 := utils.JsonEncode(c2) // err = tx.Save(&order).Error
return utils.Md5(encode1) != utils.Md5(encode2) // if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//tx.Commit()
//err = h.levelDB.Put(version, true)
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//resp.SUCCESS(c)
} }

View File

@@ -60,13 +60,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
stats.Tokens += item.Tokens stats.Tokens += item.Tokens
} }
// 众筹收入
var rewards []model.Reward
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入 // 订单收入
var orders []model.Order var orders []model.Order
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders) res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
@@ -101,13 +94,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens) historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
} }
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单 // 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders) res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders { for _, item := range orders {

View File

@@ -69,10 +69,9 @@ func (h *FunctionHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -102,10 +101,9 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.DB.Delete(&model.Function{Id: uint(id)}) err := h.DB.Delete(&model.Function{Id: uint(id)}).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }

View File

@@ -0,0 +1,254 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ImageHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type imageQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// MjList Midjourney 任务列表
func (h *ImageHandler) MjList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var list []model.MidJourneyJob
var items = make([]vo.MidJourneyJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.MidJourneyJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// SdList Stable Diffusion 任务列表
func (h *ImageHandler) SdList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SdJob{}).Count(&total)
var list []model.SdJob
var items = make([]vo.SdJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SdJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// DallList DALL-E 任务列表
func (h *ImageHandler) DallList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.DallJob{}).Count(&total)
var list []model.DallJob
var items = make([]vo.DallJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.DallJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *ImageHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, imgURL string
var power, userId, progress int
switch tab {
case "mj":
var job model.MidJourneyJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "mid-journey"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "sd":
var job model.SdJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "stable-diffusion"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "dall":
var job model.DallJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "dall-e-3"
power = job.Power
userId = int(job.UserId)
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(imgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,200 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MediaHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler {
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type mediaQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// SunoList Suno 任务列表
func (h *MediaHandler) SunoList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SunoJob{}).Count(&total)
var list []model.SunoJob
var items = make([]vo.SunoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SunoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// LumaList Luma 视频任务列表
func (h *MediaHandler) LumaList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.VideoJob{}).Count(&total)
var list []model.VideoJob
var items = make([]vo.VideoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.VideoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
if job.VideoURL == "" {
job.VideoURL = job.WaterURL
}
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *MediaHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, fileURL string
var power, userId, progress int
switch tab {
case "suno":
var job model.SunoJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "suno"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("SUNO 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.AudioURL
break
case "luma":
var job model.VideoJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = job.Type
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("LUMA 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.VideoURL
if fileURL == "" {
fileURL = job.WaterURL
}
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(fileURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -41,17 +41,16 @@ func (h *MenuHandler) Save(c *gin.Context) {
return return
} }
res := h.DB.Save(&model.Menu{ err := h.DB.Save(&model.Menu{
Id: data.Id, Id: data.Id,
Name: data.Name, Name: data.Name,
Icon: data.Icon, Icon: data.Icon,
URL: data.URL, URL: data.URL,
SortNum: data.SortNum, SortNum: data.SortNum,
Enabled: data.Enabled, Enabled: data.Enabled,
}) }).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -85,10 +84,9 @@ func (h *MenuHandler) Enable(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) err := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -106,10 +104,9 @@ func (h *MenuHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }
@@ -121,10 +118,9 @@ func (h *MenuHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Menu{}) err := h.DB.Where("id", id).Delete(&model.Menu{}).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }

View File

@@ -15,6 +15,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -67,6 +68,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix() order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
if !ok {
payMethod = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
if !ok {
payName = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
list = append(list, order) list = append(list, order)
} else { } else {
logger.Error(err) logger.Error(err)
@@ -92,12 +103,33 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return return
} }
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}) err := h.DB.Where("id = ?", id).Delete(&model.Order{}).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *OrderHandler) Clear(c *gin.Context) {
var orders []model.Order
err := h.DB.Where("status <> ?", 2).Where("pay_time", 0).Find(&orders).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
deleteIds := make([]uint, 0)
for _, order := range orders {
// 只删除 15 分钟内的未支付订单
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
deleteIds = append(deleteIds, order.Id)
}
}
err = h.DB.Where("id IN ?", deleteIds).Delete(&model.Order{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -31,6 +31,7 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
func (h *PowerLogHandler) List(c *gin.Context) { func (h *PowerLogHandler) List(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Username string `json:"username"`
UserId uint `json:"userid"`
Type int `json:"type"` Type int `json:"type"`
Model string `json:"model"` Model string `json:"model"`
Date []string `json:"date"` Date []string `json:"date"`
@@ -49,6 +50,12 @@ func (h *PowerLogHandler) List(c *gin.Context) {
if data.Type > 0 { if data.Type > 0 {
session = session.Where("type", data.Type) session = session.Where("type", data.Type)
} }
if data.UserId > 0 {
session = session.Where("user_id", data.UserId)
}
if data.Username != "" {
session = session.Where("username", data.Username)
}
if len(data.Date) == 2 { if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00" start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00" end := data.Date[1] + " 00:00:00"

View File

@@ -55,17 +55,16 @@ func (h *ProductHandler) Save(c *gin.Context) {
if item.Id > 0 { if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0) item.CreatedAt = time.Unix(data.CreatedAt, 0)
} }
res := h.DB.Save(&item) err := h.DB.Save(&item).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
var itemVo vo.Product var itemVo vo.Product
err := utils.CopyObject(item, &itemVo) err = utils.CopyObject(item, &itemVo)
if err != nil { if err != nil {
resp.ERROR(c, "数据拷贝失败") resp.ERROR(c, "数据拷贝失败: "+err.Error())
return return
} }
itemVo.Id = item.Id itemVo.Id = item.Id
@@ -106,10 +105,9 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) err := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -127,10 +125,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }
@@ -142,10 +139,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Product{}) err := h.DB.Where("id", id).Delete(&model.Product{}).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }
} }

View File

@@ -0,0 +1,219 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/csv"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RedeemHandler struct {
handler.BaseHandler
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RedeemHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
code := c.Query("code")
status := h.GetInt(c, "status", -1)
session := h.DB.Session(&gorm.Session{})
if code != "" {
session = session.Where("code LIKE ?", "%"+code+"%")
}
if status >= 0 {
session = session.Where("redeemed_at", status)
}
var total int64
session.Model(&model.Redeem{}).Count(&total)
var redeems []model.Redeem
offset := (page - 1) * pageSize
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&redeems).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var items = make([]vo.Redeem, 0)
userIds := make([]uint, 0)
for _, v := range redeems {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
}
for _, v := range redeems {
var r vo.Redeem
err = utils.CopyObject(v, &r)
if err != nil {
continue
}
r.Id = v.Id
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
items = append(items, r)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
// Export 导出 CVS 文件
func (h *RedeemHandler) Export(c *gin.Context) {
var data struct {
Status int `json:"status"`
Ids []int `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
}
session := h.DB.Session(&gorm.Session{})
if data.Status >= 0 {
session = session.Where("redeemed_at", data.Status)
}
if len(data.Ids) > 0 {
session = session.Where("id IN ?", data.Ids)
}
var items []model.Redeem
err := session.Order("id DESC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 设置响应头,告诉浏览器这是一个附件,需要下载
c.Header("Content-Disposition", "attachment; filename=output.csv")
c.Header("Content-Type", "text/csv")
// 创建一个 CSV writer
writer := csv.NewWriter(c.Writer)
// 写入 CSV 文件的标题行
headers := []string{"名称", "兑换码", "算力", "创建时间"}
if err := writer.Write(headers); err != nil {
resp.ERROR(c, err.Error())
return
}
// 写入数据行
records := make([][]string, 0)
for _, item := range items {
records = append(records, []string{item.Name, item.Code, fmt.Sprintf("%d", item.Power), item.CreatedAt.Format("2006-01-02 15:04:05")})
}
for _, record := range records {
if err := writer.Write(record); err != nil {
resp.ERROR(c, err.Error())
return
}
}
// 确保所有数据都已写入响应
writer.Flush()
if err := writer.Error(); err != nil {
resp.ERROR(c, err.Error())
return
}
}
func (h *RedeemHandler) Create(c *gin.Context) {
var data struct {
Name string `json:"name"`
Power int `json:"power"`
Num int `json:"num"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
counter := 0
codes := make([]string, 0)
var errMsg = ""
if data.Num > 0 {
for i := 0; i < data.Num; i++ {
code, err := utils.GenRedeemCode(32)
if err != nil {
errMsg = err.Error()
continue
}
err = h.DB.Create(&model.Redeem{
Code: code,
Name: data.Name,
Power: data.Power,
Enabled: true,
}).Error
if err != nil {
errMsg = err.Error()
continue
}
codes = append(codes, code)
counter++
}
}
if counter == 0 {
resp.ERROR(c, errMsg)
return
}
resp.SUCCESS(c, gin.H{
"counter": counter,
})
}
func (h *RedeemHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Model(&model.Redeem{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *RedeemHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -1,81 +0,0 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RewardHandler struct {
handler.BaseHandler
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
var items []model.Reward
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
for _, v := range items {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
}
for _, v := range items {
var r vo.Reward
err := utils.CopyObject(v, &r)
if err != nil {
continue
}
r.Id = v.Id
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
r.UpdatedAt = v.UpdatedAt.Unix()
rewards = append(rewards, r)
}
}
resp.SUCCESS(c, rewards)
}
func (h *RewardHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -19,6 +19,8 @@ import (
"geekai/utils/resp" "geekai/utils/resp"
"time" "time"
"github.com/go-redis/redis/v8"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -26,10 +28,11 @@ import (
type UserHandler struct { type UserHandler struct {
handler.BaseHandler handler.BaseHandler
licenseService *service.LicenseService licenseService *service.LicenseService
redis *redis.Client
} }
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler { func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService, redisCli *redis.Client) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService} return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
} }
// List 用户列表 // List 用户列表
@@ -37,6 +40,8 @@ func (h *UserHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20) pageSize := h.GetInt(c, "page_size", 20)
username := h.GetTrim(c, "username") username := h.GetTrim(c, "username")
mobile := h.GetTrim(c, "mobile")
email := h.GetTrim(c, "email")
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
var items []model.User var items []model.User
@@ -47,9 +52,15 @@ func (h *UserHandler) List(c *gin.Context) {
if username != "" { if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%") session = session.Where("username LIKE ?", "%"+username+"%")
} }
if mobile != "" {
session = session.Where("mobile LIKE ?", "%"+mobile+"%")
}
if email != "" {
session = session.Where("email LIKE ?", "%"+email+"%")
}
session.Model(&model.User{}).Count(&total) session.Model(&model.User{}).Count(&total)
res := session.Offset(offset).Limit(pageSize).Find(&items) res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var user vo.User var user vo.User
@@ -73,6 +84,8 @@ func (h *UserHandler) Save(c *gin.Context) {
Id uint `json:"id"` Id uint `json:"id"`
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
ChatRoles []string `json:"chat_roles"` ChatRoles []string `json:"chat_roles"`
ChatModels []int `json:"chat_models"` ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"` ExpiredTime string `json:"expired_time"`
@@ -102,6 +115,8 @@ func (h *UserHandler) Save(c *gin.Context) {
} }
var oldPower = user.Power var oldPower = user.Power
user.Username = data.Username user.Username = data.Username
user.Email = data.Email
user.Mobile = data.Mobile
user.Status = data.Status user.Status = data.Status
user.Vip = data.Vip user.Vip = data.Vip
user.Power = data.Power user.Power = data.Power
@@ -109,7 +124,8 @@ func (h *UserHandler) Save(c *gin.Context) {
user.ChatModels = utils.JsonEncode(data.ChatModels) user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime) user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user) res = h.DB.Select("username", "mobile", "email", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil { if res.Error != nil {
logger.Error("error with update database", res.Error) logger.Error("error with update database", res.Error)
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
@@ -135,6 +151,13 @@ func (h *UserHandler) Save(c *gin.Context) {
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} }
// 如果禁用了用户,则将用户踢下线
if user.Status == false {
key := fmt.Sprintf("users/%v", user.Id)
if _, err := h.redis.Del(c, key).Result(); err != nil {
logger.Error("error with delete session: ", err)
}
}
} else { } else {
// 检查用户是否已经存在 // 检查用户是否已经存在
h.DB.Where("username", data.Username).First(&user) h.DB.Where("username", data.Username).First(&user)
@@ -147,6 +170,8 @@ func (h *UserHandler) Save(c *gin.Context) {
u := model.User{ u := model.User{
Username: data.Username, Username: data.Username,
Password: utils.GenPassword(data.Password, salt), Password: utils.GenPassword(data.Password, salt),
Mobile: data.Mobile,
Email: data.Email,
Avatar: "/images/avatar/user.png", Avatar: "/images/avatar/user.png",
Salt: salt, Salt: salt,
Power: data.Power, Power: data.Power,
@@ -168,8 +193,7 @@ func (h *UserHandler) Save(c *gin.Context) {
} }
if res.Error != nil { if res.Error != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, res.Error.Error())
resp.ERROR(c, "更新数据库失败")
return return
} }
@@ -205,33 +229,69 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
} }
func (h *UserHandler) Remove(c *gin.Context) { func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := c.Query("id")
if id <= 0 { ids := c.QueryArray("ids[]")
if id != "" {
ids = append(ids, id)
}
if len(ids) == 0 {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{}) tx := h.DB.Begin()
if res.Error != nil { var err error
resp.ERROR(c, "删除失败") for _, id = range ids {
// 删除用户
if err = tx.Where("id", id).Delete(&model.User{}).Error; err != nil {
break
}
// 删除聊天记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatItem{}).Error; err != nil {
break
}
// 删除聊天历史记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatMessage{}).Error; err != nil {
break
}
// 删除登录日志
if err = tx.Where("user_id = ?", id).Delete(&model.UserLoginLog{}).Error; err != nil {
break
}
// 删除算力日志
if err = tx.Where("user_id = ?", id).Delete(&model.PowerLog{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.InviteLog{}).Error; err != nil {
break
}
// 删除众筹日志
if err = tx.Where("user_id = ?", id).Delete(&model.Redeem{}).Error; err != nil {
break
}
// 删除绘图任务
if err = tx.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SdJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.DallJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SunoJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.VideoJob{}).Error; err != nil {
break
}
}
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return return
} }
tx.Commit()
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@@ -8,13 +8,13 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"errors"
"fmt"
"gorm.io/gorm" "gorm.io/gorm"
"strings" "strings"
@@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
} }
var user model.User var user model.User
res := h.DB.First(&user, userId) res := h.DB.Where("id", userId).First(&user)
// 更新缓存 // 更新缓存
if res.Error == nil { if res.Error == nil {
c.Set(types.LoginUserCache, user) c.Set(types.LoginUserCache, user)

View File

@@ -0,0 +1,44 @@
package handler
import (
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatAppTypeHandler struct {
BaseHandler
}
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var appType vo.AppType
err = utils.CopyObject(v, &appType)
if err != nil {
continue
}
appType.Id = v.Id
appType.CreatedAt = v.CreatedAt.Unix()
appTypes = append(appTypes, appType)
}
resp.SUCCESS(c, appTypes)
}

View File

@@ -1,4 +1,4 @@
package chatimpl package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -15,8 +15,6 @@ import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
@@ -33,136 +31,31 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
) )
var logger = logger2.GetLogger()
type ChatHandler struct { type ChatHandler struct {
handler.BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
licenseService *service.LicenseService licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
userService *service.UserService
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
return &ChatHandler{ return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db}, BaseHandler: BaseHandler{App: app, DB: db},
redis: redis, redis: redis,
uploadManager: manager, uploadManager: manager,
licenseService: licenseService, licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
ChatContexts: types.NewLMap[string, []types.Message](), ChatContexts: types.NewLMap[string, []interface{}](),
userService: userService,
} }
} }
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
modelId := h.GetInt(c, "model_id", 0)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
modelId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
return
}
session := &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
UserId: h.GetLoginUserId(c),
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
}
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Name: chatModel.Name,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId}
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
cancelFunc := h.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
h.ReqCancelFunc.Delete(sessionId)
}
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
ctx, cancel := context.WithCancel(context.Background())
h.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyMessage(client, err.Error())
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Infof("回答完毕: %v", message.Content)
}
}
}()
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
if !h.App.Debug { if !h.App.Debug {
defer func() { defer func() {
@@ -204,45 +97,54 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
var req = types.ApiRequest{ var req = types.ApiRequest{
Model: session.Model.Value, Model: session.Model.Value,
Stream: true, }
// 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") {
utils.SendChunkMsg(ws, "> AI 正在思考...\n")
req.Stream = session.Stream
session.Start = time.Now().Unix()
} else {
req.MaxTokens = session.Model.MaxTokens
req.Temperature = session.Model.Temperature
req.Stream = session.Stream
} }
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res = h.DB.Where("enabled", true).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
if len(tools) > 0 { if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
req.Tools = tools var items []model.Function
req.ToolChoice = "auto" res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
} }
} }
// 加载聊天上下文 // 加载聊天上下文
chatCtx := make([]types.Message, 0) chatCtx := make([]interface{}, 0)
messages := make([]types.Message, 0) messages := make([]interface{}, 0)
if h.App.SysConfig.EnableContext { if h.App.SysConfig.EnableContext {
if h.ChatContexts.Has(session.ChatId) { if h.ChatContexts.Has(session.ChatId) {
messages = h.ChatContexts.Get(session.ChatId) messages = h.ChatContexts.Get(session.ChatId)
@@ -270,8 +172,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model) tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens tokens += tks + promptTokens
for _, v := range messages { for i := len(messages) - 1; i >= 0; i-- {
tks, _ := utils.CalcTokens(v.Content, req.Model) v := messages[i]
tks, _ = utils.CalcTokens(utils.JsonEncode(v), req.Model)
// 上下文 token 超出了模型的最大上下文长度 // 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext { if tokens+tks >= session.Model.MaxContext {
break break
@@ -289,8 +192,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
logger.Debugf("聊天上下文:%+v", chatCtx) logger.Debugf("聊天上下文:%+v", chatCtx)
} }
reqMgs := make([]interface{}, 0) reqMgs := make([]interface{}, 0)
for _, m := range chatCtx {
reqMgs = append(reqMgs, m) for i := len(chatCtx) - 1; i >= 0; i-- {
reqMgs = append(reqMgs, chatCtx[i])
} }
fullPrompt := prompt fullPrompt := prompt
@@ -355,7 +259,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
logger.Debugf("%+v", req.Messages) logger.Debugf("%+v", req.Messages)
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws)
} }
// Tokens 统计 token 数量 // Tokens 统计 token 数量
@@ -442,7 +346,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
if err != nil { if err != nil {
return nil, err return nil, err
} }
logger.Debugf(utils.JsonEncode(req)) logger.Debugf("对话请求消息体:%+v", req)
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象 // 创建 HttpClient 请求对象
@@ -468,7 +372,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
} else { } else {
client = http.DefaultClient client = http.DefaultClient
} }
logger.Debugf("Sending %s request, Channel:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model) logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
// 更新API KEY 最后使用时间 // 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix()) h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
@@ -481,115 +385,112 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
if session.Model.Power > 0 { if session.Model.Power > 0 {
power = session.Model.Power power = session.Model.Power
} }
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
Type: types.PowerConsume,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
})
if err != nil {
logger.Error(err)
}
} }
func (h *ChatHandler) saveChatHistory( func (h *ChatHandler) saveChatHistory(
req types.ApiRequest, req types.ApiRequest,
prompt string, usage Usage,
contents []string,
message types.Message, message types.Message,
chatCtx []types.Message,
session *types.ChatSession, session *types.ChatSession,
role model.ChatRole, role model.ChatRole,
userVo vo.User, userVo vo.User,
promptCreatedAt time.Time, promptCreatedAt time.Time,
replyCreatedAt time.Time) { replyCreatedAt time.Time) {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息
if h.App.SysConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx := req.Messages // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.ChatContexts.Put(session.ChatId, chatCtx) h.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
// for prompt // for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model) var promptTokens, replyTokens, totalTokens int
if err != nil { if usage.PromptTokens > 0 {
logger.Error(err) promptTokens = usage.PromptTokens
} else {
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
} }
historyUserMsg := model.ChatMessage{ historyUserMsg := model.ChatMessage{
UserId: userVo.Id, UserId: userVo.Id,
ChatId: session.ChatId, ChatId: session.ChatId,
RoleId: role.Id, RoleId: role.Id,
Type: types.PromptMsg, Type: types.PromptMsg,
Icon: userVo.Avatar, Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt), Content: template.HTMLEscapeString(usage.Prompt),
Tokens: promptToken, Tokens: promptTokens,
UseContext: true, TotalTokens: promptTokens,
Model: req.Model, UseContext: true,
Model: req.Model,
} }
historyUserMsg.CreatedAt = promptCreatedAt historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg) err := h.DB.Save(&historyUserMsg).Error
if res.Error != nil { if err != nil {
logger.Error("failed to save prompt history message: ", res.Error) logger.Error("failed to save prompt history message: ", err)
} }
// for reply // for reply
// 计算本次对话消耗的总 token 数量 // 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model) if usage.CompletionTokens > 0 {
totalTokens := replyTokens + getTotalTokens(req) replyTokens = usage.CompletionTokens
totalTokens = usage.TotalTokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
totalTokens = replyTokens + getTotalTokens(req)
}
historyReplyMsg := model.ChatMessage{ historyReplyMsg := model.ChatMessage{
UserId: userVo.Id, UserId: userVo.Id,
ChatId: session.ChatId, ChatId: session.ChatId,
RoleId: role.Id, RoleId: role.Id,
Type: types.ReplyMsg, Type: types.ReplyMsg,
Icon: role.Icon, Icon: role.Icon,
Content: message.Content, Content: usage.Content,
Tokens: totalTokens, Tokens: replyTokens,
UseContext: true, TotalTokens: totalTokens,
Model: req.Model, UseContext: true,
Model: req.Model,
} }
historyReplyMsg.CreatedAt = replyCreatedAt historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg) err = h.DB.Create(&historyReplyMsg).Error
if res.Error != nil { if err != nil {
logger.Error("failed to save reply history message: ", res.Error) logger.Error("failed to save reply history message: ", err)
} }
// 更新用户算力 // 更新用户算力
if session.Model.Power > 0 { if session.Model.Power > 0 {
h.subUserPower(userVo, session, promptToken, replyTokens) h.subUserPower(userVo, session, promptTokens, replyTokens)
} }
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
if res.Error != nil { if err != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = userVo.Id chatItem.UserId = userVo.Id
chatItem.RoleId = role.Id chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 { if utf8.RuneCountInString(usage.Prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..." chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
} else { } else {
chatItem.Title = prompt chatItem.Title = usage.Prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.DB.Create(&chatItem) err = h.DB.Create(&chatItem).Error
if err != nil {
logger.Error("failed to save chat item: ", err)
}
} }
} }

View File

@@ -1,4 +1,4 @@
package chatimpl package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0) var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem var chats []model.ChatItem
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats) h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
if res.Error == nil { if len(chats) == 0 {
var roleIds = make([]uint, 0) resp.SUCCESS(c, items)
for _, chat := range chats { return
roleIds = append(roleIds, chat.RoleId) }
}
var roles []model.ChatRole
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
roleMap[role.Id] = role
}
for _, chat := range chats { var roleIds = make([]uint, 0)
var item vo.ChatItem var modelValues = make([]string, 0)
err := utils.CopyObject(chat, &item) for _, chat := range chats {
if err == nil { roleIds = append(roleIds, chat.RoleId)
item.Id = chat.Id modelValues = append(modelValues, chat.Model)
item.Icon = roleMap[chat.RoleId].Icon }
items = append(items, item)
}
}
}
var roles []model.ChatRole
var models []model.ChatModel
roleMap := make(map[uint]model.ChatRole)
modelMap := make(map[string]model.ChatModel)
h.DB.Where("id IN ?", roleIds).Find(&roles)
h.DB.Where("value IN ?", modelValues).Find(&models)
for _, role := range roles {
roleMap[role.Id] = role
}
for _, m := range models {
modelMap[m.Value] = m
}
for _, chat := range chats {
var item vo.ChatItem
err := utils.CopyObject(chat, &item)
if err == nil {
item.Id = chat.Id
item.Icon = roleMap[chat.RoleId].Icon
item.ModelId = modelMap[chat.Model].Id
items = append(items, item)
}
} }
resp.SUCCESS(c, items) resp.SUCCESS(c, items)
} }

View File

@@ -30,29 +30,25 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
func (h *ChatModelHandler) List(c *gin.Context) { func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0) var chatModels = make([]vo.ChatModel, 0)
var res *gorm.DB session := h.DB.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
t := c.Query("type") t := c.Query("type")
if t != "" { if t != "" {
session = session.Where("type", t) session = session.Where("type", t)
} }
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) { session = session.Where("open", true)
res = session.Where("open", true).Order("sort_num ASC").Find(&items) if h.IsLogin(c) {
} else {
user, _ := h.GetLoginUser(c) user, _ := h.GetLoginUser(c)
var models []int var models []int
err := utils.JsonDecode(user.ChatModels, &models) err := utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型 // 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where( if err == nil {
h.DB.Where("id IN ?", models).Or("open", true), session = session.Or("id IN ?", models)
).Order("sort_num ASC").Find(&items) }
} }
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var cm vo.ChatModel var cm vo.ChatModel

View File

@@ -1,4 +1,4 @@
package chatimpl package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -17,15 +17,41 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
req2 "github.com/imroc/req/v3"
"io" "io"
"strings" "strings"
"time" "time"
req2 "github.com/imroc/req/v3"
) )
type Usage struct {
Prompt string `json:"prompt,omitempty"`
Content string `json:"content,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIResVo struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage Usage `json:"usage"`
}
// OPenAI 消息发送实现 // OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage( func (h *ChatHandler) sendOpenAiMessage(
chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -37,7 +63,7 @@ func (h *ChatHandler) sendOpenAiMessage(
start := time.Now() start := time.Now()
var apiKey = model.ApiKey{} var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey) response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start)) logger.Info("HTTP请求完成耗时", time.Since(start))
if err != nil { if err != nil {
if strings.Contains(err.Error(), "context canceled") { if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt) return fmt.Errorf("用户取消了请求:%s", prompt)
@@ -49,17 +75,29 @@ func (h *ChatHandler) sendOpenAiMessage(
defer response.Body.Close() defer response.Body.Close()
} }
if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
}
contentType := response.Header.Get("Content-Type") contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") { if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间 replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息 // 循环读取 Chunk 消息
var message = types.Message{} var message = types.Message{Role: "assistant"}
var contents = make([]string, 0) var contents = make([]string, 0)
var function model.Function var function model.Function
var toolCall = false var toolCall = false
var arguments = make([]string, 0) var arguments = make([]string, 0)
if strings.HasPrefix(req.Model, "o1-") {
content := fmt.Sprintf("AI 思考结束,耗时:%d 秒。\n\n", time.Now().Unix()-session.Start)
contents = append(contents, "> AI 正在思考中...\n")
contents = append(contents, content)
utils.SendChunkMsg(ws, content)
}
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 { if !strings.Contains(line, "data:") || len(line) < 30 {
@@ -78,7 +116,7 @@ func (h *ChatHandler) sendOpenAiMessage(
} }
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。") utils.SendChunkMsg(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break break
} }
@@ -106,8 +144,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if res.Error == nil { if res.Error == nil {
toolCall = true toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) utils.SendChunkMsg(ws, callMsg)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg) contents = append(contents, callMsg)
} }
continue continue
@@ -121,17 +158,10 @@ func (h *ChatHandler) sendOpenAiMessage(
// output stopped // output stopped
if responseBody.Choices[0].FinishReason != "" { if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了 break // 输出完成或者输出中断了
} else { } else { // 正常输出结果
content := responseBody.Choices[0].Delta.Content content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content)) contents = append(contents, utils.InterfaceToString(content))
if isNew { utils.SendChunkMsg(ws, content)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} }
} // end for } // end for
@@ -149,39 +179,62 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id params["user_id"] = userVo.Id
var apiRes types.BizVo var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Content-Type", "application/json"). r, err := req2.C().R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", function.Token). SetHeader("Authorization", function.Token).
SetBody(params). SetBody(params).Post(function.Action)
SetSuccessResult(&apiRes).Post(function.Action)
errMsg := "" errMsg := ""
if err != nil { if err != nil {
errMsg = err.Error() errMsg = err.Error()
} else if r.IsErrorState() {
errMsg = r.Status
}
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else { } else {
utils.ReplyChunkMessage(ws, types.WsMessage{ all, _ := io.ReadAll(r.Body)
Type: types.WsMiddle, err = json.Unmarshal(all, &apiRes)
Content: apiRes.Data, if err != nil {
}) errMsg = err.Error()
contents = append(contents, utils.InterfaceToString(apiRes.Data)) } else if apiRes.Code != types.Success {
errMsg = apiRes.Message
}
} }
if errMsg != "" {
errMsg = "调用函数工具出错:" + errMsg
contents = append(contents, errMsg)
} else {
errMsg = utils.InterfaceToString(apiRes.Data)
contents = append(contents, errMsg)
}
utils.SendChunkMsg(ws, errMsg)
} }
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) usage := Usage{
Prompt: prompt,
Content: strings.Join(contents, ""),
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
message.Content = usage.Content
h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
} }
} else { } else { // 非流式输出
body, _ := io.ReadAll(response.Body) var respVo OpenAIResVo
return fmt.Errorf("请求 OpenAI API 失败:%s", body) body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败:%v", body)
}
err = json.Unmarshal(body, &respVo)
if err != nil {
return fmt.Errorf("解析响应失败:%v", body)
}
content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
utils.SendChunkMsg(ws, content)
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
} }
return nil return nil

View File

@@ -29,10 +29,37 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
// List 获取用户聊天应用列表 // List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatRoleHandler) List(c *gin.Context) {
tid := h.GetInt(c, "tid", 0)
var roles []model.ChatRole
session := h.DB.Where("enable", true)
if tid > 0 {
session = session.Where("tid", tid)
}
err := session.Order("sort_num ASC").Find(&roles).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
}
// ListByUser 获取用户添加的角色列表
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var roles []model.ChatRole var roles []model.ChatRole
query := h.DB.Where("enable", true) session := h.DB.Where("enable", true)
// 如果用户没登录,则获取所有角色
if userId > 0 { if userId > 0 {
var user model.User var user model.User
h.DB.First(&user, userId) h.DB.First(&user, userId)
@@ -42,12 +69,16 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
resp.ERROR(c, "角色解析失败!") resp.ERROR(c, "角色解析失败!")
return return
} }
query = query.Where("marker IN ?", roleKeys) // 保证用户至少有一个角色可用
if len(roleKeys) > 0 {
session = session.Where("marker IN ?", roleKeys)
}
} }
if id > 0 { if id > 0 {
query = query.Or("id", id) session = session.Or("id", id)
} }
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles) res := session.Order("sort_num ASC").Find(&roles)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
@@ -81,10 +112,9 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) err = h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }

View File

@@ -8,34 +8,33 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"net/http"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
) )
type DallJobHandler struct { type DallJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client dallService *dalle.Service
service *dalle.Service uploader *oss.UploaderManager
uploader *oss.UploaderManager userService *service.UserService
} }
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
return &DallJobHandler{ return &DallJobHandler{
service: service, dallService: service,
uploader: manager, uploader: manager,
userService: userService,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -43,82 +42,50 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
} }
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
return true
}
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) { func (h *DallJobHandler) Image(c *gin.Context) {
if !h.preCheck(c) {
return
}
var data types.DallTask var data types.DallTask
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
var chatModel model.ChatModel
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
resp.ERROR(c, "模型不存在")
return
}
// 检查用户剩余算力
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < chatModel.Power {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return
}
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
task := types.DallTask{
ClientId: data.ClientId,
UserId: uint(userId),
ModelId: chatModel.Id,
ModelName: chatModel.Value,
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
TranslateModelId: h.App.SysConfig.TranslateModelId,
Power: chatModel.Power,
}
job := model.DallJob{ job := model.DallJob{
UserId: uint(userId), UserId: uint(userId),
Prompt: data.Prompt, Prompt: data.Prompt,
Power: h.App.SysConfig.DallPower, Power: chatModel.Power,
TaskInfo: utils.JsonEncode(task),
} }
res := h.DB.Create(&job) res := h.DB.Create(&job)
if res.Error != nil { if res.Error != nil {
@@ -126,19 +93,18 @@ func (h *DallJobHandler) Image(c *gin.Context) {
return return
} }
h.service.PushTask(types.DallTask{ task.Id = job.Id
JobId: job.Id, h.dallService.PushTask(task)
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId) // 扣减算力
if client != nil { err = h.userService.DecreasePower(int(user.Id), chatModel.Power, model.PowerLog{
_ = client.Send([]byte("Task Updated")) Type: types.PowerConsume,
Model: chatModel.Value,
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
})
if err != nil {
resp.ERROR(c, "error with decrease power: "+err.Error())
return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -174,11 +140,11 @@ func (h *DallJobHandler) JobList(c *gin.Context) {
} }
// JobList 获取任务列表 // JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if finish { if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress >= ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
} }
@@ -192,11 +158,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize) session = session.Offset(offset).Limit(pageSize)
} }
// 统计总数
var total int64
session.Model(&model.DallJob{}).Count(&total)
var items []model.DallJob var items []model.DallJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
return res.Error, nil return res.Error, vo.Page{}
} }
var jobs = make([]vo.DallJob, 0) var jobs = make([]vo.DallJob, 0)
@@ -209,28 +178,28 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
jobs = append(jobs, job) jobs = append(jobs, job)
} }
return nil, jobs return nil, vo.NewPage(total, page, pageSize, jobs)
} }
// Remove remove task image // Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) { func (h *DallJobHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0) userId := h.GetLoginUserId(c)
var job model.DallJob var job model.DallJob
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil { if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在") resp.ERROR(c, "记录不存在")
return return
} }
// remove job recode // 删除任务
res := h.DB.Delete(&model.DallJob{Id: job.Id}) err := h.DB.Delete(&job).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, err.Error())
return return
} }
// remove image // remove image
err := h.uploader.GetUploadHandler().Delete(job.ImgURL) err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil { if err != nil {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
@@ -241,15 +210,36 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
// Publish 发布/取消发布图片到画廊显示 // Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) { func (h *DallJobHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0) userId := h.GetLoginUserId(c)
action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享 action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
res := h.DB.Model(&model.DallJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action) err := h.DB.Model(&model.DallJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *DallJobHandler) GetModels(c *gin.Context) {
var models []model.ChatModel
err := h.DB.Where("type", "img").Where("enabled", true).Find(&models).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var modelVos []vo.ChatModel
for _, v := range models {
var modelVo vo.ChatModel
err := utils.CopyObject(v, &modelVo)
if err != nil {
continue
}
modelVo.Id = v.Id
modelVos = append(modelVos, modelVo)
}
resp.SUCCESS(c, modelVos)
}

View File

@@ -8,15 +8,17 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"errors"
"fmt"
"strings" "strings"
"time" "time"
@@ -31,6 +33,7 @@ type FunctionHandler struct {
config types.ApiConfig config types.ApiConfig
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
dallService *dalle.Service dallService *dalle.Service
userService *service.UserService
} }
func NewFunctionHandler( func NewFunctionHandler(
@@ -38,7 +41,8 @@ func NewFunctionHandler(
db *gorm.DB, db *gorm.DB,
config *types.AppConfig, config *types.AppConfig,
manager *oss.UploaderManager, manager *oss.UploaderManager,
dallService *dalle.Service) *FunctionHandler { dallService *dalle.Service,
userService *service.UserService) *FunctionHandler {
return &FunctionHandler{ return &FunctionHandler{
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: server, App: server,
@@ -47,6 +51,7 @@ func NewFunctionHandler(
config: config.ApiConfig, config: config.ApiConfig,
uploadManager: manager, uploadManager: manager,
dallService: dallService, dallService: dallService,
userService: userService,
} }
} }
@@ -112,10 +117,13 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
SetHeader("AppId", h.config.AppId). SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)). SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() { if err != nil {
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err)) resp.ERROR(c, fmt.Sprintf("%v", err))
return return
} }
if r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("error http code status: %v", r.Status))
}
if res.Code != types.Success { if res.Code != types.Success {
resp.ERROR(c, res.Message) resp.ERROR(c, res.Message)
@@ -148,8 +156,12 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
SetHeader("AppId", h.config.AppId). SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)). SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() { if err != nil {
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err)) resp.ERROR(c, fmt.Sprintf("%v", err))
return
}
if r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("%v", r.Err))
return return
} }
@@ -163,7 +175,7 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
for _, v := range res.Data.Items { for _, v := range res.Data.Items {
builder = append(builder, v.Title) builder = append(builder, v.Title)
} }
builder = append(builder, fmt.Sprintf("%s", res.Data.Title)) builder = append(builder, res.Data.Title)
resp.SUCCESS(c, strings.Join(builder, "\n\n")) resp.SUCCESS(c, strings.Join(builder, "\n\n"))
} }
@@ -195,32 +207,71 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
// create dall task // create dall task
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
job := model.DallJob{ task := types.DallTask{
UserId: user.Id, UserId: user.Id,
Prompt: prompt, Prompt: prompt,
Power: h.App.SysConfig.DallPower, ModelId: 0,
ModelName: "dall-e-3",
TranslateModelId: h.App.SysConfig.TranslateModelId,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: h.App.SysConfig.DallPower,
} }
res = h.DB.Create(&job) job := model.DallJob{
UserId: user.Id,
if res.Error != nil { Prompt: prompt,
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error()) Power: h.App.SysConfig.DallPower,
TaskInfo: utils.JsonEncode(task),
}
err := h.DB.Create(&job).Error
if err != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error())
return return
} }
content, err := h.dallService.Image(types.DallTask{ task.Id = job.Id
JobId: job.Id, content, err := h.dallService.Image(task, true)
UserId: user.Id,
Prompt: job.Prompt,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: job.Power,
}, true)
if err != nil { if err != nil {
resp.ERROR(c, "任务执行失败:"+err.Error()) resp.ERROR(c, "任务执行失败:"+err.Error())
return return
} }
// 扣减算力
err = h.userService.DecreasePower(int(user.Id), job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: task.ModelName,
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)),
})
if err != nil {
resp.ERROR(c, "扣减算力失败:"+err.Error())
return
}
resp.SUCCESS(c, content) resp.SUCCESS(c, content)
} }
// List 获取所有的工具函数列表
func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function
err := h.DB.Where("enabled", true).Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
tools := make([]vo.Function, 0)
for _, v := range items {
var f vo.Function
err = utils.CopyObject(v, &f)
if err != nil {
continue
}
f.Action = ""
f.Token = ""
tools = append(tools, f)
}
resp.SUCCESS(c, tools)
}

View File

@@ -9,7 +9,6 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
@@ -59,23 +58,16 @@ func (h *InviteHandler) Code(c *gin.Context) {
// List Log 用户邀请记录 // List Log 用户邀请记录
func (h *InviteHandler) List(c *gin.Context) { func (h *InviteHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
var data struct { pageSize := h.GetInt(c, "page_size", 20)
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId) session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64 var total int64
session.Model(&model.InviteLog{}).Count(&total) session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog var items []model.InviteLog
var list = make([]vo.InviteLog, 0) var list = make([]vo.InviteLog, 0)
offset := (data.Page - 1) * data.PageSize offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var v vo.InviteLog var v vo.InviteLog
@@ -89,7 +81,7 @@ func (h *InviteHandler) List(c *gin.Context) {
} }
} }
} }
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
} }
// Hits 访问邀请码 // Hits 访问邀请码

View File

@@ -8,110 +8,66 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
) )
// MarkMapHandler 生成思维导图 // MarkMapHandler 生成思维导图
type MarkMapHandler struct { type MarkMapHandler struct {
BaseHandler BaseHandler
clients *types.LMap[int, *types.WsClient] clients *types.LMap[int, *types.WsClient]
userService *service.UserService
} }
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler {
return &MarkMapHandler{ return &MarkMapHandler{
BaseHandler: BaseHandler{App: app, DB: db}, BaseHandler: BaseHandler{App: app, DB: db},
clients: types.NewLMap[int, *types.WsClient](), clients: types.NewLMap[int, *types.WsClient](),
userService: userService,
} }
} }
func (h *MarkMapHandler) Client(c *gin.Context) { // Generate 生成思维导图
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) func (h *MarkMapHandler) Generate(c *gin.Context) {
if err != nil { var data struct {
logger.Error(err) Prompt string `json:"prompt"`
ModelId int `json:"model_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return return
} }
modelId := h.GetInt(c, "model_id", 0) userId := h.GetLoginUserId(c)
userId := h.GetInt(c, "user_id", 0)
client := types.NewWsClient(ws)
h.clients.Put(userId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.clients.Delete(userId)
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
continue
}
// change model
if message.Type == "model_id" {
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
continue
}
logger.Info("Receive a message: ", message.Content)
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
}
}
}()
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
var user model.User var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId) err := h.DB.Where("id", userId).First(&user, userId).Error
if res.Error != nil { if err != nil {
return fmt.Errorf("error with query user info: %v", res.Error) resp.ERROR(c, "error with query user info")
return
} }
var chatModel model.ChatModel var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel) err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
if res.Error != nil { if err != nil {
return fmt.Errorf("error with query chat model: %v", res.Error) resp.ERROR(c, "error with query chat model")
} return
if user.Status == false {
return errors.New("当前用户被禁用")
} }
if user.Power < chatModel.Power { if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power) resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power))
return
} }
messages := make([]interface{}, 0) messages := make([]interface{}, 0)
messages = append(messages, types.Message{Role: "system", Content: ` messages = append(messages, types.Message{Role: "system", Content: `
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: 你是一位非常优秀的思维导图助手, 你能帮助用户整理思路,根据用户提供的主题或内容,快速生成结构清晰,有条理的思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
# Geek-AI 助手 # Geek-AI 助手
## 完整的开源系统 ## 完整的开源系统
@@ -128,130 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
### 支付宝 ### 支付宝
### 微信 ### 微信
另外,除此之外不要任何解释性语句。 请直接生成结果,不要任何解释性语句。
`}) `})
messages = append(messages, types.Message{Role: "user", Content: prompt}) messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", data.Prompt)})
var req = types.ApiRequest{ content, err := utils.SendOpenAIMessage(h.DB, messages, data.ModelId)
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
if err != nil { if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err) resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
} return
defer response.Body.Close()
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return fmt.Errorf("error with decode data: %v", line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" {
break
}
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
} }
// 扣减算力 // 扣减算力
if chatModel.Power > 0 { if chatModel.Power > 0 {
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
if res.Error == nil { Type: types.PowerConsume,
// 记录算力消费日志 Model: chatModel.Value,
var u model.User Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
h.DB.Where("id", userId).First(&u) })
h.DB.Create(&model.PowerLog{ if err != nil {
UserId: u.Id, resp.ERROR(c, "error with save power log, "+err.Error())
Username: u.Username, return
Type: types.PowerConsume,
Amount: chatModel.Power,
Mark: types.PowerSub,
Balance: u.Power,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
CreatedAt: time.Now(),
})
} }
} }
return nil resp.SUCCESS(c, content)
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
if chatModel.KeyId > 0 {
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
return client.Do(request)
} }

View File

@@ -8,7 +8,6 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"encoding/base64"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
@@ -19,27 +18,27 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"net/http"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
) )
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
pool *mj.ServicePool mjService *mj.Service
snowflake *service.Snowflake snowflake *service.Snowflake
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService
} }
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
return &MidJourneyHandler{ return &MidJourneyHandler{
snowflake: snowflake, snowflake: snowflake,
pool: pool, mjService: service,
uploader: manager, uploader: manager,
userService: userService,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -59,40 +58,15 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true return true
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct { var data struct {
TaskType string `json:"task_type"` TaskType string `json:"task_type"`
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"` NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"` Rate string `json:"rate"`
@@ -178,10 +152,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
resp.ERROR(c, "error with generate task id: "+err.Error()) resp.ERROR(c, "error with generate task id: "+err.Error())
return return
} }
task := types.MjTask{
ClientId: data.ClientId,
TaskId: taskId,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
TranslateModelId: h.App.SysConfig.TranslateModelId,
}
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: data.TaskType, Type: data.TaskType,
UserId: userId, UserId: userId,
TaskId: taskId, TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0, Progress: 0,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params), Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower, Power: h.App.SysConfig.MjPower,
@@ -201,44 +188,26 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ task.Id = job.Id
Id: job.Id, h.mjService.PushTask(task)
TaskId: taskId,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
})
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
// 记录算力变化日志 Type: types.PowerConsume,
if tx.Error == nil && tx.RowsAffected > 0 { Model: "mid-journey",
user, _ := h.GetLoginUser(c) Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
h.DB.Create(&model.PowerLog{ })
UserId: user.Id, if err != nil {
Username: user.Username, resp.ERROR(c, err.Error())
Type: types.PowerConsume, return
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
type reqVo struct { type reqVo struct {
Index int `json:"index"` Index int `json:"index"`
ClientId string `json:"client_id"`
ChannelId string `json:"channel_id"` ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"` MessageHash string `json:"message_hash"`
@@ -259,51 +228,44 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true) taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{ task := types.MjTask{
Type: types.TaskUpscale.String(), ClientId: data.ClientId,
ReferenceId: data.MessageId,
UserId: userId,
TaskId: taskId,
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskUpscale, Type: types.TaskUpscale,
UserId: userId, UserId: userId,
ChannelId: data.ChannelId, ChannelId: data.ChannelId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) Mode: h.App.SysConfig.MjMode,
}
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
UserId: userId,
TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
task.Id = job.Id
h.mjService.PushTask(task)
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
// 记录算力变化日志 Type: types.PowerConsume,
if tx.Error == nil && tx.RowsAffected > 0 { Model: "mid-journey",
user, _ := h.GetLoginUser(c) Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
h.DB.Create(&model.PowerLog{ })
UserId: user.Id, if err != nil {
Username: user.Username, resp.ERROR(c, err.Error())
Type: types.PowerConsume, return
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -322,53 +284,44 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true) taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{ task := types.MjTask{
Type: types.TaskVariation.String(), Type: types.TaskVariation,
ChannelId: data.ChannelId, ClientId: data.ClientId,
ReferenceId: data.MessageId,
UserId: userId, UserId: userId,
TaskId: taskId, Index: data.Index,
Progress: 0, ChannelId: data.ChannelId,
Power: h.App.SysConfig.MjActionPower, MessageId: data.MessageId,
CreatedAt: time.Now(), MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
}
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
ChannelId: data.ChannelId,
UserId: userId,
TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
} }
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error()) resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return return
} }
h.pool.PushTask(types.MjTask{ task.Id = job.Id
Id: job.Id, h.mjService.PushTask(task)
Type: types.TaskVariation,
UserId: userId, err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Index: data.Index, Type: types.PowerConsume,
ChannelId: data.ChannelId, Model: "mid-journey",
MessageId: data.MessageId, Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
MessageHash: data.MessageHash,
}) })
if err != nil {
client := h.pool.Clients.Get(uint(job.UserId)) resp.ERROR(c, err.Error())
if client != nil { return
_ = client.Send([]byte("Task Updated"))
} }
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -403,7 +356,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
} }
// JobList 获取 MJ 任务列表 // JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if finish { if finish {
session = session.Where("progress >= ?", 100).Order("id DESC") session = session.Where("progress >= ?", 100).Order("id DESC")
@@ -421,10 +374,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
session = session.Offset(offset).Limit(pageSize) session = session.Offset(offset).Limit(pageSize)
} }
// 统计总数
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var items []model.MidJourneyJob var items []model.MidJourneyJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
return res.Error, nil return res.Error, vo.Page{}
} }
var jobs = make([]vo.MidJourneyJob, 0) var jobs = make([]vo.MidJourneyJob, 0)
@@ -434,17 +391,9 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
if err != nil { if err != nil {
continue continue
} }
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
jobs = append(jobs, job) jobs = append(jobs, job)
} }
return nil, jobs return nil, vo.NewPage(total, page, pageSize, jobs)
} }
// Remove remove task image // Remove remove task image
@@ -457,40 +406,12 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
return return
} }
// remove job recode // remove job
tx := h.DB.Begin() err := h.DB.Delete(&job).Error
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// refund power
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil { if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
var user model.User
h.DB.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
// remove image // remove image
err = h.uploader.GetUploadHandler().Delete(job.ImgURL) err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
@@ -498,11 +419,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -511,10 +427,9 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0) userId := h.GetInt(c, "user_id", 0)
action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享 action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
res := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action) err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败")
return return
} }

View File

@@ -17,19 +17,21 @@ import (
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"io"
"net/http"
"time" "time"
) )
type UploadHandler struct { type NetHandler struct {
BaseHandler BaseHandler
uploaderManager *oss.UploaderManager uploaderManager *oss.UploaderManager
} }
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler { func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *NetHandler {
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager} return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
} }
func (h *UploadHandler) Upload(c *gin.Context) { func (h *NetHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
@@ -60,9 +62,11 @@ func (h *UploadHandler) Upload(c *gin.Context) {
resp.SUCCESS(c, file) resp.SUCCESS(c, file)
} }
func (h *UploadHandler) List(c *gin.Context) { func (h *NetHandler) List(c *gin.Context) {
var data struct { var data struct {
Urls []string `json:"urls,omitempty"` Urls []string `json:"urls,omitempty"`
Page int `json:"page"`
PageSize int `json:"page_size"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -77,25 +81,36 @@ func (h *UploadHandler) List(c *gin.Context) {
if len(data.Urls) > 0 { if len(data.Urls) > 0 {
session = session.Where("url IN ?", data.Urls) session = session.Where("url IN ?", data.Urls)
} }
session.Find(&items) // 统计总数
if len(items) > 0 { var total int64
for _, v := range items { session.Model(&model.File{}).Count(&total)
var file vo.File
err := utils.CopyObject(v, &file) if data.Page > 0 && data.PageSize > 0 {
if err != nil { offset := (data.Page - 1) * data.PageSize
logger.Error(err) session = session.Offset(offset).Limit(data.PageSize)
continue }
} err := session.Order("id desc").Find(&items).Error
file.CreatedAt = v.CreatedAt.Unix() if err != nil {
files = append(files, file) resp.ERROR(c, err.Error())
} return
} }
resp.SUCCESS(c, files) for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, files))
} }
// Remove remove files // Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) { func (h *NetHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
var file model.File var file model.File
@@ -119,3 +134,28 @@ func (h *UploadHandler) Remove(c *gin.Context) {
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey) _ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *NetHandler) Download(c *gin.Context) {
fileUrl := c.Query("url")
// 使用http工具下载文件
if fileUrl == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
// 使用http.Get下载文件
r, err := http.Get(fileUrl)
if err != nil {
resp.ERROR(c, err.Error())
return
}
defer r.Body.Close()
if r.StatusCode != http.StatusOK {
resp.ERROR(c, "error status"+r.Status)
return
}
c.Status(http.StatusOK)
// 将下载的文件内容写入响应
_, _ = io.Copy(c.Writer, r.Body)
}

View File

@@ -48,6 +48,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix() order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
if !ok {
payMethod = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
if !ok {
payName = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
list = append(list, order) list = append(list, order)
} else { } else {
logger.Error(err) logger.Error(err)

View File

@@ -9,7 +9,6 @@ package handler
import ( import (
"embed" "embed"
"encoding/base64"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
@@ -18,10 +17,7 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/shopspring/decimal"
"math"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time" "time"
@@ -34,21 +30,15 @@ type PayWay struct {
Value string `json:"value"` Value string `json:"value"`
} }
var (
PayWayAlipay = PayWay{Name: "支付宝", Value: "alipay"}
PayWayXunHu = PayWay{Name: "虎皮椒", Value: "hupi"}
PayWayJs = PayWay{Name: "PayJS", Value: "payjs"}
PayWayWechat = PayWay{Name: "微信支付", Value: "wechat"}
)
// PaymentHandler 支付服务回调 handler // PaymentHandler 支付服务回调 handler
type PaymentHandler struct { type PaymentHandler struct {
BaseHandler BaseHandler
alipayService *payment.AlipayService alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService huPiPayService *payment.HuPiPayService
jsPayService *payment.JPayService geekPayService *payment.GeekPayService
wechatPayService *payment.WechatPayService wechatPayService *payment.WechatPayService
snowflake *service.Snowflake snowflake *service.Snowflake
userService *service.UserService
fs embed.FS fs embed.FS
lock sync.Mutex lock sync.Mutex
signKey string // 用来签名的随机秘钥 signKey string // 用来签名的随机秘钥
@@ -58,17 +48,19 @@ func NewPaymentHandler(
server *core.AppServer, server *core.AppServer,
alipayService *payment.AlipayService, alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService, huPiPayService *payment.HuPiPayService,
jsPayService *payment.JPayService, geekPayService *payment.GeekPayService,
wechatPayService *payment.WechatPayService, wechatPayService *payment.WechatPayService,
db *gorm.DB, db *gorm.DB,
userService *service.UserService,
snowflake *service.Snowflake, snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler { fs embed.FS) *PaymentHandler {
return &PaymentHandler{ return &PaymentHandler{
alipayService: alipayService, alipayService: alipayService,
huPiPayService: huPiPayService, huPiPayService: huPiPayService,
jsPayService: jsPayService, geekPayService: geekPayService,
wechatPayService: wechatPayService, wechatPayService: wechatPayService,
snowflake: snowflake, snowflake: snowflake,
userService: userService,
fs: fs, fs: fs,
lock: sync.Mutex{}, lock: sync.Mutex{},
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
@@ -79,309 +71,167 @@ func NewPaymentHandler(
} }
} }
func (h *PaymentHandler) DoPay(c *gin.Context) { func (h *PaymentHandler) Pay(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no") var data struct {
payWay := h.GetTrim(c, "pay_way") PayWay string `json:"pay_way"`
t := h.GetInt(c, "t", 0) PayType string `json:"pay_type"`
sign := h.GetTrim(c, "sign") ProductId int `json:"product_id"`
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey) UserId int `json:"user_id"`
newSign := utils.Sha256(signStr) Device string `json:"device"`
if newSign != sign { Host string `json:"host"`
resp.ERROR(c, "订单签名错误!")
return
} }
if err := c.ShouldBindJSON(&data); err != nil {
// 检查二维码是否过期
if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) {
resp.ERROR(c, "支付二维码已过期,请重新生成!")
return
}
if orderNo == "" {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
var order model.Order var product model.Product
res := h.DB.Where("order_no = ?", orderNo).First(&order) err := h.DB.Where("id", data.ProductId).First(&product).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "Order not found") resp.ERROR(c, "Product not found")
return return
} }
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回 orderNo, err := h.snowflake.Next(false)
if order.Status == types.OrderPaidSuccess { if err != nil {
resp.ERROR(c, "订单已支付成功,无需重复支付!") resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
err = h.DB.Where("id", data.UserId).First(&user).Error
if err != nil {
resp.NotAuth(c)
return return
} }
// 更新扫码状态 amount := product.Discount
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned) var payURL, returnURL, notifyURL string
switch data.PayWay {
case "alipay":
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
notifyURL = h.App.Config.AlipayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
}
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
returnURL = h.App.Config.AlipayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
money := fmt.Sprintf("%.2f", amount)
if data.Device == "wechat" {
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
} else {
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
}
if payWay == "alipay" { // 支付宝
amount := fmt.Sprintf("%.2f", order.Amount)
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, amount, order.Subject)
if err != nil { if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error()) resp.ERROR(c, "error with generate pay url: "+err.Error())
return return
} }
break
c.Redirect(302, uri) case "wechat":
return if h.App.Config.WechatPayConfig.NotifyURL != "" {
} else if payWay == "hupi" { // 虎皮椒支付 notifyURL = h.App.Config.WechatPayConfig.NotifyURL
params := payment.HuPiPayReq{ } else {
Version: "1.1", notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
TradeOrderId: orderNo, }
TotalFee: fmt.Sprintf("%f", order.Amount), if data.Device == "wechat" {
Title: order.Subject, payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL, OutTradeNo: orderNo,
WapName: "极客学长", TotalFee: int(amount * 100),
Subject: product.Name,
NotifyURL: notifyURL,
ClientIP: c.ClientIP(),
})
} else {
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name,
NotifyURL: notifyURL,
})
} }
r, err := h.huPiPayService.Pay(params)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
break
c.Redirect(302, r.URL)
}
resp.ERROR(c, "Invalid operations")
}
// PayQrcode 生成支付 URL 二维码
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
var payWay string
var notifyURL string
switch data.PayWay {
case "hupi": case "hupi":
payWay = PayWayXunHu.Value if h.App.Config.HuPiPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
break
case "payjs":
payWay = PayWayJs.Value
notifyURL = h.App.Config.JPayConfig.NotifyURL
break
case "alipay":
payWay = PayWayAlipay.Value
notifyURL = h.App.Config.AlipayConfig.NotifyURL
break
default:
payWay = PayWayWechat.Value
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
// PayJs 单独处理,只能用官方生成的二维码
if data.PayWay == "payjs" {
params := payment.JPayReq{
TotalFee: int(math.Ceil(order.Amount * 100)),
OutTradeNo: order.OrderNo,
Subject: product.Name,
}
r := h.jsPayService.Pay(params)
if r.IsOK() {
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
return
} else { } else {
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg) notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
return
} }
} if h.App.Config.HuPiPayConfig.ReturnURL != "" {
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
var logo string
if data.PayWay == "alipay" {
logo = "res/img/alipay.jpg"
} else if data.PayWay == "hupi" {
if h.App.Config.HuPiPayConfig.Name == "wechat" {
logo = "res/img/wechat-pay.jpg"
} else { } else {
logo = "res/img/alipay.jpg" returnURL = fmt.Sprintf("%s/payReturn", data.Host)
} }
} else if data.PayWay == "wechat" { r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
logo = "res/img/wechat-pay.jpg"
}
file, err := h.fs.Open(logo)
if err != nil {
resp.ERROR(c, "error with open qrcode log file: "+err.Error())
return
}
parse, err := url.Parse(notifyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
timestamp := time.Now().Unix()
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
sign := utils.Sha256(signStr)
var imageURL string
if data.PayWay == "wechat" {
payUrl, err := h.wechatPayService.PayUrlNative(order.OrderNo, int(math.Floor(order.Amount*100)), product.Name)
if err != nil {
resp.ERROR(c, "error with generating wechat payment qrcode: "+err.Error())
return
} else {
imageURL = payUrl
}
} else {
imageURL = fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
}
imgData, err := utils.GenQrcode(imageURL, 400, file)
if err != nil {
resp.ERROR(c, err.Error())
return
}
imgDataBase64 := base64.StdEncoding.EncodeToString(imgData)
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
}
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu.Name
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
params := payment.HuPiPayReq{
Version: "1.1", Version: "1.1",
TradeOrderId: orderNo, TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount), TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name, Title: product.Name,
NotifyURL: notifyURL, NotifyURL: notifyURL,
ReturnURL: returnURL, ReturnURL: returnURL,
CallbackURL: returnURL, WapName: "GeekAI助手",
WapName: "极客学长", })
WapUrl: baseURL,
Type: "WAP",
}
r, err := h.huPiPayService.Pay(params)
if err != nil { if err != nil {
errMsg := "error with generating Pay Hupi URL: " + err.Error() resp.ERROR(c, err.Error())
logger.Error(errMsg)
resp.ERROR(c, errMsg)
return return
} }
payURL = r.URL payURL = r.URL
case "payjs": break
payWay = PayWayJs.Name case "geek":
notifyURL = h.App.Config.JPayConfig.NotifyURL if h.App.Config.GeekPayConfig.NotifyURL != "" {
returnURL = h.App.Config.JPayConfig.ReturnURL notifyURL = h.App.Config.GeekPayConfig.NotifyURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart() } else {
params := url.Values{} notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
params.Add("total_fee", fmt.Sprintf("%d", totalFee)) }
params.Add("out_trade_no", orderNo) if h.App.Config.GeekPayConfig.ReturnURL != "" {
params.Add("body", product.Name) data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
params.Add("notify_url", notifyURL) }
params.Add("auto", "0") if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
payURL = h.jsPayService.PayH5(params) returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
case "alipay": } else {
payWay = PayWayAlipay.Name returnURL = fmt.Sprintf("%s/payReturn", data.Host)
payURL, err = h.alipayService.PayUrlMobile(orderNo, fmt.Sprintf("%.2f", amount), product.Name) }
params := payment.GeekPayParams{
OutTradeNo: orderNo,
Method: "web",
Name: product.Name,
Money: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
Type: data.PayType,
ReturnURL: returnURL,
NotifyURL: notifyURL,
}
res, err := h.geekPayService.Pay(params)
if err != nil { if err != nil {
errMsg := "error with generating Alipay URL: " + err.Error() resp.ERROR(c, err.Error())
resp.ERROR(c, errMsg)
return
}
case "wechat":
payWay = PayWayWechat.Name
payURL, err = h.wechatPayService.PayUrlH5(orderNo, int(amount*100), product.Name, c.ClientIP())
if err != nil {
errMsg := "error with generating Wechat URL: " + err.Error()
logger.Error(errMsg)
resp.ERROR(c, errMsg)
return return
} }
payURL = res.PayURL
default: default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay) resp.ERROR(c, "不支持的支付渠道")
return return
} }
// 创建订单 // 创建订单
remark := types.OrderRemark{ remark := types.OrderRemark{
Days: product.Days, Days: product.Days,
@@ -390,7 +240,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
Price: product.Price, Price: product.Price,
Discount: product.Discount, Discount: product.Discount,
} }
order := model.Order{ order := model.Order{
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
@@ -399,26 +248,24 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
Subject: product.Name, Subject: product.Name,
Amount: amount, Amount: amount,
Status: types.OrderNotPaid, Status: types.OrderNotPaid,
PayWay: payWay, PayWay: data.PayWay,
PayType: data.PayType,
Remark: utils.JsonEncode(remark), Remark: utils.JsonEncode(remark),
} }
res = h.DB.Create(&order) err = h.DB.Create(&order).Error
if res.Error != nil || res.RowsAffected == 0 { if err != nil {
resp.ERROR(c, "error with create order: "+res.Error.Error()) resp.ERROR(c, "error with create order: "+err.Error())
return return
} }
resp.SUCCESS(c, payURL)
resp.SUCCESS(c, gin.H{"url": payURL, "order_no": orderNo})
} }
// 异步通知回调公共逻辑 // 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order) err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
if res.Error != nil { if err != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error) return fmt.Errorf("error with fetch order: %v", err)
logger.Error(err)
return err
} }
h.lock.Lock() h.lock.Lock()
@@ -430,45 +277,24 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
} }
var user model.User var user model.User
res = h.DB.First(&user, order.UserId) err = h.DB.First(&user, order.UserId).Error
if res.Error != nil { if err != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error) return fmt.Errorf("error with fetch user info: %v", err)
logger.Error(err)
return err
} }
var remark types.OrderRemark var remark types.OrderRemark
err := utils.JsonDecode(order.Remark, &remark) err = utils.JsonDecode(order.Remark, &remark)
if err != nil { if err != nil {
err := fmt.Errorf("error with decode order remark: %v", err) return fmt.Errorf("error with decode order remark: %v", err)
logger.Error(err)
return err
} }
var opt string // 增加用户算力
var power int err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
if remark.Days > 0 { // VIP 充值 Type: types.PowerRecharge,
if user.ExpiredTime >= time.Now().Unix() { Model: order.PayWay,
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix() Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
opt = "VIP充值VIP 没到期,只延期不增加算力" })
} else { if err != nil {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
power = h.App.SysConfig.VipMonthPower
opt = "VIP充值"
}
user.Vip = true
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
return err return err
} }
@@ -476,29 +302,16 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
order.PayTime = time.Now().Unix() order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo order.TradeNo = tradeNo
res = h.DB.Updates(&order) err = h.DB.Updates(&order).Error
if res.Error != nil { if err != nil {
err := fmt.Errorf("error with update order info: %v", res.Error) return fmt.Errorf("error with update order info: %v", err)
logger.Error(err)
return err
} }
// 更新产品销量 // 更新产品销量
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1)) err = h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).
UpdateColumn("sales", gorm.Expr("sales + ?", 1)).Error
// 记录算力充值日志 if err != nil {
if power > 0 { return fmt.Errorf("error with update product sales: %v", err)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
} }
return nil return nil
@@ -506,20 +319,22 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// GetPayWays 获取支付方式 // GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) { func (h *PaymentHandler) GetPayWays(c *gin.Context) {
data := gin.H{} payWays := make([]gin.H, 0)
if h.App.Config.AlipayConfig.Enabled { if h.App.Config.AlipayConfig.Enabled {
data["alipay"] = gin.H{"name": "alipay"} payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
} }
if h.App.Config.HuPiPayConfig.Enabled { if h.App.Config.HuPiPayConfig.Enabled {
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name} payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
} }
if h.App.Config.JPayConfig.Enabled { if h.App.Config.GeekPayConfig.Enabled {
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name} for _, v := range h.App.Config.GeekPayConfig.Methods {
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
}
} }
if h.App.Config.WechatPayConfig.Enabled { if h.App.Config.WechatPayConfig.Enabled {
data["wechat"] = gin.H{"name": "wechat"} payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
} }
resp.SUCCESS(c, data) resp.SUCCESS(c, payWays)
} }
// HuPiPayNotify 虎皮椒支付异步回调 // HuPiPayNotify 虎皮椒支付异步回调
@@ -532,15 +347,17 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
orderNo := c.Request.Form.Get("trade_order_id") orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id") tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,订单 NO%s交易流水号%s", orderNo, tradeNo) logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
if err = h.huPiPayService.Check(tradeNo); err != nil { if err = h.huPiPayService.Check(orderNo); err != nil {
logger.Error("订单校验失败:", err) logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
err = h.notify(orderNo, tradeNo) err = h.notify(orderNo, tradeNo)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
@@ -556,18 +373,18 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return return
} }
// TODO验证交易签名 result := h.alipayService.TradeVerify(c.Request)
res := h.alipayService.TradeVerify(c.Request) logger.Infof("收到支付宝商号订单支付回调:%+v", result)
logger.Infof("验证支付结果:%+v", res) if !result.Success() {
if !res.Success() { logger.Error("订单校验失败:", result.Message)
logger.Error("订单校验失败:", res.Message)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
tradeNo := c.Request.Form.Get("trade_no") tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(res.OutTradeNo, tradeNo) err = h.notify(result.OutTradeNo, tradeNo)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
@@ -575,33 +392,30 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
c.String(http.StatusOK, "success") c.String(http.StatusOK, "success")
} }
// PayJsNotify PayJs 支付异步回调 // GeekPayNotify 支付异步回调
func (h *PaymentHandler) PayJsNotify(c *gin.Context) { func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
err := c.Request.ParseForm() var params = make(map[string]string)
if err != nil { for k := range c.Request.URL.Query() {
params[k] = c.Query(k)
}
logger.Infof("收到GeekPay订单支付回调%+v", params)
// 检查支付状态
if params["trade_status"] != "TRADE_SUCCESS" {
c.String(http.StatusOK, "success")
return
}
sign := h.geekPayService.Sign(params)
if sign != c.Query("sign") {
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
orderNo := c.Request.Form.Get("out_trade_no") err := h.notify(params["out_trade_no"], params["trade_no"])
returnCode := c.Request.Form.Get("return_code")
logger.Infof("收到PayJs订单支付回调订单 NO%s支付结果代码%v", orderNo, returnCode)
// 支付失败
if returnCode != "1" {
return
}
// 校验订单支付状态
tradeNo := c.Request.Form.Get("payjs_order_id")
err = h.jsPayService.TradeVerify(tradeNo)
if err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
@@ -618,6 +432,7 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
} }
result := h.wechatPayService.TradeVerify(c.Request) result := h.wechatPayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", result)
if !result.Success() { if !result.Success() {
logger.Error("订单校验失败:", err) logger.Error("订单校验失败:", err)
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
@@ -629,6 +444,7 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
err = h.notify(result.OutTradeNo, result.TradeId) err = h.notify(result.OutTradeNo, result.TradeId)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }

View File

@@ -0,0 +1,155 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// 提示词生成 handler
// 使用 AI 生成绘画指令,歌词,视频生成指令等
type PromptHandler struct {
BaseHandler
userService *service.UserService
}
func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *PromptHandler {
return &PromptHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
userService: userService,
}
}
// Lyric 生成歌词
func (h *PromptHandler) Lyric(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成歌词",
})
}
resp.SUCCESS(c, content)
}
// Image 生成 AI 绘画提示词
func (h *PromptHandler) Image(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成绘画提示词",
})
}
resp.SUCCESS(c, strings.Trim(content, `"`))
}
// Video 生成视频提示词
func (h *PromptHandler) Video(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成视频脚本",
})
}
resp.SUCCESS(c, strings.Trim(content, `"`))
}
// MetaPrompt 生成元提示词
func (h *PromptHandler) MetaPrompt(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
messages := make([]interface{}, 0)
messages = append(messages, types.Message{
Role: "system",
Content: service.MetaPromptTemplate,
})
messages = append(messages, types.Message{
Role: "user",
Content: "Task, Goal, or the Role to actor is:\n" + data.Prompt,
})
content, err := utils.SendOpenAIMessage(h.DB, messages, 0)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, strings.Trim(content, `"`))
}
func (h *PromptHandler) getPromptModel() string {
if h.App.SysConfig.TranslateModelId > 0 {
var chatModel model.ChatModel
h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel)
return chatModel.Value
}
return "gpt-4o"
}

View File

@@ -0,0 +1,221 @@
package handler
import (
"encoding/json"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// OpenAI Realtime API Relay Server
type RealtimeHandler struct {
BaseHandler
userService *service.UserService
}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *service.UserService) *RealtimeHandler {
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
// 获取客户端请求中指定的子协议
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
md := c.Query("model")
userId := h.GetLoginUserId(c)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
c.Abort()
return
}
// 将 HTTP 协议升级为 Websocket 协议
subProtocols := strings.Split(clientProtocols, ",")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: subProtocols,
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
defer ws.Close()
// 目前只针对 VIP 用户可以访问
if !user.Vip {
sendError(ws, "当前功能只针对 VIP 用户开放")
c.Abort()
return
}
var apiKey model.ApiKey
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
if apiKey.Id == 0 {
sendError(ws, "管理员未配置 Realtime API KEY")
c.Abort()
return
}
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
// 连接到真实的后端服务器,传入相同的子协议
headers := http.Header{}
// 修正子协议内容
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
if clientProtocols != "" {
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
}
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
if err != nil {
sendError(ws, "桥接后端 API 失败:"+err.Error())
c.Abort()
return
}
defer backendConn.Close()
// 确保协议一致性,如果失败返回
if ws.Subprotocol() != backendConn.Subprotocol() {
sendError(ws, "Websocket 子协议不匹配")
c.Abort()
return
}
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
// 开始双向转发
errorChan := make(chan error, 2)
go relay(ws, backendConn, errorChan)
go relay(backendConn, ws, errorChan)
// 等待其中一个连接关闭
err = <-errorChan
logger.Infof("Relay ended: %v", err)
}
func relay(src, dst *websocket.Conn, errorChan chan error) {
for {
messageType, message, err := src.ReadMessage()
if err != nil {
errorChan <- err
return
}
err = dst.WriteMessage(messageType, message)
if err != nil {
errorChan <- err
return
}
}
}
func sendError(ws *websocket.Conn, message string) {
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
if err != nil {
logger.Error(err)
}
}
// OpenAI 实时语音对话,一次性对话
func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
var apiKey model.ApiKey
err := h.DB.Session(&gorm.Session{}).Where("type", "realtime").Where("enabled", true).First(&apiKey).Error
if err != nil {
resp.ERROR(c, fmt.Sprintf("error with fetch OpenAI API KEY%v", err))
return
}
// 检查用户是否还有算力
userId := h.GetLoginUserId(c)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
resp.ERROR(c, fmt.Sprintf("error with fetch user%v", err))
return
}
if user.Power < h.App.SysConfig.AdvanceVoicePower {
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
return
}
var response utils.OpenAIResponse
client := req.C()
if len(apiKey.ProxyURL) > 5 {
client.SetProxyURL(apiKey.ApiURL)
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, "advanced-voice")
r, err := client.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(types.ApiRequest{
Model: "advanced-voice",
Temperature: 0.9,
MaxTokens: 1024,
Stream: false,
Messages: []interface{}{types.Message{
Role: "user",
Content: "实时语音通话",
}},
}).Post(apiURL)
if err != nil {
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败%v", err))
return
}
if r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败%v", r.Status))
return
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &response)
if err != nil {
resp.ERROR(c, fmt.Sprintf("解析API数据失败%v, %s", err, string(body)))
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 扣减算力
err = h.userService.DecreasePower(int(userId), h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
Type: types.PowerConsume,
Model: "advanced-voice",
Remark: "实时语音通话",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
logger.Infof("Response: %v", response.Choices[0].Message.Content)
// 提取链接
re := regexp.MustCompile(`\[(.*?)\]\((.*?)\)`)
links := re.FindAllStringSubmatch(response.Choices[0].Message.Content, -1)
var url = ""
if len(links) > 0 {
url = links[0][2]
}
resp.SUCCESS(c, url)
}

View File

@@ -0,0 +1,88 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"sync"
"time"
)
type RedeemHandler struct {
BaseHandler
lock sync.Mutex
userService *service.UserService
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler {
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
}
func (h *RedeemHandler) Verify(c *gin.Context) {
var data struct {
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
h.lock.Lock()
defer h.lock.Unlock()
var item model.Redeem
res := h.DB.Where("code", data.Code).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的兑换码!")
return
}
if !item.Enabled {
resp.ERROR(c, "当前兑换码已被禁用!")
return
}
if item.RedeemedAt > 0 {
resp.ERROR(c, "当前兑换码已使用,请勿重复使用!")
return
}
tx := h.DB.Begin()
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
Type: types.PowerRedeem,
Model: "兑换码",
Remark: fmt.Sprintf("兑换码核销,算力:%d兑换码%s...", item.Power, item.Code[:10]),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 更新核销状态
item.RedeemedAt = time.Now().Unix()
item.UserId = userId
err = tx.Updates(&item).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -1,108 +0,0 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
BaseHandler
lock sync.Mutex
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
func (h *RewardHandler) Verify(c *gin.Context) {
var data struct {
TxId string `json:"tx_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
}
// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
data.TxId = strings.ReplaceAll(data.TxId, " ", "")
h.lock.Lock()
defer h.lock.Unlock()
var item model.Reward
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的交易流水号!")
return
}
if item.Status {
resp.ERROR(c, "当前交易流水号已经被核销,请不要重复核销!")
return
}
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
logger.Error("添加应用失败:", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
// 更新核销状态
item.Status = true
item.UserId = user.Id
item.Exchange = utils.JsonEncode(exchange)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
logger.Error("添加应用失败:", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -19,11 +19,8 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"net/http"
"time" "time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
@@ -31,19 +28,27 @@ import (
type SdJobHandler struct { type SdJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
pool *sd.ServicePool sdService *sd.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
snowflake *service.Snowflake snowflake *service.Snowflake
leveldb *store.LevelDB leveldb *store.LevelDB
userService *service.UserService
} }
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { func NewSdJobHandler(app *core.AppServer,
db *gorm.DB,
service *sd.Service,
manager *oss.UploaderManager,
snowflake *service.Snowflake,
userService *service.UserService,
levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{ return &SdJobHandler{
pool: pool, sdService: service,
uploader: manager, uploader: manager,
snowflake: snowflake, snowflake: snowflake,
leveldb: levelDB, leveldb: levelDB,
userService: userService,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -51,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, man
} }
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) preCheck(c *gin.Context) bool { func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
@@ -79,11 +63,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
return false
}
if user.Power < h.App.SysConfig.SdPower { if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false return false
@@ -130,29 +109,37 @@ func (h *SdJobHandler) Image(c *gin.Context) {
resp.ERROR(c, "error with generate task id: "+err.Error()) resp.ERROR(c, "error with generate task id: "+err.Error())
return return
} }
params := types.SdTaskParams{
TaskId: taskId, task := types.SdTask{
Prompt: data.Prompt, ClientId: data.ClientId,
NegPrompt: data.NegPrompt, Type: types.TaskImage,
Steps: data.Steps, Params: types.SdTaskParams{
Sampler: data.Sampler, TaskId: taskId,
FaceFix: data.FaceFix, Prompt: data.Prompt,
CfgScale: data.CfgScale, NegPrompt: data.NegPrompt,
Seed: data.Seed, Steps: data.Steps,
Height: data.Height, Sampler: data.Sampler,
Width: data.Width, FaceFix: data.FaceFix,
HdFix: data.HdFix, CfgScale: data.CfgScale,
HdRedrawRate: data.HdRedrawRate, Seed: data.Seed,
HdScale: data.HdScale, Height: data.Height,
HdScaleAlg: data.HdScaleAlg, Width: data.Width,
HdSteps: data.HdSteps, HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
},
UserId: userId,
TranslateModelId: h.App.SysConfig.TranslateModelId,
} }
job := model.SdJob{ job := model.SdJob{
UserId: userId, UserId: userId,
Type: types.TaskImage.String(), Type: types.TaskImage.String(),
TaskId: params.TaskId, TaskId: taskId,
Params: utils.JsonEncode(params), Params: utils.JsonEncode(task.Params),
TaskInfo: utils.JsonEncode(task),
Prompt: data.Prompt, Prompt: data.Prompt,
Progress: 0, Progress: 0,
Power: h.App.SysConfig.SdPower, Power: h.App.SysConfig.SdPower,
@@ -164,34 +151,18 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return return
} }
h.pool.PushTask(types.SdTask{ task.Id = int(job.Id)
Id: int(job.Id), h.sdService.PushTask(task)
Type: types.TaskImage,
Params: params,
UserId: userId,
})
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
// 记录算力变化日志 Type: types.PowerConsume,
if tx.Error == nil && tx.RowsAffected > 0 { Model: "stable-diffusion",
user, _ := h.GetLoginUser(c) Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
h.DB.Create(&model.PowerLog{ })
UserId: user.Id, if err != nil {
Username: user.Username, resp.ERROR(c, err.Error())
Type: types.PowerConsume, return
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -228,11 +199,11 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
} }
// JobList 获取 MJ 任务列表 // JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if finish { if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress >= ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
} }
@@ -247,10 +218,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
session = session.Offset(offset).Limit(pageSize) session = session.Offset(offset).Limit(pageSize)
} }
// 统计总数
var total int64
session.Model(&model.SdJob{}).Count(&total)
var items []model.SdJob var items []model.SdJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
return res.Error, nil return res.Error, vo.Page{}
} }
var jobs = make([]vo.SdJob, 0) var jobs = make([]vo.SdJob, 0)
@@ -260,62 +235,47 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
if err != nil { if err != nil {
continue continue
} }
if item.Progress < 100 {
// 从 leveldb 中获取图片预览数据
var imageData string
err = h.leveldb.Get(item.TaskId, &imageData)
if err == nil {
job.ImgURL = "data:image/png;base64," + imageData
}
}
jobs = append(jobs, job) jobs = append(jobs, job)
} }
return nil, jobs return nil, vo.NewPage(total, page, pageSize, jobs)
} }
// Remove remove task image // Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) { func (h *SdJobHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0) userId := h.GetLoginUserId(c)
var job model.SdJob var job model.SdJob
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil { if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在") resp.ERROR(c, "记录不存在")
return return
} }
// remove job recode // 删除任务
res := h.DB.Delete(&model.SdJob{Id: job.Id}) err := h.DB.Delete(&job).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, err.Error())
return return
} }
// remove image // remove image
err := h.uploader.GetUploadHandler().Delete(job.ImgURL) err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil { if err != nil {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte(sd.Finished))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// Publish 发布/取消发布图片到画廊显示 // Publish 发布/取消发布图片到画廊显示
func (h *SdJobHandler) Publish(c *gin.Context) { func (h *SdJobHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0) userId := h.GetLoginUserId(c)
action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享 action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
res := h.DB.Model(&model.SdJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action) err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败")
return return
} }

View File

@@ -56,15 +56,17 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
var check bool if h.App.SysConfig.EnabledVerify {
if data.X != 0 { var check bool
check = h.captcha.SlideCheck(data) if data.X != 0 {
} else { check = h.captcha.SlideCheck(data)
check = h.captcha.Check(data) } else {
} check = h.captcha.Check(data)
if !check { }
resp.ERROR(c, "验证码错误,请先完人机验证") if !check {
return resp.ERROR(c, "请先完人机验证")
return
}
} }
code := utils.RandomNumber(6) code := utils.RandomNumber(6)
@@ -74,6 +76,20 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, "系统已禁用邮箱注册!") resp.ERROR(c, "系统已禁用邮箱注册!")
return return
} }
// 检查邮箱后缀是否在白名单
if len(h.App.SysConfig.EmailWhiteList) > 0 {
inWhiteList := false
for _, suffix := range h.App.SysConfig.EmailWhiteList {
if strings.HasSuffix(data.Receiver, suffix) {
inWhiteList = true
break
}
}
if !inWhiteList {
resp.ERROR(c, "邮箱后缀不在白名单中")
return
}
}
err = h.smtp.SendVerifyCode(data.Receiver, code) err = h.smtp.SendVerifyCode(data.Receiver, code)
} else { } else {
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") { if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {

View File

@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/suno" "geekai/service/suno"
"geekai/store/model" "geekai/store/model"
@@ -18,53 +19,33 @@ import (
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
"time" "time"
) )
type SunoHandler struct { type SunoHandler struct {
BaseHandler BaseHandler
service *suno.Service sunoService *suno.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService
} }
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler { func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
return &SunoHandler{ return &SunoHandler{
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
}, },
service: service, sunoService: service,
uploader: uploader, uploader: uploader,
userService: userService,
} }
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SunoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SunoHandler) Create(c *gin.Context) { func (h *SunoHandler) Create(c *gin.Context) {
var data struct { var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"` Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"` Lyrics string `json:"lyrics"`
@@ -72,21 +53,65 @@ func (h *SunoHandler) Create(c *gin.Context) {
Tags string `json:"tags"` Tags string `json:"tags"`
Title string `json:"title"` Title string `json:"title"`
Type int `json:"type"` Type int `json:"type"`
RefTaskId string `json:"ref_task_id"` // 续写的任务id RefTaskId string `json:"ref_task_id"` // 续写的任务id
ExtendSecs int `json:"extend_secs"` // 续写秒数 ExtendSecs int `json:"extend_secs"` // 续写秒数
RefSongId string `json:"ref_song_id"` // 续写的歌曲id RefSongId string `json:"ref_song_id"` // 续写的歌曲id
SongId string `json:"song_id,omitempty"` // 要拼接的歌曲id
AudioURL string `json:"audio_url,omitempty"` // 上传自己创作的歌曲
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
// 歌曲拼接
if data.SongId != "" && data.Type == 3 {
var song model.SunoJob
if err := h.DB.Where("song_id = ?", data.SongId).First(&song).Error; err == nil {
data.Instrumental = song.Instrumental
data.Model = song.ModelName
data.Tags = song.Tags
}
// 拼接歌词
var refSong model.SunoJob
if err := h.DB.Where("song_id = ?", data.RefSongId).First(&refSong).Error; err == nil {
data.Prompt = fmt.Sprintf("%s\n%s", song.Prompt, refSong.Prompt)
}
}
task := types.SunoTask{
ClientId: data.ClientId,
UserId: int(h.GetLoginUserId(c)),
Type: data.Type,
Title: data.Title,
RefTaskId: data.RefTaskId,
RefSongId: data.RefSongId,
ExtendSecs: data.ExtendSecs,
Prompt: data.Prompt,
Tags: data.Tags,
Model: data.Model,
Instrumental: data.Instrumental,
SongId: data.SongId,
AudioURL: data.AudioURL,
}
// 插入数据库 // 插入数据库
job := model.SunoJob{ job := model.SunoJob{
UserId: int(h.GetLoginUserId(c)), UserId: task.UserId,
Prompt: data.Prompt, Prompt: data.Prompt,
Instrumental: data.Instrumental, Instrumental: data.Instrumental,
ModelName: data.Model, ModelName: data.Model,
TaskInfo: utils.JsonEncode(task),
Tags: data.Tags, Tags: data.Tags,
Title: data.Title, Title: data.Title,
Type: data.Type, Type: data.Type,
@@ -94,6 +119,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
RefTaskId: data.RefTaskId, RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs, ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower, Power: h.App.SysConfig.SunoPower,
SongId: utils.RandString(32),
} }
if data.Lyrics != "" { if data.Lyrics != "" {
job.Prompt = data.Lyrics job.Prompt = data.Lyrics
@@ -105,49 +131,28 @@ func (h *SunoHandler) Create(c *gin.Context) {
} }
// 创建任务 // 创建任务
h.service.PushTask(types.SunoTask{ task.Id = job.Id
Id: job.Id, h.sunoService.PushTask(task)
UserId: job.UserId,
Type: job.Type,
Title: job.Title,
RefTaskId: data.RefTaskId,
RefSongId: data.RefSongId,
ExtendSecs: data.ExtendSecs,
Prompt: job.Prompt,
Tags: data.Tags,
Model: data.Model,
Instrumental: data.Instrumental,
})
// update user's power // update user's power
tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
// 记录算力变化日志 Type: types.PowerConsume,
if tx.Error == nil && tx.RowsAffected > 0 { Model: job.ModelName,
user, _ := h.GetLoginUser(c) Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
h.DB.Create(&model.PowerLog{ CreatedAt: time.Now(),
UserId: user.Id, })
Username: user.Username, if err != nil {
Type: types.PowerConsume, resp.ERROR(c, err.Error())
Amount: job.Power, return
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(),
})
} }
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *SunoHandler) List(c *gin.Context) { func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0) page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 0) pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数 // 统计总数
@@ -209,8 +214,20 @@ func (h *SunoHandler) Remove(c *gin.Context) {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 只有失败或者已完成的任务可以删除
if !(job.Progress == service.FailTaskProgress || job.Progress == 100) {
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
return
}
// 删除任务 // 删除任务
h.DB.Delete(&job) err = h.DB.Delete(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 删除文件 // 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL) _ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL) _ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
@@ -306,40 +323,3 @@ func (h *SunoHandler) Play(c *gin.Context) {
} }
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1)) h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
} }
const genLyricTemplate = `
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
请以【%s】为主题创作一首歌曲歌曲时间不要太短3分钟左右不要输出任何解释性的内容。
输出格式如下:
歌曲名称
第一节:
{{歌词内容}}
副歌:
{{歌词内容}}
第二节:
{{歌词内容}}
副歌:
{{歌词内容}}
尾声:
{{歌词内容}}
`
// Lyric 生成歌词
func (h *SunoHandler) Lyric(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}

View File

@@ -3,15 +3,52 @@ package handler
import ( import (
"geekai/service" "geekai/service"
"geekai/service/payment" "geekai/service/payment"
"github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
) )
type TestHandler struct { type TestHandler struct {
db *gorm.DB db *gorm.DB
snowflake *service.Snowflake snowflake *service.Snowflake
js *payment.JPayService js *payment.GeekPayService
} }
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.JPayService) *TestHandler { func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js} return &TestHandler{db: db, snowflake: snowflake, js: js}
} }
func (h *TestHandler) SseTest(c *gin.Context) {
//c.Header("Body-Type", "text/event-stream")
//c.Header("Cache-Control", "no-cache")
//c.Header("Connection", "keep-alive")
//
//
//// 模拟实时数据更新
//for i := 0; i < 10; i++ {
// // 发送 SSE 数据
// _, err := fmt.Fprintf(c.Writer, "data: %v\n\n", data)
// if err != nil {
// return
// }
// c.Writer.Flush() // 确保立即发送数据
// time.Sleep(1 * time.Second) // 每秒发送一次数据
//}
//c.Abort()
}
func (h *TestHandler) PostTest(c *gin.Context) {
var data struct {
Message string `json:"message"`
UserId uint `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 将参数存储在上下文中
c.Set("data", data)
c.Next()
}

View File

@@ -33,6 +33,8 @@ type UserHandler struct {
searcher *xdb.Searcher searcher *xdb.Searcher
redis *redis.Client redis *redis.Client
licenseService *service.LicenseService licenseService *service.LicenseService
captcha *service.CaptchaService
userService *service.UserService
} }
func NewUserHandler( func NewUserHandler(
@@ -40,12 +42,16 @@ func NewUserHandler(
db *gorm.DB, db *gorm.DB,
searcher *xdb.Searcher, searcher *xdb.Searcher,
client *redis.Client, client *redis.Client,
captcha *service.CaptchaService,
userService *service.UserService,
licenseService *service.LicenseService) *UserHandler { licenseService *service.LicenseService) *UserHandler {
return &UserHandler{ return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app}, BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher, searcher: searcher,
redis: client, redis: client,
captcha: captcha,
licenseService: licenseService, licenseService: licenseService,
userService: userService,
} }
} }
@@ -55,14 +61,33 @@ func (h *UserHandler) Register(c *gin.Context) {
var data struct { var data struct {
RegWay string `json:"reg_way"` RegWay string `json:"reg_way"`
Username string `json:"username"` Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
Code string `json:"code"` Code string `json:"code"`
InviteCode string `json:"invite_code"` InviteCode string `json:"invite_code"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
data.Password = strings.TrimSpace(data.Password) data.Password = strings.TrimSpace(data.Password)
if len(data.Password) < 8 { if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符") resp.ERROR(c, "密码长度不能少于8个字符")
@@ -79,8 +104,15 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检查验证码 // 检查验证码
var key string var key string
if data.RegWay == "email" || data.RegWay == "mobile" { if data.RegWay == "email" {
key = CodeStorePrefix + data.Username key = CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
} else if data.RegWay == "mobile" {
key = CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误") resp.ERROR(c, "验证码错误")
@@ -98,26 +130,37 @@ func (h *UserHandler) Register(c *gin.Context) {
} }
} }
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
Power: h.App.SysConfig.InitPower,
}
// check if the username is existing // check if the username is existing
var item model.User var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item) session := h.DB.Session(&gorm.Session{})
if data.Mobile != "" {
session = session.Where("mobile = ?", data.Mobile)
user.Username = data.Mobile
user.Mobile = data.Mobile
} else if data.Email != "" {
session = session.Where("email = ?", data.Email)
user.Username = data.Email
user.Email = data.Email
} else if data.Username != "" {
session = session.Where("username = ?", data.Username)
}
session.First(&item)
if item.Id > 0 { if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册") resp.ERROR(c, "该用户名已经被注册")
return return
} }
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
Power: h.App.SysConfig.InitPower,
}
// 被邀请人也获得赠送算力 // 被邀请人也获得赠送算力
if data.InviteCode != "" { if data.InviteCode != "" {
user.Power += h.App.SysConfig.InvitePower user.Power += h.App.SysConfig.InvitePower
@@ -128,10 +171,9 @@ func (h *UserHandler) Register(c *gin.Context) {
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)) user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
} }
res = h.DB.Create(&user) tx := h.DB.Begin()
if res.Error != nil { if err := tx.Create(&user).Error; err != nil {
resp.ERROR(c, "保存数据失败") resp.ERROR(c, err.Error())
logger.Error(res.Error)
return return
} }
@@ -140,35 +182,35 @@ func (h *UserHandler) Register(c *gin.Context) {
// 增加邀请数量 // 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 { if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower)) err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
// 记录邀请算力充值日志 Type: types.PowerInvite,
var inviter model.User Model: "Invite",
h.DB.Where("id", inviteCode.UserId).First(&inviter) Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
}) })
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
// 添加邀请记录 // 添加邀请记录
h.DB.Create(&model.InviteLog{ err := tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId, InviterId: inviteCode.UserId,
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
InviteCode: inviteCode.Code, InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
}) }).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
tx.Commit()
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码 _ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token // 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id, "user_id": user.Id,
@@ -193,20 +235,41 @@ func (h *UserHandler) Login(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
needVerify, err := h.redis.Get(c, verifyKey).Bool()
if h.App.SysConfig.EnabledVerify && needVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var user model.User var user model.User
res := h.DB.Where("username = ?", data.Username).First(&user) res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil { if res.Error != nil {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名不存在") resp.ERROR(c, "用户名不存在")
return return
} }
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
if password != user.Password { if password != user.Password {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名或密码错误") resp.ERROR(c, "用户名或密码错误")
return return
} }
@@ -239,11 +302,13 @@ func (h *UserHandler) Login(c *gin.Context) {
return return
} }
// 保存到 redis // 保存到 redis
key := fmt.Sprintf("users/%d", user.Id) sessionKey := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error()) resp.ERROR(c, "error with save token: "+err.Error())
return return
} }
// 移除登录行为验证码
h.redis.Del(c, verifyKey)
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username}) resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
} }
@@ -285,8 +350,10 @@ func (h *UserHandler) CLogin(c *gin.Context) {
// CLoginCallback 第三方登录回调 // CLoginCallback 第三方登录回调
func (h *UserHandler) CLoginCallback(c *gin.Context) { func (h *UserHandler) CLoginCallback(c *gin.Context) {
loginType := h.GetTrim(c, "login_type") loginType := c.Query("login_type")
code := h.GetTrim(c, "code") code := c.Query("code")
userId := h.GetInt(c, "user_id", 0)
action := c.Query("action")
var res types.BizVo var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL) apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
@@ -311,11 +378,34 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
// login successfully // login successfully
data := res.Data.(map[string]interface{}) data := res.Data.(map[string]interface{})
session := gin.H{}
var user model.User var user model.User
tx := h.DB.Debug().Where("openid", data["openid"]).First(&user) if action == "bind" && userId > 0 {
if tx.Error != nil { // user not exist, create new user err = h.DB.Where("openid", data["openid"]).First(&user).Error
// 检测最大注册人数 if err == nil {
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
return
}
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.ERROR(c, "绑定用户不存在")
return
}
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": ""})
return
}
session := gin.H{}
tx := h.DB.Where("openid", data["openid"]).First(&user)
if tx.Error != nil {
// create new user
var totalUser int64 var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser) h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
@@ -326,16 +416,15 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
salt := utils.RandString(8) salt := utils.RandString(8)
password := fmt.Sprintf("%d", utils.RandomNumber(8)) password := fmt.Sprintf("%d", utils.RandomNumber(8))
user = model.User{ user = model.User{
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)), Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
Password: utils.GenPassword(password, salt), Password: utils.GenPassword(password, salt),
Avatar: fmt.Sprintf("%s", data["avatar"]), Avatar: fmt.Sprintf("%s", data["avatar"]),
Salt: salt, Salt: salt,
Status: true, Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型 Power: h.App.SysConfig.InitPower,
Power: h.App.SysConfig.InitPower, OpenId: fmt.Sprintf("%s", data["openid"]),
OpenId: fmt.Sprintf("%s", data["openid"]), Nickname: fmt.Sprintf("%s", data["nickname"]),
Nickname: fmt.Sprintf("%s", data["nickname"]),
} }
tx = h.DB.Create(&user) tx = h.DB.Create(&user)
@@ -383,18 +472,24 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
// Session 获取/验证会话 // Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) { func (h *UserHandler) Session(c *gin.Context) {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err == nil { if err != nil {
var userVo vo.User resp.NotAuth(c, err.Error())
err := utils.CopyObject(user, &userVo) return
if err != nil {
resp.ERROR(c)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} else {
resp.NotAuth(c)
} }
var userVo vo.User
err = utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 用户 VIP 到期
if user.ExpiredTime > 0 && user.ExpiredTime < time.Now().Unix() {
h.DB.Model(&user).UpdateColumn("vip", false)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} }
type userProfile struct { type userProfile struct {
@@ -481,20 +576,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
} }
newPass := utils.GenPassword(data.Password, user.Salt) newPass := utils.GenPassword(data.Password, user.Salt)
res := h.DB.Model(&user).UpdateColumn("password", newPass) err = h.DB.Model(&user).UpdateColumn("password", newPass).Error
if res.Error != nil { if err != nil {
logger.Error("error with update database", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// ResetPass 重置密码 // ResetPass 找回密码
func (h *UserHandler) ResetPass(c *gin.Context) { func (h *UserHandler) ResetPass(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Type string `json:"type"` // 验证类别mobile, email
Mobile string `json:"mobile"` // 手机号
Email string `json:"email"` // 邮箱地址
Code string `json:"code"` // 验证码 Code string `json:"code"` // 验证码
Password string `json:"password"` // 新密码 Password string `json:"password"` // 新密码
} }
@@ -503,37 +599,47 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
return return
} }
session := h.DB.Session(&gorm.Session{})
var key string
if data.Type == "email" {
session = session.Where("email", data.Email)
key = CodeStorePrefix + data.Email
} else if data.Type == "mobile" {
session = session.Where("mobile", data.Mobile)
key = CodeStorePrefix + data.Mobile
} else {
resp.ERROR(c, "验证类别错误")
return
}
var user model.User var user model.User
res := h.DB.Where("username", data.Username).First(&user) err := session.First(&user).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "用户不存在!") resp.ERROR(c, "用户不存在!")
return return
} }
// 检查验证码 // 检查验证码
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误") resp.ERROR(c, "验证码错误")
return return
} }
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
user.Password = password err = h.DB.Model(&user).UpdateColumn("password", password).Error
res = h.DB.Updates(&user) if err != nil {
if res.Error != nil { resp.ERROR(c, err.Error())
resp.ERROR(c)
} else { } else {
h.redis.Del(c, key) h.redis.Del(c, key)
resp.SUCCESS(c) resp.SUCCESS(c)
} }
} }
// BindUsername 重置账 // BindMobile 绑定手机
func (h *UserHandler) BindUsername(c *gin.Context) { func (h *UserHandler) BindMobile(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Mobile string `json:"mobile"`
Code string `json:"code"` Code string `json:"code"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -541,7 +647,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
} }
// 检查验证码 // 检查验证码
key := CodeStorePrefix + data.Username key := CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误") resp.ERROR(c, "验证码错误")
@@ -550,22 +656,56 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定 // 检查手机号是否被其他账号绑定
var item model.User var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item) res := h.DB.Where("mobile", data.Mobile).First(&item)
if res.Error == nil { if res.Error == nil {
resp.ERROR(c, "该号已经其他账号绑定") resp.ERROR(c, "该手机号已经绑定了其他账号,请更换手机号")
return return
} }
user, err := h.GetLoginUser(c) userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("mobile", data.Mobile).Error
if err != nil { if err != nil {
resp.NotAuth(c) resp.ERROR(c, err.Error())
return return
} }
res = h.DB.Model(&user).UpdateColumn("username", data.Username) _ = h.redis.Del(c, key) // 删除短信验证码
if res.Error != nil { resp.SUCCESS(c)
logger.Error(res.Error) }
resp.ERROR(c, "更新数据库失败")
// BindEmail 绑定邮箱
func (h *UserHandler) BindEmail(c *gin.Context) {
var data struct {
Email string `json:"email"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检查验证码
key := CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.DB.Where("email", data.Email).First(&item)
if res.Error == nil {
resp.ERROR(c, "该邮箱地址已经绑定了其他账号,请更邮箱地址")
return
}
userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("email", data.Email).Error
if err != nil {
resp.ERROR(c, err.Error())
return return
} }

View File

@@ -0,0 +1,215 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/video"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type VideoHandler struct {
BaseHandler
videoService *video.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
return &VideoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
videoService: service,
uploader: uploader,
userService: userService,
}
}
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"`
ExpandPrompt bool `json:"expand_prompt,omitempty"`
Loop bool `json:"loop,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
if data.Prompt == "" {
resp.ERROR(c, "prompt is needed")
return
}
userId := int(h.GetLoginUserId(c))
params := types.VideoParams{
PromptOptimize: data.ExpandPrompt,
Loop: data.Loop,
StartImgURL: data.FirstFrameImg,
EndImgURL: data.EndFrameImg,
}
task := types.VideoTask{
ClientId: data.ClientId,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.TranslateModelId,
}
// 插入数据库
job := model.VideoJob{
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
TaskInfo: utils.JsonEncode(task),
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
task.Id = job.Id
h.videoService.PushTask(task)
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "luma",
Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *VideoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
t := c.Query("type")
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
all := h.GetBool(c, "all")
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
if t != "" {
session = session.Where("type", t)
}
if all {
session = session.Where("publish", 0).Where("progress", 100)
} else {
session = session.Where("user_id", h.GetLoginUserId(c))
}
// 统计总数
var total int64
session.Model(&model.VideoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.VideoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 转换为 VO
items := make([]vo.VideoJob, 0)
for _, v := range list {
var item vo.VideoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if item.VideoURL == "" {
item.VideoURL = v.WaterURL
}
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *VideoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 只有失败或者超时的任务才能删除
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
return
}
// 删除任务
err = h.DB.Delete(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
}
func (h *VideoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.DB.Model(&job).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

151
api/handler/ws_handler.go Normal file
View File

@@ -0,0 +1,151 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
)
// Websocket 连接处理 handler
type WebsocketHandler struct {
BaseHandler
wsService *service.WebsocketService
chatHandler *ChatHandler
}
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
return &WebsocketHandler{
BaseHandler: BaseHandler{App: app, DB: db},
chatHandler: chatHandler,
wsService: s,
}
}
func (h *WebsocketHandler) Client(c *gin.Context) {
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: strings.Split(clientProtocols, ","),
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
clientId := c.Query("client_id")
client := types.NewWsClient(ws, clientId)
userId := h.GetLoginUserId(c)
if userId == 0 {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
h.wsService.Clients.Put(clientId, client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.wsService.Clients.Delete(clientId)
break
}
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
logger.Debugf("Receive a message:%+v", message)
if message.Type == types.MsgTypePing {
utils.SendChannelMsg(client, types.ChPing, "pong")
continue
}
// 当前只处理聊天消息,其他消息全部丢弃
var chatMessage types.ChatMessage
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
if err != nil || message.Channel != types.ChChat {
logger.Warnf("invalid message body:%+v", message.Body)
continue
}
var chatRole model.ChatRole
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
if err != nil || !chatRole.Enable {
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
continue
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
chatMessage.RoleId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
if err != nil || chatModel.Enabled == false {
utils.SendAndFlush(client, "当前AI模型暂未启用请更换模型后再发起对话")
continue
}
session := &types.ChatSession{
ClientIP: c.ClientIP(),
UserId: userId,
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
if chat.Id > 0 {
chatModel.Id = chat.ModelId
chatMessage.RoleId = int(chat.RoleId)
}
session.ChatId = chatMessage.ChatId
session.Tools = chatMessage.Tools
session.Stream = chatMessage.Stream
// 复制模型数据
err = utils.CopyObject(chatModel, &session.Model)
if err != nil {
logger.Error(err, chatModel)
}
session.Model.Id = chatModel.Id
ctx, cancel := context.WithCancel(context.Background())
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
if err != nil {
logger.Error(err)
utils.SendAndFlush(client, err.Error())
} else {
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
logger.Infof("回答完毕: %v", message.Body)
}
}
}()
}

View File

@@ -14,7 +14,6 @@ import (
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/handler/admin" "geekai/handler/admin"
"geekai/handler/chatimpl"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/dalle" "geekai/service/dalle"
@@ -24,7 +23,7 @@ import (
"geekai/service/sd" "geekai/service/sd"
"geekai/service/sms" "geekai/service/sms"
"geekai/service/suno" "geekai/service/suno"
"geekai/service/wx" "geekai/service/video"
"geekai/store" "geekai/store"
"io" "io"
"log" "log"
@@ -128,10 +127,10 @@ func main() {
// 创建控制器 // 创建控制器
fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler), fx.Provide(handler.NewUserHandler),
fx.Provide(chatimpl.NewChatHandler), fx.Provide(handler.NewChatHandler),
fx.Provide(handler.NewUploadHandler), fx.Provide(handler.NewNetHandler),
fx.Provide(handler.NewSmsHandler), fx.Provide(handler.NewSmsHandler),
fx.Provide(handler.NewRewardHandler), fx.Provide(handler.NewRedeemHandler),
fx.Provide(handler.NewCaptchaHandler), fx.Provide(handler.NewCaptchaHandler),
fx.Provide(handler.NewMidJourneyHandler), fx.Provide(handler.NewMidJourneyHandler),
fx.Provide(handler.NewChatModelHandler), fx.Provide(handler.NewChatModelHandler),
@@ -146,8 +145,8 @@ func main() {
fx.Provide(admin.NewAdminHandler), fx.Provide(admin.NewAdminHandler),
fx.Provide(admin.NewApiKeyHandler), fx.Provide(admin.NewApiKeyHandler),
fx.Provide(admin.NewUserHandler), fx.Provide(admin.NewUserHandler),
fx.Provide(admin.NewChatRoleHandler), fx.Provide(admin.NewChatAppHandler),
fx.Provide(admin.NewRewardHandler), fx.Provide(admin.NewRedeemHandler),
fx.Provide(admin.NewDashboardHandler), fx.Provide(admin.NewDashboardHandler),
fx.Provide(admin.NewChatModelHandler), fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler), fx.Provide(admin.NewProductHandler),
@@ -161,13 +160,12 @@ func main() {
return service.NewCaptchaService(config.ApiConfig) return service.NewCaptchaService(config.ApiConfig)
}), }),
fx.Provide(oss.NewUploaderManager), fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService), fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) { fx.Invoke(func(s *dalle.Service) {
service.Run() s.Run()
service.CheckTaskNotify() s.CheckTaskNotify()
service.DownloadImages() s.DownloadImages()
service.CheckTaskStatus() s.CheckTaskStatus()
}), }),
// 邮件服务 // 邮件服务
@@ -178,36 +176,22 @@ func main() {
licenseService.SyncLicense() licenseService.SyncLicense()
}), }),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
if config.WeChatBot {
err := bot.Run()
if err != nil {
logger.Error("微信登录失败:", err)
}
}
}),
// MidJourney service pool // MidJourney service pool
fx.Provide(mj.NewServicePool), fx.Provide(mj.NewService),
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) { fx.Provide(mj.NewClient),
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs) fx.Invoke(func(s *mj.Service) {
if pool.HasAvailableService() { s.Run()
pool.DownloadImages() s.SyncTaskProgress()
pool.CheckTaskNotify() s.CheckTaskNotify()
pool.SyncTaskProgress() s.DownloadImages()
}
}), }),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewServicePool), fx.Provide(sd.NewService),
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) { fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
pool.InitServices(config.SdConfigs) s.Run()
if pool.HasAvailableService() { s.CheckTaskStatus()
pool.CheckTaskNotify() s.CheckTaskNotify()
pool.CheckTaskStatus()
}
}), }),
fx.Provide(suno.NewService), fx.Provide(suno.NewService),
@@ -215,9 +199,16 @@ func main() {
s.Run() s.Run()
s.SyncTaskProgress() s.SyncTaskProgress()
s.CheckTaskNotify() s.CheckTaskNotify()
s.DownloadImages() s.DownloadFiles()
}), }),
fx.Provide(video.NewService),
fx.Invoke(func(s *video.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay), fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewJPayService), fx.Provide(payment.NewJPayService),
@@ -234,8 +225,9 @@ func main() {
// 注册路由 // 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
group := s.Engine.Group("/api/role/") group := s.Engine.Group("/api/app/")
group.GET("list", h.List) group.GET("list", h.List)
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateRole) group.POST("update", h.UpdateRole)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) { fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
@@ -247,14 +239,14 @@ func main() {
group.GET("profile", h.Profile) group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate) group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass) group.POST("password", h.UpdatePass)
group.POST("bind/username", h.BindUsername) group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.POST("resetPass", h.ResetPass) group.POST("resetPass", h.ResetPass)
group.GET("clogin", h.CLogin) group.GET("clogin", h.CLogin)
group.GET("clogin/callback", h.CLoginCallback) group.GET("clogin/callback", h.CLoginCallback)
}), }),
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
group := s.Engine.Group("/api/chat/") group := s.Engine.Group("/api/chat/")
group.Any("new", h.ChatHandle)
group.GET("list", h.List) group.GET("list", h.List)
group.GET("detail", h.Detail) group.GET("detail", h.Detail)
group.POST("update", h.Update) group.POST("update", h.Update)
@@ -264,10 +256,11 @@ func main() {
group.POST("tokens", h.Tokens) group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate) group.GET("stop", h.StopGenerate)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
s.Engine.POST("/api/upload", h.Upload) s.Engine.POST("/api/upload", h.Upload)
s.Engine.POST("/api/upload/list", h.List) s.Engine.POST("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove) s.Engine.GET("/api/upload/remove", h.Remove)
s.Engine.GET("/api/download", h.Download)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/") group := s.Engine.Group("/api/sms/")
@@ -280,13 +273,12 @@ func main() {
group.GET("slide/get", h.SlideGet) group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck) group.POST("slide/check", h.SlideCheck)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
group := s.Engine.Group("/api/reward/") group := s.Engine.Group("/api/redeem/")
group.POST("verify", h.Verify) group.POST("verify", h.Verify)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/") group := s.Engine.Group("/api/mj/")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.POST("upscale", h.Upscale) group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation) group.POST("variation", h.Variation)
@@ -297,7 +289,6 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall) group.GET("imgWall", h.ImgWall)
@@ -312,13 +303,12 @@ func main() {
// 管理后台控制器 // 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/") group := s.Engine.Group("/api/admin/config")
group.POST("config/update", h.Update) group.POST("update", h.Update)
group.GET("config/get", h.Get) group.GET("get", h.Get)
group.POST("active", h.Active) group.POST("active", h.Active)
group.GET("config/get/license", h.GetLicense) group.GET("fixData", h.FixData)
group.GET("config/get/app", h.GetAppConfig) group.GET("license", h.GetLicense)
group.POST("config/update/draw", h.SaveDrawingConfig)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/") group := s.Engine.Group("/api/admin/")
@@ -346,7 +336,7 @@ func main() {
group.GET("loginLog", h.LoginLog) group.GET("loginLog", h.LoginLog)
group.POST("resetPass", h.ResetPass) group.POST("resetPass", h.ResetPass)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ChatRoleHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
group := s.Engine.Group("/api/admin/role/") group := s.Engine.Group("/api/admin/role/")
group.GET("list", h.List) group.GET("list", h.List)
group.POST("save", h.Save) group.POST("save", h.Save)
@@ -354,10 +344,13 @@ func main() {
group.POST("set", h.Set) group.POST("set", h.Set)
group.GET("remove", h.Remove) group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
group := s.Engine.Group("/api/admin/reward/") group := s.Engine.Group("/api/admin/redeem/")
group.GET("list", h.List) group.GET("list", h.List)
group.POST("remove", h.Remove) group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("export", h.Export)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/") group := s.Engine.Group("/api/admin/dashboard/")
@@ -377,14 +370,12 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/") group := s.Engine.Group("/api/payment/")
group.GET("doPay", h.DoPay) group.POST("doPay", h.Pay)
group.GET("payWays", h.GetPayWays) group.GET("payWays", h.GetPayWays)
group.POST("qrcode", h.PayQrcode) group.POST("notify/alipay", h.AlipayNotify)
group.POST("mobile", h.Mobile) group.GET("notify/geek", h.GeekPayNotify)
group.POST("alipay/notify", h.AlipayNotify) group.POST("notify/wechat", h.WechatPayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify) group.POST("notify/hupi", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify)
group.POST("wechat/notify", h.WechatPayNotify)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/") group := s.Engine.Group("/api/admin/product/")
@@ -398,6 +389,7 @@ func main() {
group := s.Engine.Group("/api/admin/order/") group := s.Engine.Group("/api/admin/order/")
group.POST("list", h.List) group.POST("list", h.List)
group.GET("remove", h.Remove) group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) { fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
group := s.Engine.Group("/api/order/") group := s.Engine.Group("/api/order/")
@@ -413,7 +405,7 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) { fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/") group := s.Engine.Group("/api/invite/")
group.GET("code", h.Code) group.GET("code", h.Code)
group.POST("list", h.List) group.GET("list", h.List)
group.GET("hits", h.Hits) group.GET("hits", h.Hits)
}), }),
@@ -438,6 +430,7 @@ func main() {
group.POST("weibo", h.WeiBo) group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao) group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3) group.POST("dalle3", h.Dall3)
group.GET("list", h.List)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/") group := s.Engine.Group("/api/admin/chat/")
@@ -471,23 +464,21 @@ func main() {
}), }),
fx.Provide(handler.NewMarkMapHandler), fx.Provide(handler.NewMarkMapHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
group := s.Engine.Group("/api/markMap/") s.Engine.POST("/api/markMap/gen", h.Generate)
group.Any("client", h.Client)
}), }),
fx.Provide(handler.NewDallJobHandler), fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall") group := s.Engine.Group("/api/dall")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall) group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove) group.GET("remove", h.Remove)
group.GET("publish", h.Publish) group.GET("publish", h.Publish)
group.GET("models", h.GetModels)
}), }),
fx.Provide(handler.NewSunoHandler), fx.Provide(handler.NewSunoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
group := s.Engine.Group("/api/suno") group := s.Engine.Group("/api/suno")
group.Any("client", h.Client)
group.POST("create", h.Create) group.POST("create", h.Create)
group.GET("list", h.List) group.GET("list", h.List)
group.GET("remove", h.Remove) group.GET("remove", h.Remove)
@@ -495,13 +486,53 @@ func main() {
group.POST("update", h.Update) group.POST("update", h.Update)
group.GET("detail", h.Detail) group.GET("detail", h.Detail)
group.GET("play", h.Play) group.GET("play", h.Play)
group.POST("lyric", h.Lyric) }),
fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.POST("luma/create", h.LumaCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Provide(admin.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
group := s.Engine.Group("/api/admin/app/type")
group.POST("save", h.Save)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}),
fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type")
group.GET("list", h.List)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")
group.Any("sse", h.PostTest, h.SseTest)
}),
fx.Provide(service.NewWebsocketService),
fx.Provide(handler.NewWebsocketHandler),
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
s.Engine.Any("/api/ws", h.Client)
}),
fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt")
group.POST("/lyric", h.Lyric)
group.POST("/image", h.Image)
group.POST("/video", h.Video)
group.POST("/meta", h.MetaPrompt)
}), }),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) { fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() { go func() {
err := s.Run(db) err := s.Run(db)
if err != nil { if err != nil {
log.Fatal(err) logger.Error(err)
os.Exit(0)
} }
}() }()
}), }),
@@ -517,6 +548,26 @@ func main() {
}, },
}) })
}), }),
fx.Provide(admin.NewImageHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
group := s.Engine.Group("/api/admin/image")
group.POST("/list/mj", h.MjList)
group.POST("/list/sd", h.SdList)
group.POST("/list/dall", h.DallList)
group.GET("/remove", h.Remove)
}),
fx.Provide(admin.NewMediaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
group := s.Engine.Group("/api/admin/media")
group.POST("/list/suno", h.SunoList)
group.POST("/list/luma", h.LumaList)
group.GET("/remove", h.Remove)
}),
fx.Provide(handler.NewRealtimeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
s.Engine.Any("/api/realtime", h.Connection)
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
}),
) )
// 启动应用程序 // 启动应用程序
go func() { go func() {

BIN
api/res/img/geek-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
api/res/img/qq-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@@ -8,19 +8,19 @@ package dalle
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8" "io"
"time" "time"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -35,17 +35,21 @@ type Service struct {
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client userService *service.UserService
wsService *service.WebsocketService
clientIds map[uint]string
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](), wsService: wsService,
uploadManager: manager, uploadManager: manager,
userService: userService,
clientIds: map[uint]string{},
} }
} }
@@ -56,6 +60,20 @@ func (s *Service) PushTask(task types.DallTask) {
} }
func (s *Service) Run() { func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.DallJob
s.db.Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.DallTask
err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil {
logger.Errorf("decode task info with error: %v", err)
continue
}
task.Id = v.Id
s.PushTask(task)
}
logger.Info("Starting DALL-E job consumer...") logger.Info("Starting DALL-E job consumer...")
go func() { go func() {
for { for {
@@ -66,14 +84,15 @@ func (s *Service) Run() {
continue continue
} }
logger.Infof("handle a new DALL-E task: %+v", task) logger.Infof("handle a new DALL-E task: %+v", task)
s.clientIds[task.Id] = task.ClientId
_, err = s.Image(task, false) _, err = s.Image(task, false)
if err != nil { if err != nil {
logger.Errorf("error with image task: %v", err) logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"progress": -1, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
} }
} }
}() }()
@@ -82,17 +101,18 @@ func (s *Service) Run() {
type imgReq struct { type imgReq struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
N int `json:"n"` N int `json:"n,omitempty"`
Size string `json:"size"` Size string `json:"size,omitempty"`
Quality string `json:"quality"` Quality string `json:"quality,omitempty"`
Style string `json:"style"` Style string `json:"style,omitempty"`
} }
type imgRes struct { type imgRes struct {
Created int64 `json:"created"` Created int64 `json:"created"`
Data []struct { Data []struct {
RevisedPrompt string `json:"revised_prompt"` RevisedPrompt string `json:"revised_prompt,omitempty"`
Url string `json:"url"` Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
} `json:"data"` } `json:"data"`
} }
@@ -110,45 +130,27 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
prompt := task.Prompt prompt := task.Prompt
// translate prompt // translate prompt
if utils.HasChinese(prompt) { if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
if err == nil { if err == nil {
prompt = content prompt = content
logger.Debugf("重写后提示词:%s", prompt) logger.Debugf("重写后提示词:%s", prompt)
} }
} }
var user model.User var chatModel model.ChatModel
s.db.Where("id", task.UserId).First(&user) s.db.Where("id = ?", task.ModelId).First(&chatModel)
if user.Power < task.Power {
return "", errors.New("insufficient of power")
}
// 更新用户算力
tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
// get image generation API KEY // get image generation API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
tx = s.db.Where("type", "dalle"). session := s.db.Where("enabled", true)
Where("enabled", true). if chatModel.KeyId > 0 {
Order("last_used_at ASC").First(&apiKey) session = session.Where("id = ?", chatModel.KeyId)
if tx.Error != nil { } else {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error) session = session.Where("type = ?", "dalle")
}
err := session.Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return "", fmt.Errorf("no available Image Generation api key: %v", err)
} }
var res imgRes var res imgRes
@@ -158,7 +160,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
} }
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL) apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{ reqBody := imgReq{
Model: "dall-e-3", Model: chatModel.Value,
Prompt: prompt, Prompt: prompt,
N: 1, N: 1,
Size: task.Size, Size: task.Size,
@@ -166,35 +168,54 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
Quality: task.Quality, Quality: task.Quality,
} }
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody) logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value). SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody). SetBody(reqBody).
SetErrorResult(&errRes). SetErrorResult(&errRes).
SetSuccessResult(&res). SetSuccessResult(&res).
Post(apiURL) Post(apiURL)
if err != nil { if err != nil {
logger.Errorf("error with send request: %v", err)
return "", fmt.Errorf("error with send request: %v", err) return "", fmt.Errorf("error with send request: %v", err)
} }
if r.IsErrorState() { if r.IsErrorState() {
logger.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error) return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
} }
all, _ := io.ReadAll(r.Body)
logger.Debugf("response: %+v", string(all))
// update the api key last use time // update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress var imgURL string
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ var data = map[string]interface{}{
"progress": 100, "progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt, "prompt": prompt,
}) }
if tx.Error != nil { // 如果返回的是base64则需要上传到oss
return "", fmt.Errorf("err with update database: %v", tx.Error) if res.Data[0].B64Json != "" {
imgURL, err = s.uploadManager.GetUploadHandler().PutBase64(res.Data[0].B64Json)
if err != nil {
return "", fmt.Errorf("error with upload image: %v", err)
}
logger.Infof("upload image to oss: %s", imgURL)
data["img_url"] = imgURL
} else {
imgURL = res.Data[0].Url
}
data["org_url"] = imgURL
// update task progress
err = s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(data).Error
if err != nil {
return "", fmt.Errorf("err with update database: %v", err)
} }
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
var content string var content string
if sync { if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) imgURL, err := s.downloadImage(task.Id, int(task.UserId), res.Data[0].Url)
if err != nil { if err != nil {
return "", fmt.Errorf("error with download image: %v", err) return "", fmt.Errorf("error with download image: %v", err)
} }
@@ -208,19 +229,58 @@ func (s *Service) CheckTaskNotify() {
go func() { go func() {
logger.Info("Running DALL-E task notify checking ...") logger.Info("Running DALL-E task notify checking ...")
for { for {
var message sd.NotifyMessage var message service.NotifyMessage
err := s.notifyQueue.LPop(&message) err := s.notifyQueue.LPop(&message)
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChDall, message.Message)
if err != nil { }
continue }()
}
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running DALL-E task status checking ...")
for {
// 检查未完成任务进度
var jobs []model.DallJob
s.db.Where("progress < ?", 100).Find(&jobs)
for _, job := range jobs {
// 超时的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
} }
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
var task types.DallTask
err := utils.JsonDecode(job.TaskInfo, &task)
if err != nil {
continue
}
err = s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: task.ModelName,
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 10)
} }
}() }()
} }
@@ -268,47 +328,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil { if res.Error != nil {
return "", err return "", err
} }
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil return imgURL, nil
} }
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@@ -8,13 +8,16 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/store" "geekai/store"
"strings"
"time" "time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"github.com/shirou/gopsutil/host"
) )
type LicenseService struct { type LicenseService struct {
@@ -27,11 +30,18 @@ type LicenseService struct {
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService { func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License var license types.License
var machineId string
_ = levelDB.Get(types.LicenseKey, &license)
info, err := host.Info()
if err == nil {
machineId = info.HostID
}
logger.Infof("License: %+v", license)
return &LicenseService{ return &LicenseService{
config: server.Config.ApiConfig, config: server.Config.ApiConfig,
levelDB: levelDB, levelDB: levelDB,
license: &license, license: &license,
machineId: "", machineId: machineId,
} }
} }
@@ -109,33 +119,30 @@ func (s *LicenseService) SyncLicense() {
} }
func (s *LicenseService) fetchLicense() (*types.License, error) { func (s *LicenseService) fetchLicense() (*types.License, error) {
//var res struct { var res struct {
// Code types.BizCode `json:"code"` Code types.BizCode `json:"code"`
// Message string `json:"message"` Message string `json:"message"`
// Data License `json:"data"` Data License `json:"data"`
//} }
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
//response, err := req.C().R(). response, err := req.C().R().
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
// SetSuccessResult(&res).Post(apiURL) SetSuccessResult(&res).Post(apiURL)
//if err != nil { if err != nil {
// return nil, fmt.Errorf("发送激活请求失败: %v", err) return nil, fmt.Errorf("发送激活请求失败: %v", err)
//} }
//if response.IsErrorState() { if response.IsErrorState() {
// return nil, fmt.Errorf("激活失败:%v", response.Status) return nil, fmt.Errorf("激活失败:%v", response.Status)
//} }
//if res.Code != types.Success { if res.Code != types.Success {
// return nil, fmt.Errorf("激活失败:%v", res.Message) return nil, fmt.Errorf("激活失败:%v", res.Message)
//} }
return &types.License{ return &types.License{
Key: "abc", Key: res.Data.License,
MachineId: "abc", MachineId: res.Data.MachineId,
Configs: types.LicenseConfig{ Configs: res.Data.Configs,
UserNum: 10000, ExpiredAt: res.Data.ExpiredAt,
DeCopy: false,
},
ExpiredAt: 0,
IsActive: true, IsActive: true,
}, nil }, nil
} }
@@ -169,29 +176,28 @@ func (s *LicenseService) GetLicense() *types.License {
// IsValidApiURL 判断是否合法的中转 URL // IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error { func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行 // 获得许可授权的直接放行
return nil if s.license.IsActive {
//if s.license.IsActive { if s.license.MachineId != s.machineId {
// if s.license.MachineId != s.machineId { return errors.New("系统使用了盗版的许可证书")
// return errors.New("系统使用了盗版的许可证书") }
// }
// if time.Now().Unix() > s.license.ExpiredAt {
// if time.Now().Unix() > s.license.ExpiredAt { return errors.New("系统许可证书已经过期")
// return errors.New("系统许可证书已经过期") }
// } return nil
// return nil }
//}
// if len(s.urlWhiteList) == 0 {
//if len(s.urlWhiteList) == 0 { urls, err := s.fetchUrlWhiteList()
// urls, err := s.fetchUrlWhiteList() if err == nil {
// if err == nil { s.urlWhiteList = urls
// s.urlWhiteList = urls }
// } }
//}
// for _, v := range s.urlWhiteList {
//for _, v := range s.urlWhiteList { if strings.HasPrefix(uri, v) {
// if strings.HasPrefix(uri, v) { return nil
// return nil }
// } }
//} return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
} }

View File

@@ -7,15 +7,28 @@ package mj
// * @Author yangjian102621@163.com // * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "geekai/core/types" import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"time"
type Client interface { "github.com/gin-gonic/gin"
Imagine(task types.MjTask) (ImageRes, error) )
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error) // Client MidJourney client
Upscale(task types.MjTask) (ImageRes, error) type Client struct {
Variation(task types.MjTask) (ImageRes, error) client *req.Client
QueryTask(taskId string) (QueryRes, error) licenseService *service.LicenseService
db *gorm.DB
} }
type ImageReq struct { type ImageReq struct {
@@ -33,13 +46,8 @@ type ImageRes struct {
Description string `json:"description"` Description string `json:"description"`
Properties struct { Properties struct {
} `json:"properties"` } `json:"properties"`
Result string `json:"result"` Result string `json:"result"`
} Channel string `json:"channel,omitempty"`
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} }
type QueryRes struct { type QueryRes struct {
@@ -66,3 +74,177 @@ type QueryRes struct {
Status string `json:"status"` Status string `json:"status"`
SubmitTime int `json:"submitTime"` SubmitTime int `json:"submitTime"`
} }
var logger = logger2.GetLogger()
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
return &Client{
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
db: db,
}
}
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
var res ImageRes
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
if channel != "" {
session = session.Where("api_url", channel)
}
var apiKey model.ApiKey
err := session.Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(body).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
errMsg, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
}
// update the api key last used time
if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
logger.Error("update api key last used time error: ", err)
}
res.Channel = apiKey.ApiURL
return res, nil
}
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
var apiKey model.ApiKey
err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
if err != nil {
return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}

View File

@@ -1,204 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/utils"
"github.com/imroc/req/v3"
"time"
"github.com/gin-gonic/gin"
)
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
licenseService *service.LicenseService
}
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
}
}
func (c *PlusClient) preCheck() error {
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
return c.doRequest(body, apiURL)
}
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
return c.doRequest(body, apiURL)
}
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
return c.doRequest(body, apiURL)
}
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
return c.doRequest(body, apiURL)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
return c.doRequest(body, apiURL)
}
func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) {
var res ImageRes
var errRes ErrRes
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &PlusClient{}

View File

@@ -1,207 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"strings"
"time"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
licenseService *service.LicenseService
}
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
licenseService: licenseService,
}
}
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for _, config := range plusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config, p.licenseService)
name := utils.Md5(config.ApiURL)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
plusService.Run()
}()
p.services = append(p.services, plusService)
}
// for mid-journey proxy
for _, config := range proxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := utils.Md5(config.ApiURL)
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
proxyService.Run()
}()
p.services = append(p.services, proxyService)
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
mjService := p.getService(v.ChannelId)
if mjService == nil {
logger.Errorf("Invalid task: %+v", v)
continue
}
task, _ := mjService.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second * 10)
}
}()
}
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if s.Name == name {
return s
}
}
return nil
}

View File

@@ -1,185 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -11,10 +11,11 @@ import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/sd" "geekai/service/oss"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8"
"strings" "strings"
"time" "time"
@@ -23,127 +24,132 @@ import (
// Service MJ 绘画服务 // Service MJ 绘画服务
type Service struct { type Service struct {
Name string // service Name client *Client // MJ Client
Client Client // MJ Client taskQueue *store.RedisQueue
taskQueue *store.RedisQueue notifyQueue *store.RedisQueue
notifyQueue *store.RedisQueue db *gorm.DB
db *gorm.DB wsService *service.WebsocketService
running bool uploaderManager *oss.UploaderManager
retryCount map[uint]int userService *service.UserService
clientIds map[uint]string
} }
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service {
return &Service{ return &Service{
Name: name, db: db,
db: db, taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
taskQueue: taskQueue, notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
notifyQueue: notifyQueue, client: client,
Client: cli, wsService: wsService,
running: true, uploaderManager: manager,
retryCount: make(map[uint]int), clientIds: map[uint]string{},
userService: userService,
} }
} }
const failedProgress = 101
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name) // 将数据库中未提交的人物加载到队列
for s.running { var jobs []model.MidJourneyJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.MjTask var task types.MjTask
err := s.taskQueue.LPop(&task) err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("decode task info with error: %v", err)
continue continue
} }
task.Id = v.Id
s.clientIds[task.Id] = task.ClientId
s.PushTask(task)
}
// 如果配置了多个中转平台的 API KEY logger.Info("Starting MidJourney job consumer for service")
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表 go func() {
if task.ChannelId != "" && task.ChannelId != s.Name { for {
if s.retryCount[task.Id] > 5 { var task types.MjTask
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{}) err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue continue
} }
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
s.retryCount[task.Id]++
time.Sleep(time.Second)
continue
}
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil { if err == nil {
task.Prompt = content task.Prompt = content
} else { } else {
logger.Warnf("error with translate prompt: %v", err) logger.Warnf("error with translate prompt: %v", err)
}
} }
} // translate negative prompt
// translate negative prompt if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini") if err == nil {
if err == nil { task.NegPrompt = content
task.NegPrompt = content } else {
} else { logger.Warnf("error with translate prompt: %v", err)
logger.Warnf("error with translate prompt: %v", err) }
}
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
} }
logger.Error("绘画任务执行失败:", errMsg) // use fast mode as default
job.Progress = failedProgress if task.Mode == "" {
job.ErrMsg = errMsg task.Mode = "fast"
// update the task progress }
s.clientIds[task.Id] = task.ClientId
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
logger.Infof("handle a new MidJourney task: %+v", task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.client.Upscale(task)
break
case types.TaskVariation:
res, err = s.client.Variation(task)
break
case types.TaskBlend:
res, err = s.client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = service.FailTaskProgress
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = res.Channel
s.db.Updates(&job) s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
} }
logger.Infof("任务提交成功:%+v", res) }()
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
func (s *Service) Stop() {
s.running = false
} }
type CBReq struct { type CBReq struct {
@@ -164,46 +170,6 @@ type CBReq struct {
} `json:"properties"` } `json:"properties"`
} }
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": failedProgress,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}
func GetImageHash(action string) string { func GetImageHash(action string) string {
split := strings.Split(action, "::") split := strings.Split(action, "::")
if len(split) > 5 { if len(split) > 5 {
@@ -211,3 +177,160 @@ func GetImageHash(action string) string {
} }
return split[len(split)-1] return split[len(split)-1]
} }
func (s *Service) CheckTaskNotify() {
go func() {
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("receive a new mj notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChMj, message.Message)
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[v.Id],
UserId: v.UserId,
JobId: int(v.Id),
Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
// 10 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
continue
}
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
if err != nil {
logger.Errorf("error with query task: %v", err)
continue
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": task.FailReason,
})
logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFailed})
continue
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
err = s.db.Updates(&job).Error
if err != nil {
logger.Errorf("error with update database: %v", err)
continue
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := service.TaskStatusRunning
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: message})
}
}
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "mid-journey",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -89,7 +89,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
fileExt := utils.GetImgExt(file.Filename) fileExt := utils.GetImgExt(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Content-Type"), ContentType: file.Header.Get("Body-Type"),
}) })
if err != nil { if err != nil {
return File{}, fmt.Errorf("error uploading to MinIO: %v", err) return File{}, fmt.Errorf("error uploading to MinIO: %v", err)

View File

@@ -43,10 +43,8 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
//client.DebugSwitch = gopay.DebugOn // 开启调试模式 //client.DebugSwitch = gopay.DebugOn // 开启调试模式
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间 client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8 SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2 SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
SetReturnUrl(config.ReturnURL). // 设置返回URL
SetNotifyUrl(config.NotifyURL)
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil { if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with load payment public key: %v", err) return nil, fmt.Errorf("error with load payment public key: %v", err)
@@ -55,23 +53,31 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
return &AlipayService{config: &config, client: client}, nil return &AlipayService{config: &config, client: client}, nil
} }
func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) { type AlipayParams struct {
bm := make(gopay.BodyMap) OutTradeNo string `json:"out_trade_no"`
bm.Set("subject", subject) Subject string `json:"subject"`
bm.Set("out_trade_no", outTradeNo) TotalFee string `json:"total_fee"`
bm.Set("quit_url", s.config.ReturnURL) ReturnURL string `json:"return_url"`
bm.Set("total_amount", amount) NotifyURL string `json:"notify_url"`
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.TradeWapPay(context.Background(), bm)
} }
func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) { func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap) bm := make(gopay.BodyMap)
bm.Set("subject", subject) bm.Set("subject", params.Subject)
bm.Set("out_trade_no", outTradeNo) bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", amount) bm.Set("quit_url", params.ReturnURL)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY") bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.TradePagePay(context.Background(), bm) return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
} }
// TradeVerify 交易验证 // TradeVerify 交易验证

View File

@@ -0,0 +1,139 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
// GeekPayService Geek 支付服务
type GeekPayService struct {
config *types.GeekPayConfig
}
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
return &GeekPayService{
config: &appConfig.GeekPayConfig,
}
}
type GeekPayParams struct {
Method string `json:"method"` // 接口类型
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
}
// Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
p := map[string]string{
"pid": s.config.AppId,
//"method": params.Method,
"device": params.Device,
"type": params.Type,
"out_trade_no": params.OutTradeNo,
"name": params.Name,
"money": params.Money,
"clientip": params.ClientIP,
"notify_url": params.NotifyURL,
"return_url": params.ReturnURL,
"timestamp": fmt.Sprintf("%d", time.Now().Unix()),
}
p["sign"] = s.Sign(p)
p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p)
}
func (s *GeekPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数
var keys []string
for k := range params {
if params[k] == "" || k == "sign" || k == "sign_type" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
// 构建待签名字符串
var signStr strings.Builder
for _, k := range keys {
signStr.WriteString(k)
signStr.WriteString("=")
signStr.WriteString(params[k])
signStr.WriteString("&")
}
signString := strings.TrimSuffix(signStr.String(), "&") + s.config.PrivateKey
return utils.Md5(signString)
}
type GeekPayResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
TradeNo string `json:"trade_no"`
PayURL string `json:"payurl"`
QrCode string `json:"qrcode"`
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
}
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{}
for k, v := range params {
form.Add(k, v)
}
apiURL := fmt.Sprintf("%s/mapi.php", endpoint)
logger.Infof(apiURL)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // 取消 SSL 证书验证
},
}
client := &http.Client{Transport: tr}
resp, err := client.PostForm(apiURL, form)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
logger.Debugf(string(body))
if err != nil {
return nil, err
}
var r GeekPayResp
err = json.Unmarshal(body, &r)
if err != nil {
return nil, errors.New("当前支付渠道暂不支持")
}
if r.Code != 1 {
return nil, errors.New(r.Msg)
}
return &r, nil
}

View File

@@ -37,7 +37,7 @@ func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
} }
} }
type HuPiPayReq struct { type HuPiPayParams struct {
AppId string `json:"appid"` AppId string `json:"appid"`
Version string `json:"version"` Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"` TradeOrderId string `json:"trade_order_id"`
@@ -53,7 +53,7 @@ type HuPiPayReq struct {
WapUrl string `json:"wap_url"` WapUrl string `json:"wap_url"`
} }
type HuPiResp struct { type HuPiPayResp struct {
Openid interface{} `json:"openid"` Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"` UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"` URL string `json:"url"`
@@ -62,7 +62,7 @@ type HuPiResp struct {
} }
// Pay 执行支付请求操作 // Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) { func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{} data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10) simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId params.AppId = s.appId
@@ -80,22 +80,22 @@ func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL) apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data) resp, err := http.PostForm(apiURL, data)
if err != nil { if err != nil {
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err) return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
all, err := io.ReadAll(resp.Body) all, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err) return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
} }
var res HuPiResp var res HuPiPayResp
err = utils.JsonDecode(string(all), &res) err = utils.JsonDecode(string(all), &res)
if err != nil { if err != nil {
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err) return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
} }
if res.ErrCode != 0 { if res.ErrCode != 0 {
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg) return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
} }
return res, nil return res, nil
@@ -127,10 +127,10 @@ func (s *HuPiPayService) Sign(params url.Values) string {
} }
// Check 校验订单状态 // Check 校验订单状态
func (s *HuPiPayService) Check(tradeNo string) error { func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{} data := url.Values{}
data.Add("appid", s.appId) data.Add("appid", s.appId)
data.Add("open_order_id", tradeNo) data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10) stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp) data.Add("time", stamp)
data.Add("nonce_str", stamp) data.Add("nonce_str", stamp)

View File

@@ -1,153 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
type JPayService struct {
config *types.JPayConfig
}
func NewJPayService(appConfig *types.AppConfig) *JPayService {
return &JPayService{
config: &appConfig.JPayConfig,
}
}
type JPayReq struct {
TotalFee int `json:"total_fee"`
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"body"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"callback_url"`
}
type JPayReps struct {
OutTradeNo string `json:"out_trade_no"`
OrderId string `json:"payjs_order_id"`
ReturnCode int `json:"return_code"`
ReturnMsg string `json:"return_msg"`
Sign string `json:"Sign"`
TotalFee string `json:"total_fee"`
CodeUrl string `json:"code_url,omitempty"`
Qrcode string `json:"qrcode,omitempty"`
}
func (r JPayReps) IsOK() bool {
return r.ReturnMsg == "SUCCESS"
}
func (js *JPayService) Pay(param JPayReq) JPayReps {
param.NotifyURL = js.config.NotifyURL
var p = url.Values{}
encode := utils.JsonEncode(param)
m := make(map[string]interface{})
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
p.Add(k, fmt.Sprintf("%v", v))
}
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
cli := http.Client{}
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
r, err := cli.PostForm(apiURL, p)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
defer r.Body.Close()
bs, err := io.ReadAll(r.Body)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
var data JPayReps
err = utils.JsonDecode(string(bs), &data)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
return data
}
func (js *JPayService) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *JPayService) sign(params url.Values) string {
params.Del(`sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += "&key=" + js.config.PrivateKey
md5bs := md5.Sum([]byte(src))
md5res := hex.EncodeToString(md5bs[:])
return strings.ToUpper(md5res)
}
// TradeVerify 查询订单支付状态
// @param tradeNo 支付平台交易 ID
func (js *JPayService) TradeVerify(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{}
params.Add("payjs_order_id", tradeNo)
params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ReturnCode int `json:"return_code"`
Status int `json:"status"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ReturnCode == 1 && r.Status == 1 {
return nil
} else {
logger.Errorf("PayJs 支付验证响应:%s", string(body))
return errors.New("order not paid")
}
}

View File

@@ -46,18 +46,27 @@ func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
return &WechatPayService{config: &config, client: client}, nil return &WechatPayService{config: &config, client: client}, nil
} }
func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) { type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap // 初始化 BodyMap
bm := make(gopay.BodyMap) bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId). bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId). Set("mchid", s.config.MchId).
Set("description", subject). Set("description", params.Subject).
Set("out_trade_no", outTradeNo). Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire). Set("time_expire", expire).
Set("notify_url", s.config.NotifyURL). Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) { SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", amount). bm.Set("total", params.TotalFee).
Set("currency", "CNY") Set("currency", "CNY")
}) })
@@ -71,22 +80,22 @@ func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject s
return wxRsp.Response.CodeUrl, nil return wxRsp.Response.CodeUrl, nil
} }
func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) { func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap // 初始化 BodyMap
bm := make(gopay.BodyMap) bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId). bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId). Set("mchid", s.config.MchId).
Set("description", subject). Set("description", params.Subject).
Set("out_trade_no", outTradeNo). Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire). Set("time_expire", expire).
Set("notify_url", s.config.NotifyURL). Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) { SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", amount). bm.Set("total", params.TotalFee).
Set("currency", "CNY") Set("currency", "CNY")
}). }).
SetBodyMap("scene_info", func(bm gopay.BodyMap) { SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", ip). bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) { SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap") bm.Set("type", "Wap")
}) })

View File

@@ -1,143 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploader *oss.UploaderManager
levelDB *store.LevelDB
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
uploader: manager,
levelDB: levelDB,
}
}
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range configs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf(" sd-service-%d", k)
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
// run sd service
go func() {
service.Run()
}()
p.services = append(p.services, service)
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
p.db.Delete(&job)
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -10,95 +10,104 @@ package sd
import ( import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"strings" "github.com/go-redis/redis/v8"
"time" "time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
) )
var logger = logger2.GetLogger()
// SD 绘画服务 // SD 绘画服务
type Service struct { type Service struct {
httpClient *req.Client httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
name string // service name wsService *service.WebsocketService
leveldb *store.LevelDB userService *service.UserService
running bool // 运行状态
} }
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{ return &Service{
name: name,
config: config,
httpClient: req.C(), httpClient: req.C(),
taskQueue: taskQueue, taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: notifyQueue, notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
leveldb: levelDB, wsService: wsService,
uploadManager: manager, uploadManager: manager,
running: true, userService: userService,
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) // 将数据库中未提交的人物加载到队列
for s.running { var jobs []model.SdJob
s.db.Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.SdTask var task types.SdTask
err := s.taskQueue.LPop(&task) err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("decode task info with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
continue continue
} }
task.Id = int(v.Id)
s.PushTask(task)
} }
} logger.Infof("Starting Stable-Diffusion job consumer")
go func() {
for {
var task types.SdTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
func (s *Service) Stop() { // translate prompt
s.running = false if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.Prompt), task.TranslateModelId)
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), task.TranslateModelId)
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
}
}
}()
} }
// Txt2ImgReq 文生图请求实体 // Txt2ImgReq 文生图请求实体
@@ -130,9 +139,8 @@ type Txt2ImgResp struct {
// TaskProgressResp 任务进度响应实体 // TaskProgressResp 任务进度响应实体
type TaskProgressResp struct { type TaskProgressResp struct {
Progress float64 `json:"progress"` Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"` EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
} }
// Txt2Img 文生图 API // Txt2Img 文生图 API
@@ -160,12 +168,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
var res Txt2ImgResp var res Txt2ImgResp
var errChan = make(chan error) var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
logger.Debugf("send image request to %s", apiURL) var apiKey model.ApiKey
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
}
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
logger.Infof("send image request to %s", apiURL)
// send a request to sd api endpoint // send a request to sd api endpoint
go func() { go func() {
response, err := s.httpClient.R(). response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey). SetHeader("Authorization", apiKey.Value).
SetBody(body). SetBody(body).
SetSuccessResult(&res). SetSuccessResult(&res).
Post(apiURL) Post(apiURL)
@@ -178,6 +193,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return return
} }
// update the last used time
apiKey.LastUsedAt = time.Now().Unix()
s.db.Updates(&apiKey)
// 保存 Base64 图片 // 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil { if err != nil {
@@ -206,21 +225,15 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
default: default:
err, resp := s.checkTaskProgress() err, resp := s.checkTaskProgress(apiKey)
// 更新任务进度 // 更新任务进度
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号 // 发送更新状态信号
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -229,11 +242,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
// 执行任务 // 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
var res TaskProgressResp var res TaskProgressResp
response, err := s.httpClient.R(). response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey). SetHeader("Authorization", apiKey.Value).
SetSuccessResult(&res). SetSuccessResult(&res).
Get(apiURL) Get(apiURL)
if err != nil { if err != nil {
@@ -245,3 +258,67 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
return nil, &res return nil, &res
} }
func (s *Service) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChSd, message.Message)
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d Err: %s", job.Id, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -1,24 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import logger2 "geekai/logger"
var logger = logger2.GetLogger()
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)

View File

@@ -29,7 +29,7 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
func (s *SmtpService) SendVerifyCode(to string, code int) error { func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName) subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", s.config.AppName, code) body := fmt.Sprintf("【%s】您的验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", s.config.AppName, code)
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host) auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
if s.config.UseTls { if s.config.UseTls {

View File

@@ -13,8 +13,8 @@ import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
@@ -34,17 +34,21 @@ type Service struct {
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client wsService *service.WebsocketService
clientIds map[string]string
userService *service.UserService
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
return &Service{ return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager, uploadManager: manager,
wsService: wsService,
clientIds: map[string]string{},
userService: userService,
} }
} }
@@ -56,22 +60,17 @@ func (s *Service) PushTask(task types.SunoTask) {
func (s *Service) Run() { func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列 // 将数据库中未提交的人物加载到队列
var jobs []model.SunoJob var jobs []model.SunoJob
s.db.Where("task_id", "").Find(&jobs) s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs { for _, v := range jobs {
s.PushTask(types.SunoTask{ var task types.SunoTask
Id: v.Id, err := utils.JsonDecode(v.TaskInfo, &task)
Channel: v.Channel, if err != nil {
UserId: v.UserId, logger.Errorf("decode task info with error: %v", err)
Type: v.Type, continue
Title: v.Title, }
RefTaskId: v.RefTaskId, task.Id = v.Id
RefSongId: v.RefSongId, s.PushTask(task)
Prompt: v.Prompt, s.clientIds[v.TaskId] = task.ClientId
Tags: v.Tags,
Model: v.ModelName,
Instrumental: v.Instrumental,
ExtendSecs: v.ExtendSecs,
})
} }
logger.Info("Starting Suno job consumer...") logger.Info("Starting Suno job consumer...")
go func() { go func() {
@@ -82,14 +81,21 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
continue continue
} }
var r RespVo
r, err := s.Create(task) if task.Type == 3 && task.SongId != "" { // 歌曲拼接
r, err = s.Merge(task)
} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
r, err = s.Upload(task)
} else { // 歌曲创作
r, err = s.Create(task)
}
if err != nil { if err != nil {
logger.Errorf("create task with error: %v", err) logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(), "err_msg": err.Error(),
"progress": 101, "progress": service.FailTaskProgress,
}) })
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue continue
} }
@@ -98,6 +104,7 @@ func (s *Service) Run() {
"task_id": r.Data, "task_id": r.Data,
"channel": r.Channel, "channel": r.Channel,
}) })
s.clientIds[r.Data] = task.ClientId
} }
}() }()
} }
@@ -138,7 +145,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
} }
var res RespVo var res RespVo
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL) apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R(). r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value). SetHeader("Authorization", "Bearer "+apiKey.Value).
@@ -157,6 +164,100 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
if res.Code != "success" { if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message) return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
} }
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"clip_id": task.SongId,
"is_infill": false,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"url": task.AudioURL,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL res.Channel = apiKey.ApiURL
return res, nil return res, nil
} }
@@ -165,24 +266,24 @@ func (s *Service) CheckTaskNotify() {
go func() { go func() {
logger.Info("Running Suno task notify checking ...") logger.Info("Running Suno task notify checking ...")
for { for {
var message sd.NotifyMessage var message service.NotifyMessage
err := s.notifyQueue.LPop(&message) err := s.notifyQueue.LPop(&message)
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId)) logger.Debugf("notify message: %+v", message)
logger.Debugf("client id: %+v", s.wsService.Clients)
client := s.wsService.Clients.Get(message.ClientId)
logger.Debugf("%+v", client)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChSuno, message.Message)
if err != nil {
continue
}
} }
}() }()
} }
func (s *Service) DownloadImages() { func (s *Service) DownloadFiles() {
go func() { go func() {
var items []model.SunoJob var items []model.SunoJob
for { for {
@@ -210,7 +311,7 @@ func (s *Service) DownloadImages() {
v.AudioURL = audioURL v.AudioURL = audioURL
v.Progress = 100 v.Progress = 100
s.db.Updates(&v) s.db.Updates(&v)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
@@ -276,15 +377,29 @@ func (s *Service) SyncTaskProgress() {
} }
} }
tx.Commit() tx.Commit()
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
} else if task.Data.FailReason != "" { } else if task.Data.FailReason != "" {
job.Progress = 101 job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason job.ErrMsg = task.Data.FailReason
s.db.Updates(&job) s.db.Updates(&job)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
} }
} }
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
} }
}() }()
@@ -328,15 +443,15 @@ type QueryRespVo struct {
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) { func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
// 读取 API KEY // 读取 API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
tx := s.db.Session(&gorm.Session{}).Where("type", "suno"). err := s.db.Session(&gorm.Session{}).Where("type", "suno").
Where("api_url", channel). Where("api_url", channel).
Where("enabled", true). Where("enabled", true).
Order("last_used_at DESC").First(&apiKey) Order("last_used_at DESC").First(&apiKey).Error
if tx.Error != nil { if err != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno") return QueryRespVo{}, errors.New("no available API KEY for Suno")
} }
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId) apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo var res QueryRespVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)

View File

@@ -1,4 +1,166 @@
package service package service
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const FailTaskProgress = 101
const (
TaskStatusRunning = "RUNNING"
TaskStatusFinished = "FINISH"
TaskStatusFailed = "FAIL"
)
type NotifyMessage struct {
UserId int `json:"user_id"`
ClientId string `json:"client_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
const ImagePromptOptimizeTemplate = `
Create a highly effective prompt to provide to an AI image generation tool in order to create an artwork based on a desired concept.
Please specify details about the artwork, such as the style, subject, mood, and other important characteristics you want the resulting image to have.
Remember, prompts should always be output in English.
# Steps
1. **Subject Description**: Describe the main subject of the image clearly. Include as much detail as possible about what should be in the scene. For example, "a majestic lion roaring at sunrise" or "a futuristic city with flying cars."
2. **Art Style**: Specify the art style you envision. Possible options include 'realistic', 'impressionist', a specific artist name, or imaginative styles like "cyberpunk." This helps the AI achieve your visual expectations.
3. **Mood or Atmosphere**: Convey the feeling you want the image to evoke. For instance, peaceful, chaotic, epic, etc.
4. **Color Palette and Lighting**: Mention color preferences or lighting. For example, "vibrant with shades of blue and purple" or "dim and dramatic lighting."
5. **Optional Features**: You can add any additional attributes, such as background details, attention to textures, or any specific kind of framing.
# Output Format
- **Prompt Format**: A descriptive phrase that includes key aspects of the artwork (subject, style, mood, colors, lighting, any optional features).
Here is an example of how the final prompt should look:
"An ethereal landscape featuring towering ice mountains, in an impressionist style reminiscent of Claude Monet, with a serene mood. The sky is glistening with soft purples and whites, with a gentle morning sun illuminating the scene."
**Please input the prompt words directly in English, and do not input any other explanatory statements**
# Examples
1. **Input**:
- Subject: A white tiger in a dense jungle
- Art Style: Realistic
- Mood: Intense, mysterious
- Lighting: Dramatic contrast with light filtering through leaves
**Output Prompt**: "A realistic rendering of a white tiger stealthily moving through a dense jungle, with an intense, mysterious mood. The lighting creates strong contrasts as beams of sunlight filter through a thick canopy of leaves."
2. **Input**:
- Subject: An enchanted castle on a floating island
- Art Style: Fantasy
- Mood: Majestic, magical
- Colors: Bright blues, greens, and gold
**Output Prompt**: "A majestic fantasy castle on a floating island above the clouds, with bright blues, greens, and golds to create a magical, dreamy atmosphere. Textured cobblestone details and glistening waters surround the scene."
# Notes
- Ensure that you mix different aspects to get a comprehensive and visually compelling prompt.
- Be as descriptive as possible as it often helps generate richer, more detailed images.
- If you want the image to resemble a particular artist's work, be sure to mention the artist explicitly. e.g., "in the style of Van Gogh."
The theme of the creation is:【%s】
`
const LyricPromptTemplate = `
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
请以【%s】为主题创作一首歌曲歌曲时间不要太短3分钟左右不要输出任何解释性的内容。
输出格式如下:
歌曲名称
第一节:
{{歌词内容}}
副歌:
{{歌词内容}}
第二节:
{{歌词内容}}
副歌:
{{歌词内容}}
尾声:
{{歌词内容}}
`
const VideoPromptTemplate = `
As an expert in video generation prompts, please create a detailed descriptive prompt for the following video concept. The description should include the setting, character appearance, actions, overall atmosphere, and camera angles. Please make it as detailed and vivid as possible to help ensure that every aspect of the video is accurately captured.
Please remember that regardless of the users input, the final output must be in English.
# Details to Include
- Describe the overall visual style of the video (e.g., animated, realistic, retro tone, etc.)
- Identify key characters or objects in the video and describe their appearance, attire, and expressions
- Describe the environment of the scene, including weather, lighting, colors, and important details
- Explain the behavior and interactions of the characters
- Include any unique camera angles, movements, or special effects
# Output Format
Provide the prompt in paragraph form, ensuring that the description is detailed enough for a video generation system to recreate the envisioned scene. Include the beginning, middle, and end of the scene to convey a complete storyline.
# Example
**User Input:**
“A small cat basking in the sun on a balcony.”
**Generated Prompt:**
On a bright spring afternoon, an orange-striped kitten lies lazily on a balcony, basking in the warm sunlight. The iron railings around the balcony cast soft shadows that dance gently with the light. The cats eyes are half-closed, exuding a sense of contentment and tranquility in its surroundings. In the distance, a few fluffy white clouds drift slowly across the blue sky. The camera initially focuses on the cats face, capturing the delicate details of its fur, and then gradually zooms out to reveal the full balcony scene, immersing viewers in a moment of calm and relaxation.
The theme of the creation is:【%s】
`
const MetaPromptTemplate = `
Given a task description or existing prompt, produce a detailed system prompt to guide a language model in completing the task effectively.
Please remember, the final output must be the same language with users input.
# Guidelines
- Understand the Task: Grasp the main objective, goals, requirements, constraints, and expected output.
- Minimal Changes: If an existing prompt is provided, improve it only if it's simple. For complex prompts, enhance clarity and add missing elements without altering the original structure.
- Reasoning Before Conclusions**: Encourage reasoning steps before any conclusions are reached. ATTENTION! If the user provides examples where the reasoning happens afterward, REVERSE the order! NEVER START EXAMPLES WITH CONCLUSIONS!
- Reasoning Order: Call out reasoning portions of the prompt and conclusion parts (specific fields by name). For each, determine the ORDER in which this is done, and whether it needs to be reversed.
- Conclusion, classifications, or results should ALWAYS appear last.
- Examples: Include high-quality examples if helpful, using placeholders [in brackets] for complex elements.
- What kinds of examples may need to be included, how many, and whether they are complex enough to benefit from placeholders.
- Clarity and Conciseness: Use clear, specific language. Avoid unnecessary instructions or bland statements.
- Formatting: Use markdown features for readability. DO NOT USE CODE BLOCKS UNLESS SPECIFICALLY REQUESTED.
- Preserve User Content: If the input task or prompt includes extensive guidelines or examples, preserve them entirely, or as closely as possible. If they are vague, consider breaking down into sub-steps. Keep any details, guidelines, examples, variables, or placeholders provided by the user.
- Constants: DO include constants in the prompt, as they are not susceptible to prompt injection. Such as guides, rubrics, and examples.
- Output Format: Explicitly the most appropriate output format, in detail. This should include length and syntax (e.g. short sentence, paragraph, JSON, etc.)
- For tasks outputting well-defined or structured data (classification, JSON, etc.) bias toward outputting a JSON.
- JSON should never be wrapped in code blocks unless explicitly requested.
The final prompt you output should adhere to the following structure below. Do not include any additional commentary, only output the completed system prompt. SPECIFICALLY, do not include any additional messages at the start or end of the prompt. (e.g. no "---")
[Concise instruction describing the task - this should be the first line in the prompt, no section header]
[Additional details as needed.]
[Optional sections with headings or bullet points for detailed steps.]
# Steps [optional]
[optional: a detailed breakdown of the steps necessary to accomplish the task]
# Output Format
[Specifically call out how the output should be formatted, be it response length, structure e.g. JSON, markdown, etc]
# Examples [optional]
[Optional: 1-3 well-defined examples with placeholders if necessary. Clearly mark where examples start and end, and what the input and output are. User placeholders as necessary.]
[If the examples are shorter than what a realistic example is expected to be, make a reference with () explaining how real examples should be longer / shorter / different. AND USE PLACEHOLDERS! ]
# Notes [optional]
[optional: edge cases, details, and an area to call or repeat out specific important considerations]
`

View File

@@ -0,0 +1,83 @@
package service
import (
"fmt"
"geekai/core/types"
"geekai/store/model"
"gorm.io/gorm"
"sync"
"time"
)
type UserService struct {
db *gorm.DB
lock sync.Mutex
}
func NewUserService(db *gorm.DB) *UserService {
return &UserService{db: db, lock: sync.Mutex{}}
}
// IncreasePower 增加用户算力
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
if err != nil {
tx.Rollback()
return err
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
// DecreasePower 减少用户算力
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("扣减算力失败:%v", err)
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerSub,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("记录算力日志失败:%v", err)
}
tx.Commit()
return nil
}

377
api/service/video/luma.go Normal file
View File

@@ -0,0 +1,377 @@
package video
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
wsService *service.WebsocketService
clientIds map[uint]string
userService *service.UserService
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
wsService: wsService,
uploadManager: manager,
clientIds: map[uint]string{},
userService: userService,
}
}
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.VideoJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.VideoTask
err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil {
logger.Errorf("decode task info with error: %v", err)
continue
}
task.Id = v.Id
s.PushTask(task)
s.clientIds[v.Id] = task.ClientId
}
logger.Info("Starting Video job consumer...")
go func() {
for {
var task types.VideoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
}
}()
}
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
QueueState interface{} `json:"queue_state"`
CreatedAt string `json:"created_at"`
Video interface{} `json:"video"`
VideoRaw interface{} `json:"video_raw"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Thumbnail interface{} `json:"thumbnail"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
"user_prompt": task.Prompt,
"expand_prompt": task.Params.PromptOptimize,
"loop": task.Params.Loop,
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 && r.StatusCode != 201 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChLuma, message.Message)
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
if v.WaterURL == "" {
continue
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
logger.Infof("download video success: %s", videoURL)
v.WaterURL = videoURL
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
}
logger.Infof("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.State == "completed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
"cover_url": task.Thumbnail.Url,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
}
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "luma",
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 10)
}
}()
}
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
Thumbnail string `json:"thumbnail"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
UserId string `json:"user_id"`
BatchId string `json:"batch_id"`
Thumbnail struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"thumbnail"`
VideoRaw struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"video_raw"`
CreatedAt string `json:"created_at"`
LastFrame struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"last_frame"`
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
}
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
if r.StatusCode != 200 {
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

13
api/service/ws_service.go Normal file
View File

@@ -0,0 +1,13 @@
package service
import "geekai/core/types"
type WebsocketService struct {
Clients *types.LMap[string, *types.WsClient] // clientId => Client
}
func NewWebsocketService() *WebsocketService {
return &WebsocketService{
Clients: types.NewLMap[string, *types.WsClient](),
}
}

View File

@@ -1,101 +0,0 @@
package wx
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
logger2 "geekai/logger"
"geekai/store/model"
"github.com/eatmoreapple/openwechat"
"github.com/skip2/go-qrcode"
"gorm.io/gorm"
"os"
"strconv"
)
// 微信收款机器人
var logger = logger2.GetLogger()
type Bot struct {
bot *openwechat.Bot
token string
db *gorm.DB
}
func NewWeChatBot(db *gorm.DB) *Bot {
bot := openwechat.DefaultBot(openwechat.Desktop)
return &Bot{
bot: bot,
db: db,
}
}
func (b *Bot) Run() error {
logger.Info("Starting WeChat Bot...")
// set message handler
b.bot.MessageHandler = func(msg *openwechat.Message) {
b.messageHandler(msg)
}
// scan code login callback
b.bot.UUIDCallback = b.qrCodeCallBack
debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG"))
if debug {
reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json")
err = b.bot.HotLogin(reloadStorage, true)
} else {
err = b.bot.Login()
}
if err != nil {
return err
}
logger.Info("微信登录成功!")
return nil
}
// message handler
func (b *Bot) messageHandler(msg *openwechat.Message) {
sender, err := msg.Sender()
if err != nil {
return
}
// 只处理微信支付的推送消息
if sender.NickName == "微信支付" ||
msg.MsgType == openwechat.MsgTypeApp ||
msg.AppMsgType == openwechat.AppMsgTypeUrl {
// 解析支付金额
message := parseTransactionMessage(msg.Content)
transaction := extractTransaction(message)
logger.Infof("解析到收款信息:%+v", transaction)
if transaction.TransId != "" {
var item model.Reward
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
if item.Id > 0 {
logger.Error("当前交易 ID 己经存在!")
return
}
res = b.db.Create(&model.Reward{
TxId: transaction.TransId,
Amount: transaction.Amount,
Remark: transaction.Remark,
Status: false,
})
if res.Error != nil {
logger.Errorf("交易保存失败: %v", res.Error)
}
}
}
}
func (b *Bot) qrCodeCallBack(uuid string) {
logger.Info("请使用微信扫描下面二维码登录")
q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
logger.Info(q.ToString(true))
}

View File

@@ -1,112 +0,0 @@
package wx
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/xml"
"net/url"
"strconv"
"strings"
)
// Message 转账消息
type Message struct {
Des string
Url string
}
// Transaction 解析后的交易信息
type Transaction struct {
TransId string `json:"trans_id"` // 微信转账交易 ID
Amount float64 `json:"amount"` // 微信转账交易金额
Remark string `json:"remark"` // 转账备注
}
// 解析微信转账消息
func parseTransactionMessage(xmlData string) *Message {
decoder := xml.NewDecoder(strings.NewReader(xmlData))
message := Message{}
for {
token, err := decoder.Token()
if err != nil {
break
}
switch se := token.(type) {
case xml.StartElement:
var value string
if se.Name.Local == "des" && message.Des == "" {
if err := decoder.DecodeElement(&value, &se); err == nil {
message.Des = strings.TrimSpace(value)
}
break
}
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
if err := decoder.DecodeElement(&value, &se); err == nil {
if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
message.Url = value
}
}
break
}
}
}
// 兼容旧版消息记录
if message.Url == "" {
var msg struct {
XMLName xml.Name `xml:"msg"`
AppMsg struct {
Des string `xml:"des"`
Url string `xml:"url"`
} `xml:"appmsg"`
}
if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil {
message.Url = msg.AppMsg.Url
}
}
return &message
}
// 导出交易信息
func extractTransaction(message *Message) Transaction {
var tx = Transaction{}
// 导出交易金额和备注
lines := strings.Split(message.Des, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
// 解析收款金额
prefix := "收款金额¥"
if strings.HasPrefix(line, prefix) {
if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
tx.Amount = value
continue
}
}
// 解析收款备注
prefix = "付款方备注"
if strings.HasPrefix(line, prefix) {
tx.Remark = line[len(prefix):]
break
}
}
// 解析交易 ID
parse, err := url.Parse(message.Url)
if err == nil {
tx.TransId = parse.Query().Get("id")
if tx.TransId == "" {
tx.TransId = parse.Query().Get("trans_id")
}
}
return tx
}

View File

@@ -81,54 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
// 自动将 VIP 会员的算力补充到每月赠送的最大值 // 自动将 VIP 会员的算力补充到每月赠送的最大值
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) { func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("开始进行月底账号盘点...") logger.Info("开始进行月底账号盘点...")
var users []model.User
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
if res.Error != nil {
return "No vip users found"
}
var sysConfig model.Config
res = e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return "error with get system config: " + res.Error.Error()
}
var config types.SystemConfig
err := utils.JsonDecode(sysConfig.Config, &config)
if err != nil {
return "error with decode system config: " + err.Error()
}
for _, u := range users {
// 处理过期的 VIP
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
u.Vip = false
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
continue
}
if u.Power < config.VipMonthPower {
power := config.VipMonthPower - u.Power
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: power,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
}
}
}
logger.Info("月底盘点完成!")
return "success" return "success"
} }

View File

@@ -29,15 +29,9 @@ func NewLevelDB() (*LevelDB, error) {
} }
func (db *LevelDB) Put(key string, value interface{}) error { func (db *LevelDB) Put(key string, value interface{}) error {
var byteData []byte byteData, err := json.Marshal(value)
if v, ok := value.(string); ok { if err != nil {
byteData = []byte(v) return err
} else {
b, err := json.Marshal(value)
if err != nil {
return err
}
byteData = b
} }
return db.driver.Put([]byte(key), byteData, nil) return db.driver.Put([]byte(key), byteData, nil)
} }

View File

@@ -0,0 +1,12 @@
package model
import "time"
type AppType struct {
Id uint `gorm:"primarykey"`
Name string
Icon string
Enabled bool
SortNum int
CreatedAt time.Time
}

View File

@@ -4,16 +4,17 @@ import "gorm.io/gorm"
type ChatMessage struct { type ChatMessage struct {
BaseModel BaseModel
ChatId string // 会话 ID ChatId string // 会话 ID
UserId uint // 用户 ID UserId uint // 用户 ID
RoleId uint // 角色 ID RoleId uint // 角色 ID
Model string // AI模型 Model string // AI模型
Type string Type string
Icon string Icon string
Tokens int Tokens int
Content string TotalTokens int // 总 token 消耗
UseContext bool // 是否可以作为聊天上下文 Content string
DeletedAt gorm.DeletedAt UseContext bool // 是否可以作为聊天上下文
DeletedAt gorm.DeletedAt
} }
func (ChatMessage) TableName() string { func (ChatMessage) TableName() string {

View File

@@ -12,4 +12,5 @@ type ChatModel struct {
MaxContext int // 最大上下文长度 MaxContext int // 最大上下文长度
Temperature float32 // 模型温度 Temperature float32 // 模型温度
KeyId int // 绑定 API KEY ID KeyId int // 绑定 API KEY ID
Type string // 模型类型
} }

View File

@@ -2,6 +2,7 @@ package model
type ChatRole struct { type ChatRole struct {
BaseModel BaseModel
Tid int
Key string `gorm:"column:marker;unique"` // 角色唯一标识 Key string `gorm:"column:marker;unique"` // 角色唯一标识
Name string // 角色名称 Name string // 角色名称
Context string `gorm:"column:context_json"` // 角色语料信息 json Context string `gorm:"column:context_json"` // 角色语料信息 json

View File

@@ -6,6 +6,7 @@ type DallJob struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
UserId uint UserId uint
Prompt string Prompt string
TaskInfo string // 原始任务信息
ImgURL string ImgURL string
OrgURL string OrgURL string
Publish bool Publish bool

View File

@@ -7,6 +7,7 @@ type MidJourneyJob struct {
Type string Type string
UserId int UserId int
TaskId string TaskId string
TaskInfo string // 原始任务信息
ChannelId string ChannelId string
MessageId string MessageId string
ReferenceId string ReferenceId string

View File

@@ -2,7 +2,6 @@ package model
import ( import (
"geekai/core/types" "geekai/core/types"
"gorm.io/gorm"
) )
// Order 充值订单 // Order 充值订单
@@ -18,6 +17,6 @@ type Order struct {
Status types.OrderStatus Status types.OrderStatus
Remark string Remark string
PayTime int64 PayTime int64
PayWay string // 支付方式 PayWay string // 支付渠道
DeletedAt gorm.DeletedAt PayType string // 支付类型
} }

16
api/store/model/redeem.go Normal file
View File

@@ -0,0 +1,16 @@
package model
import "time"
// 兑换码
type Redeem struct {
Id uint `gorm:"primarykey;column:id"`
UserId uint // 用户 ID
Name string // 名称
Power int // 算力
Code string // 兑换码
Enabled bool // 启用状态
RedeemedAt int64 // 兑换时间
CreatedAt time.Time
}

View File

@@ -1,13 +0,0 @@
package model
// 用户打赏
type Reward struct {
BaseModel
UserId uint // 用户 ID
TxId string // 交易ID
Amount float64 // 打赏金额
Remark string // 打赏备注
Status bool // 核销状态
Exchange string // 众筹兑换详情JSON
}

View File

@@ -7,6 +7,7 @@ type SdJob struct {
Type string Type string
UserId int UserId int
TaskId string TaskId string
TaskInfo string // 原始任务信息
ImgURL string ImgURL string
Progress int Progress int
Prompt string Prompt string

Some files were not shown because too many files have changed in this diff Show More