Compare commits

..

188 Commits

Author SHA1 Message Date
RockYang
49b5906bc7 fix: can not change user's power in admin console 2024-03-25 11:40:03 +08:00
RockYang
3075bfb7fc fix: fix bug for update user's power in admin page did not work 2024-03-23 16:21:37 +08:00
RockYang
82e06fad33 No need to login with Stable-Diffusion page and Invite page 2024-03-23 15:45:37 +08:00
RockYang
4a9028747b fix: fix code highlight error when add formule detecting 2024-03-22 19:09:04 +08:00
RockYang
4a8ff0ccf0 chore: correct prompt messages 2024-03-22 18:27:57 +08:00
RockYang
99341f0484 feat: add chart for admin dashbord 2024-03-22 16:57:30 +08:00
RockYang
f58ac29ad0 feat: integrate xxl-job-admin to implements automatic task scheduling 2024-03-22 13:47:16 +08:00
RockYang
7060edb3e5 feat: save prompt in power log for dalle-3 2024-03-21 15:55:39 +08:00
RockYang
41ae411f9b feat: add manager list page in console page 2024-03-21 15:24:28 +08:00
RockYang
79b7fee47c feat: add powerlog page for admin console 2024-03-21 13:46:39 +08:00
RockYang
0044bf10af opt: optimize the formula show styles 2024-03-21 11:04:12 +08:00
RockYang
e9348d3611 always parse authorization token for all request 2024-03-20 21:11:52 +08:00
RockYang
b9236e09a7 fixed conflicts 2024-03-20 20:40:22 +08:00
RockYang
09b38d5f42 update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-03-20 20:39:44 +08:00
RockYang
7bb539a06e feat: Async loading midjourney job for mobile MidJourney page 2024-03-20 18:39:14 +08:00
RockYang
5cdada8265 feat: h5 payment for payjs is ready 2024-03-20 17:46:39 +08:00
RockYang
4147c217b1 feat: payment for mobile pages is ready 2024-03-20 16:14:02 +08:00
RockYang
8dda639b23 feat: the power log page is ready 2024-03-20 14:14:30 +08:00
RockYang
8487d2c9eb feat: no need login refactor with member and chatApps page 2024-03-19 18:59:02 +08:00
RockYang
c5e583b215 feat: load preview page do not require user to login 2024-03-19 18:25:01 +08:00
RockYang
549f618cff feat: optimize login dialog 2024-03-19 10:47:13 +08:00
RockYang
e9a3510346 feat: refactoring adjustments for reward page is ready 2024-03-18 18:28:34 +08:00
RockYang
30e6e963b3 feat: refactoring adjustments for member pages 2024-03-18 16:59:07 +08:00
RockYang
c72d963f45 feat: The 'chat_models' field of user table, holds the model IDS in place of the model values 2024-03-18 15:37:46 +08:00
RockYang
172d498618 remove new-ui files 2024-03-18 12:01:34 +08:00
RockYang
313993532e chore: adjust page styles 2024-03-18 06:46:08 +08:00
RockYang
e53db3582c 重构主体工作完成 2024-03-15 18:35:10 +08:00
RockYang
72c6bd3f77 restore new ui files 2024-03-15 11:13:02 +08:00
廖彦棋
ca8b349df3 fix(ui): 交互修复调整 2024-03-15 11:07:41 +08:00
RockYang
1b206c3640 feat: refactor user list page for new UI 2024-03-15 09:29:19 +08:00
廖彦棋
c60276fc9f fix(ui): 用户管理有效期传参调整,权限标识补充 2024-03-15 09:25:06 +08:00
廖彦棋
d00a3167c0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:49:18 +08:00
廖彦棋
6b1cd8c30c refactor(ui): 无权限页面调整 2024-03-14 17:49:15 +08:00
吴汉强
46f12dc9ad feat(ui): 后台首页去掉权限判断 2024-03-14 17:24:40 +08:00
廖彦棋
a3e1d8ae21 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:19:41 +08:00
廖彦棋
72a066b93e feat(ui): 无权限判断 2024-03-14 17:19:39 +08:00
吴汉强
0327a829ac Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 17:12:53 +08:00
吴汉强
882e9b8819 feat(ui): 新增 sql 2024-03-14 17:12:48 +08:00
廖彦棋
ef58cfadaa Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:06:15 +08:00
吴汉强
bf958d6113 feat(ui): 403,没权限 2024-03-14 16:41:38 +08:00
吴汉强
71611273d7 feat(ui): 网站配置需授权,去掉 2024-03-14 16:28:49 +08:00
廖彦棋
b27c654311 fix(ui): 调整 2024-03-14 16:14:49 +08:00
吴汉强
90930ea9f9 feat(ui): 后端加权限验证 2024-03-14 15:39:12 +08:00
廖彦棋
1ab2185ff1 feat(ui): 角色管理 2024-03-14 15:25:17 +08:00
廖彦棋
0f2f978d4c feat(ui): 新增系统分类菜单 2024-03-14 15:17:53 +08:00
廖彦棋
f61963b0b0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 10:56:56 +08:00
廖彦棋
2aa413960d fix(ui): 修复 2024-03-14 10:56:54 +08:00
吴汉强
aa4bbba5ec Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 10:28:39 +08:00
吴汉强
eba61fea2d feat(ui): 登录接口返回权限 2024-03-14 10:28:32 +08:00
廖彦棋
34e3455128 feat(ui): 管理后台新增权限及部分组合式函数优化 2024-03-14 10:27:09 +08:00
廖彦棋
07dca3e739 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 17:30:26 +08:00
廖彦棋
4cb4b145f9 feat(ui): web移动端初始化 2024-03-13 17:30:24 +08:00
吴汉强
1ed417cb69 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 17:24:38 +08:00
吴汉强
6cf91a84ca feat(ui): 增加角色管理,管理员方法新增角色关联 2024-03-13 17:24:30 +08:00
chenzifan
0b566980fc Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 14:40:45 +08:00
chenzifan
f86176b342 refactor: remove api change to post request 2024-03-13 14:40:38 +08:00
吴汉强
c700b32670 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 11:41:07 +08:00
吴汉强
22641b452a feat(ui): 增加系统权限管理 2024-03-13 11:41:01 +08:00
廖彦棋
d3fbb8c19e fix(ui): ui调整 2024-03-13 09:54:20 +08:00
RockYang
e3bb69ff10 docs: update sql file 2024-03-13 08:49:40 +08:00
RockYang
770360c614 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:48:10 +08:00
chenzifan
f302a0478f fix: 删除系统管理员失效的问题 2024-03-13 08:47:17 +08:00
chenzifan
a88697b43a Merge remote-tracking branch 'origin/ui' into ui
# Conflicts:
#	api/handler/admin/admin_user_handler.go
2024-03-13 08:46:16 +08:00
chenzifan
cc6f140812 fix: 删除系统管理员失效的问题 2024-03-13 08:45:09 +08:00
廖彦棋
424f2b3bdc Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:44:50 +08:00
廖彦棋
ec0c13a600 feat(ui): 调整 2024-03-13 08:44:48 +08:00
chenzf@pvc123.com
a1f03bec4c feat: 超级管理员不支持修改和删除 2024-03-12 21:16:05 +08:00
RockYang
b5bd4a5e0e Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 18:07:24 +08:00
RockYang
7c2e49bfdb fix conflicts 2024-03-12 18:07:19 +08:00
chenzifan
f80fe6d041 feat: 增加系统管理员 2024-03-12 18:06:49 +08:00
RockYang
72f80a96bc fix conflicts 2024-03-12 18:03:24 +08:00
RockYang
2de655a1cf refactor: use power replace calls for front pages 2024-03-12 17:47:06 +08:00
RockYang
da2bd4a501 refactor: 重构项目,为所有的 AI 工具都引入算力,采用算力统一结算各个工具的调用次数和权限 2024-03-12 15:40:44 +08:00
廖彦棋
e0aa62c40d Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 08:37:29 +08:00
廖彦棋
9d26a892d1 refactor(ui): 调整 2024-03-12 08:37:27 +08:00
huangqj
4ece7f2847 fix(ui):环境变量 2024-03-11 16:10:49 +08:00
廖彦棋
32368caf1b feat(ui): 新增系统管理员 2024-03-11 15:59:15 +08:00
廖彦棋
e91f54e79e fix(ui): 删除冗余 2024-03-11 14:23:11 +08:00
廖彦棋
bb8f4c57c4 fix(ui): 删除冗余 2024-03-11 14:22:39 +08:00
RockYang
43bfac99b6 feat: replace Tools param with Function param for OpenAI chat API 2024-03-11 14:09:19 +08:00
廖彦棋
be379b6d63 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-11 13:52:24 +08:00
廖彦棋
17f3c9b840 fix(ui): type 2024-03-11 13:52:22 +08:00
chenzifan
24de97fac2 feat: 优化后台UI 2024-03-11 13:51:26 +08:00
chenzifan
bf27b44fee Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 13:46:46 +08:00
廖彦棋
1802b4fe4d refactor(ui): 调整优化 2024-03-11 13:46:08 +08:00
廖彦棋
241a5c7bc9 feat(ui): 细节优化 2024-03-11 12:02:20 +08:00
廖彦棋
557d547bf1 feat(ui): 上传功能补充 2024-03-11 11:41:50 +08:00
廖彦棋
2e7b75affb feat(ui): 登录新增验证码及记住密码功能 2024-03-11 10:49:13 +08:00
廖彦棋
bc21a1d443 feat(ul): 顶部信息 2024-03-11 09:00:00 +08:00
huangqj
3fc9e10a24 feat(ui):看板图表 样式调整 2024-03-11 08:35:21 +08:00
chenzifan
5fa1aa2060 Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 08:01:54 +08:00
廖彦棋
ff4b267858 fix(ui): 细节调整 2024-03-08 17:46:48 +08:00
huangqj
a590d0497f feat(ui):细节调整 2024-03-08 10:24:38 +08:00
廖彦棋
ac30d906f0 feat(ui): prettier 2024-03-08 09:59:40 +08:00
廖彦棋
5bc071e038 refactor(ui): 登录页重构 2024-03-08 09:45:09 +08:00
廖彦棋
88b956cf98 refactor(ui): 优化 2024-03-08 09:12:39 +08:00
huangqj
f725cf4661 feat(ui):路由 2024-03-08 08:35:41 +08:00
廖彦棋
057cc1e8a6 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 18:03:08 +08:00
廖彦棋
de122735b8 feat(ui): 过期跳转登录 2024-03-07 18:03:06 +08:00
huangqj
e87ede981c feat(ui):apiKey 语言模型 角色管理 产品 2024-03-07 17:58:25 +08:00
廖彦棋
606fb498e1 feat(ui): 新增系统设置 2024-03-07 17:24:50 +08:00
廖彦棋
a0c06e40a4 feat(ui): 对话管理 2024-03-07 15:32:32 +08:00
huangqj
aba8f57279 feat(ui):用户 2024-03-07 15:05:01 +08:00
huangqj
960286a350 feat(ui):simpleTable 2024-03-07 14:59:45 +08:00
huangqj
8c93fa51f6 feat(ui):searchtable 2024-03-07 14:59:16 +08:00
huangqj
cb0e7d64ff feat(ui):用户 2024-03-07 14:03:55 +08:00
廖彦棋
8e7413da97 feat(ui): 函数管理 2024-03-07 11:40:57 +08:00
廖彦棋
a36f14eb94 feat(ui): 新增弹窗及时间格式化 2024-03-07 09:23:45 +08:00
chenzifan
f2f9f6e488 Merge branch 'main' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 08:37:54 +08:00
chenzifan
85068b8ca2 feat: 增加系统用户管理 2024-03-07 08:37:48 +08:00
廖彦棋
f2cfcfeefc fix(ui): ts类型 2024-03-06 18:20:07 +08:00
RockYang
755273a898 feat: update changelogs 2024-03-06 17:58:17 +08:00
廖彦棋
d4a24a0f1d feat(ui): 新增登录 2024-03-06 17:54:38 +08:00
RockYang
92281fcbb7 feat: Mj and sd jobs data loading in pages 2024-03-06 17:31:54 +08:00
RockYang
636db4afcc add prompt translating function for mobile midjourney page 2024-03-06 16:22:03 +08:00
huangqj
ba25b8755e feat(ui):simpleTable 2024-03-06 15:33:37 +08:00
廖彦棋
6399d13a49 feat(ui): 新增请求方法及表格 2024-03-06 13:55:38 +08:00
廖彦棋
06fa54fd25 feat(ui): 管理后台基础配置 2024-03-06 10:23:55 +08:00
廖彦棋
a335b965d0 chore(ui): 目录结构调整 2024-03-06 09:32:47 +08:00
廖彦棋
725adaa7d0 feat(ui): 初始化 2024-03-06 09:27:11 +08:00
chenzifan
7e7e81e974 refactor: 初始化UI重构 2024-03-06 08:57:46 +08:00
RockYang
8cfe6bfc17 Merge branch 'dev' of gitee.com:blackfox/chatgpt-plus-pro into dev 2024-03-04 08:34:12 +08:00
RockYang
33de83f2ac feat: add removing order button in admin order list page 2024-03-03 19:27:22 +08:00
RockYang
3f856afec8 fix: fix major bugs for unauthorized access to data 2024-03-03 10:40:32 +08:00
RockYang
02a9c422fe fix: fixed bug image preview im mobile chat session page 2024-02-29 15:41:45 +08:00
RockYang
ca69341024 feat: add draw same image for midjourney page 2024-02-29 11:44:09 +08:00
RockYang
169bf069ce opt: add logs for mj-plus api error 2024-02-28 15:50:42 +08:00
RockYang
1bee0ab04d opt: replace proxy url for discord image url 2024-02-27 17:45:57 +08:00
RockYang
440d91dd0e feat: add change password in with mobile page 2024-02-27 15:36:20 +08:00
RockYang
8168e246a8 feat: add image preview for mobile chat page 2024-02-26 18:11:37 +08:00
RockYang
2ef07574ae feat: replace http polling with webscoket notify in sd image page 2024-02-26 15:45:54 +08:00
RockYang
37392f2bb2 chore: replace 'token' with power 2024-02-23 18:11:57 +08:00
RockYang
a80cd3848e docs: update mj-plus api domain 2024-02-23 15:41:02 +08:00
RockYang
db6ed84451 opt: remove global keyup event bind 2024-02-23 11:43:27 +08:00
RockYang
4463cc5963 docs: update config.toml 2024-02-22 17:53:39 +08:00
RockYang
d316158fe2 docs: add database sql file for v3.2.7 2024-02-22 17:28:22 +08:00
RockYang
e02a8d7586 feat: allow to view chat message in manager console 2024-02-22 17:16:44 +08:00
RockYang
9988dff885 feat: download image which ai generated in dialog and replace the image url 2024-02-20 18:38:03 +08:00
RockYang
35ef5674ff feat: image-wall page for mobile is ready 2024-02-20 17:38:18 +08:00
RockYang
976da45bce feat: allow user config third-party platform openai and mj api key 2024-02-20 11:23:55 +08:00
RockYang
c83ac48bd2 feat: added delete file function 2024-02-19 16:43:03 +08:00
RockYang
3d159a833e fix: verifycation component touch event coordinates misplace in iphone browser 2024-02-19 14:04:50 +08:00
RockYang
4b09878bdd fix: fix bug for regenerate button did not work 2024-02-19 11:22:42 +08:00
RockYang
b0162e6a92 fix: Upscale and Variation task overrite each other 2024-02-16 18:08:29 +08:00
RockYang
8ab15e5dc4 feat: midjourney mobile page all function is ready 2024-02-16 15:55:04 +08:00
RockYang
d2ac807252 feat: mobile mj list page is ready 2024-02-15 18:11:22 +08:00
RockYang
0af01f6f1f feat: add mj image list component for mobile page. fixed bug for html tag escape 2024-02-15 11:39:04 +08:00
RockYang
013b319fab feat: add functions mj page for mobile 2024-01-31 07:24:35 +08:00
RockYang
2899ba5949 opt: add default extension for mj image 2024-01-30 21:46:17 +08:00
RockYang
a558b7e104 feat: mj for mobile page payout is ready 2024-01-30 18:34:01 +08:00
RockYang
7a833e2233 add docs and github link 2024-01-30 16:18:27 +08:00
RockYang
bf65746d00 opt: enable use cdn url for mj-plus 2024-01-28 21:56:25 +08:00
RockYang
f08a7862de feat: LaTeX parse is ready 2024-01-26 18:04:53 +08:00
RockYang
023a2c2f09 feat: add model field for chat_item and and chat_history data table 2024-01-26 16:54:00 +08:00
RockYang
1bcd0f4c1a feat: add err_msg field for mj and sd jobs 2024-01-26 14:50:36 +08:00
RockYang
a0f3bc8ccb feat: blend and swap face function for midjourney-plus is ready 2024-01-26 11:57:08 +08:00
RockYang
dea72738c1 feat: add blend and swapface task implements for midjourney 2024-01-25 18:50:24 +08:00
RockYang
a1d1fe7763 opt: refactor chat session page for mobile device 2024-01-25 14:07:10 +08:00
RockYang
a39ed9764c opt: optimize chat list page for mobile 2024-01-24 18:23:24 +08:00
RockYang
aaa5ba99aa feat: use vant replace element-plus as mobile UI framework 2024-01-24 17:34:30 +08:00
RockYang
2113508b6d feat: add websocket heartbeat message for mj page 2024-01-24 09:33:04 +08:00
RockYang
7fe4212684 add v3.2.6 database sql file 2024-01-23 18:00:49 +08:00
RockYang
8bdda64794 doc: update config sample file 2024-01-23 17:56:22 +08:00
RockYang
ec08c24dca fix: auto fill apiURL when platform changed for ApiKey add page 2024-01-23 17:30:54 +08:00
RockYang
a992a5b3b3 feat: merge sms branch,add DuanXinBao sms service implemetation 2024-01-23 16:16:47 +08:00
RockYang
0f05970141 opt: add heartbeat message for websocket connects 2024-01-22 18:42:51 +08:00
whale_fall
e5e762efcd feat: 添加支持多个短信服务商支持 添加短信宝服务商支持,同时添加配置示例 2024-01-22 16:38:44 +08:00
RockYang
b3d0c1ef9c fix: fix bug for wechat transfer message parse failed 2024-01-22 16:10:08 +08:00
RockYang
397078f7ff feat: HuPiPay order check function is ready 2024-01-22 15:17:26 +08:00
RockYang
3ad8065e20 opt: verify the order in notify callback 2024-01-22 13:58:25 +08:00
RockYang
66c7717f04 chore: print error detail when call http api failed with mj 2024-01-21 22:30:24 +08:00
RockYang
412f8ecc6c opt: add image upload support for md-editor-3 2024-01-19 18:43:13 +08:00
RockYang
51dcf642b3 fixed conflicts 2024-01-19 18:21:49 +08:00
RockYang
bfeea555b2 feat: system notice function is ready 2024-01-19 18:19:51 +08:00
RockYang
479f94c372 feat: system notice function is ready 2024-01-19 18:18:10 +08:00
RockYang
0140713e86 fix: fixed bug for img_call increased when upscale task run failed 2024-01-19 17:10:52 +08:00
RockYang
15b2ec9721 feat: add system config item for wechat qrcode 2024-01-19 16:58:13 +08:00
RockYang
c9cd082855 chore: optimize variable name 2024-01-19 11:26:22 +08:00
RockYang
d7c002890c Merge branch 'main' into dev 2024-01-19 10:09:18 +08:00
RockYang
348dd22279 Merge branch 'main' of gitee.com:blackfox/chatgpt-plus 2024-01-19 10:09:01 +08:00
RockYang
3e99b4cbf6 !4 添加支持阿里旗下的大模型 通义千问对话
Merge pull request !4 from 鲸落/qwen
2024-01-19 02:06:17 +00:00
whale_fall
6968da3ac7 feat: 添加支持阿里的通义千问对话 2024-01-19 09:52:16 +08:00
RockYang
bf1c1b84c3 feat: add image publish function, ONLY published image show in image wall page 2024-01-19 06:52:23 +08:00
RockYang
c70314d930 update change log 2024-01-18 18:11:25 +08:00
RockYang
9104ca8e49 feat: add system config disable user registeration 2024-01-18 17:24:02 +08:00
RockYang
2af33b3630 opt: compatible wechat old message format for parsing wechat transfer message 2024-01-18 16:58:20 +08:00
RockYang
654e795545 opt: optimize order query alg, reduce polling times 2024-01-18 09:39:36 +08:00
RockYang
c62ba2451e update config 2024-01-16 15:24:06 +08:00
241 changed files with 14332 additions and 8260 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
deploy
docs
api/static
web/node_modules
desktop

View File

@@ -1,4 +1,49 @@
# 更新日志
## v4.0.0
非兼容版本重大重构引入算力概念将系统中所有的能力AI对话MJ绘画SD绘画DALL绘画全部使用算力来兑换。
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
* 功能重构:重构整体系统,全部采用算力来进行结算
* 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
* 功能优化:移动端聊天页面图片支持预览和放大功能
* 功能优化MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
* 功能优化:**PC端不登录也可以预览功能只有在发起操作的时候才需要登录**
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
* 功能新增支持H5支付
* 功能优化:支持数学公式的识别和美化输出
* 功能新增:新增算力消费日志功能
* 功能优化:整合 XXL-JOB 实现订单清理每日算力派发VIP 算力重置等任务
* 功能新增管理后台新增7日内新增用户和新增订单统计
## v3.2.7
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
* Bug修复修复 issue [
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
* 功能优化:在对话和聊天记录表中新增冗余字段 model存储对话模型
* Bug修复IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
* Bug修复重新生成按钮功能失效问题
* Bug修复对话输入HTML标签不显示的问题
* 功能优化gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
* 功能新增:新增删除文件功能
* Bug修复解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
* 功能新增:后台管理新怎对话查看和检索功能
## v3.2.6
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
* 功能优化:兼用旧版本微信收款消息解析
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
* 功能新增:新增网站公告,可以在管理后台自定义配置
* 功能新增:新增阿里通义千问大模型支持
* Bug修复修复 MJ 放大任务失败时候 img_call 会增加的 Bug
* 功能优化新增虎皮椒和PayJS订单状态校验功能增加安全性
* Bug修复修复微信转账交易 ID 提取失败 Bug
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
* 功能新增:新增短信宝短信平台发送平台集成
## v3.2.5
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 Plus 账号了!!!

View File

@@ -73,9 +73,11 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
```shell
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.5-400fea2598.sh)"
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)"
```
最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
* 前端访问地址http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
@@ -145,6 +147,3 @@ KEY。
![打赏](docs/imgs/donate.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/chatgpt-plus&type=Date)

View File

@@ -10,10 +10,6 @@ WeChatBot = false
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
MaxAge = 86400
[Manager]
Username = "admin"
Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
[Redis] # redis 配置信息
Host = "localhost"
Port = 6379
@@ -25,19 +21,28 @@ WeChatBot = false
AppId = ""
Token = ""
[SmsConfig] # 阿里云短信服务配置
AccessKey = ""
AccessSecret = ""
Product = "Dysmsapi"
Domain = "dysmsapi.aliyuncs.com"
Sign = ""
CodeTempId = ""
[SMS] # Sms 配置,用于发送短信
Active = "Ali" # 当前启用的短信服务,默认使用阿里云
[SMS.Bao]
Username = ""
Password = ""
Domain = "api.smsbao.com"
Sign = "【极客学长】"
CodeTemplate = "您的验证码是{code}。5分钟有效若非本人操作请忽略本短信。"
[SMS.Ali]
AccessKey = ""
AccessSecret = ""
Product = "Dysmsapi"
Domain = "dysmsapi.aliyuncs.com"
Sign = ""
CodeTempId = ""
[OSS] # OSS 配置,用于存储 MJ 绘画图片
Active = "local" # 默认使用本地文件存储引擎
[OSS.Local]
BasePath = "./static/upload" # 本地文件上传根路径
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件 URL 如果是线上,则直接设置为 /static/upload 即可
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
[OSS.Minio]
Endpoint = "" # 如 172.22.11.200:9000
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
@@ -51,6 +56,13 @@ WeChatBot = false
AccessSecret = ""
Bucket = ""
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
[OSS.AliYun]
Endpoint = "oss-cn-hangzhou.aliyuncs.com"
AccessKey = ""
AccessSecret = ""
Bucket = "chatgpt-plus"
SubDir = ""
Domain = ""
[[MjConfigs]]
Enabled = false
@@ -59,20 +71,17 @@ WeChatBot = false
GuildId = ""
ChanelId = ""
UseCDN = false #是否使用反向代理访问设置为true下面的设置才会生效
DiscordAPI = "https://mj.r9it.com:8001" # discord API 反代地址
DiscordCDN = "https://mj.r9it.com:8002" # mj 图片反代地址
DiscordGateway = "wss://mj.r9it.com:8003" # discord 机器人反代地址
DiscordAPI = "" # discord API 反代地址
DiscordCDN = "" # mj 图片反代地址
DiscordGateway = "" # discord 机器人反代地址
[[MjConfigs]]
[[MjPlusConfigs]]
Enabled = false
UserToken = ""
BotToken = ""
GuildId = ""
ChanelId = ""
UseCDN = false #是否使用反向代理访问设置为true下面的设置才会生效
DiscordAPI = "https://mj.r9it.com:8001" # discord API 反代地址
DiscordCDN = "https://mj.r9it.com:8002" # mj 图片反代地址
DiscordGateway = "wss://mj.r9it.com:8003" # discord 机器人反代地址
ApiURL = "https://api.chat-plus.net"
CdnURL = "" # CND 加速的 URL如果有的话就设置
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
NotifyURL = "https://ai.r9it.com/api/mj/notify" # 这里需要改成你的域名
[[SdConfigs]]
Enabled = false
@@ -80,18 +89,6 @@ WeChatBot = false
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/text2img.json"
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
@@ -114,9 +111,9 @@ WeChatBot = false
[HuPiPayConfig]
Enabled = false
Name = "wechat"
AppId = "201906161477"
AppSecret = "7f403199d510fb2c6f0b9f2311800e7c"
PayURL = "https://api.xunhupay.com/payment/do.html"
AppId = ""
AppSecret = ""
ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口所以如果需要使用邮件功能请别用阿里云服务器
@@ -131,5 +128,5 @@ WeChatBot = false
Name = "wechat" # 请不要改动
AppId = "" # 商户 ID
PrivateKey = "" # 秘钥
ApiURL = "https://payjs.cn/api/native"
ApiURL = "https://payjs.cn"
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的

View File

@@ -28,10 +28,9 @@ type AppServer struct {
Debug bool
Config *types.AppConfig
Engine *gin.Engine
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache
SysConfig *types.SystemConfig // system config cache
SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
@@ -47,7 +46,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false,
Config: appConfig,
Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []interface{}](),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
@@ -69,23 +68,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
}
func (s *AppServer) Run(db *gorm.DB) error {
// load chat config from database
var chatConfig model.Config
res := db.Where("marker", "chat").First(&chatConfig)
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
if err != nil {
return err
}
// load system configs
var sysConfig model.Config
res = db.Where("marker", "system").First(&sysConfig)
res := db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return res.Error
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return err
}
@@ -143,73 +132,64 @@ func corsMiddleware() gin.HandlerFunc {
// 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/mj/jobs" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/jobs" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
c.Request.URL.Path == "/api/admin/config/get" {
c.Next()
return
}
var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" {
tokenString = c.Query("token")
} else {
tokenString = c.GetHeader(types.UserAuthHeader)
}
if tokenString == "" {
resp.ERROR(c, "You should put Authorization in request headers")
c.Abort()
return
if needLogin(c) {
resp.ERROR(c, "You should put Authorization in request headers")
c.Abort()
return
} else { // 直接放行
c.Next()
return
}
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
if isAdminApi {
return []byte(s.Config.AdminSession.SecretKey), nil
} else {
return []byte(s.Config.Session.SecretKey), nil
}
return []byte(s.Config.Session.SecretKey), nil
})
if err != nil {
if err != nil && needLogin(c) {
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
if !ok || !token.Valid && needLogin(c) {
resp.NotAuth(c, "Token is invalid")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
resp.NotAuth(c, "Token is expired")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := client.Get(context.Background(), key).Result(); err != nil {
if isAdminApi {
key = fmt.Sprintf("admin/%v", claims["user_id"])
}
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
resp.NotAuth(c, "Token is not found in redis")
c.Abort()
return
@@ -218,6 +198,36 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
}
}
func needLogin(c *gin.Context) bool {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false
}
return true
}
// 统一参数处理
func parameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {

View File

@@ -16,7 +16,6 @@ func NewDefaultConfig() *types.AppConfig {
return &types.AppConfig{
Listen: "0.0.0.0:5678",
ProxyURL: "",
Manager: types.Manager{Username: "admin", Password: "admin123"},
StaticDir: "./static",
StaticUrl: "http://localhost/5678/static",
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},

View File

@@ -12,6 +12,9 @@ type ApiRequest struct {
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"`
Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问
Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
}
type Message struct {
@@ -51,10 +54,14 @@ type ChatSession struct {
}
type ChatModel struct {
Id uint `json:"id"`
Platform Platform `json:"platform"`
Value string `json:"value"`
Weight int `json:"weight"`
Id uint `json:"id"`
Platform Platform `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
}
type ApiError struct {
@@ -69,23 +76,36 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message
var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"chatglm_pro": 32768, // 清华智普
"chatglm_std": 16384,
"chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞
"general2": 8192,
"general3": 8192,
// PowerType 算力日志类型
type PowerType int
const (
PowerRecharge = PowerType(1) // 充值
PowerConsume = PowerType(2) // 消费
PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
PowerInvite = PowerType(4) // 邀请奖励
PowerReward = PowerType(5) // 众筹
PowerGift = PowerType(6) // 系统赠送
)
func (t PowerType) String() string {
switch t {
case PowerRecharge:
return "充值"
case PowerConsume:
return "消费"
case PowerRefund:
return "退款"
case PowerReward:
return "众筹"
}
return "其他"
}
func GetModelMaxToken(model string) int {
if token, ok := ModelToTokens[model]; ok {
return token
}
return 4096
}
type PowerMark int
const (
PowerSub = PowerMark(0)
PowerAdd = PowerMark(1)
)

View File

@@ -8,18 +8,17 @@ type AppConfig struct {
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
Manager Manager // 后台管理员账户信息
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
SmsConfig AliYunSmsConfig // AliYun send message service config
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool
MjPlusConfigs []MidJourneyPlusConfig // MJ plus config
ImgCdnURL string // 图片反代加速地址
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
@@ -38,16 +37,6 @@ type SmtpConfig struct {
Password string // 发件人邮箱密码
}
// JPayConfig PayJs 支付配置
type JPayConfig struct {
Enabled bool
Name string // 支付名称,默认 wechat
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步回调地址
}
type ChatPlusApiConfig struct {
ApiURL string
AppId string
@@ -61,6 +50,7 @@ type MidJourneyConfig struct {
GuildId string // Server ID
ChanelId string // Chanel ID
UseCDN bool
ImgCdnURL string // 图片反代加速地址
DiscordAPI string
DiscordGateway string
}
@@ -73,21 +63,14 @@ type StableDiffusionConfig struct {
}
type MidJourneyPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
CdnURL string // CDN 加速地址
ApiKey string
NotifyURL string // 任务进度更新回调地址
}
type AliYunSmsConfig struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
}
type AlipayConfig struct {
Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境
@@ -98,6 +81,7 @@ type AlipayConfig struct {
AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
@@ -105,8 +89,20 @@ type HuPiPayConfig struct { //虎皮椒第四方支付配置
Name string // 支付名称wechat/alipay
AppId string // App ID
AppSecret string // app 密钥
PayURL string // 支付网关
ApiURL string // 支付网关
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
}
// JPayConfig PayJs 支付配置
type JPayConfig struct {
Enabled bool
Name string // 支付名称,默认 wechat
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步回调地址
ReturnURL string // 支付成功返回地址
}
type XXLConfig struct { // XXL 任务调度配置
@@ -129,26 +125,6 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
}
// ChatConfig 系统默认的聊天配置
type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
XunFei ModelAPIConfig `json:"xun_fei"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
ContextDeep int `json:"context_deep"` // 上下文深度
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
}
type Platform string
const OpenAI = Platform("OpenAI")
@@ -156,42 +132,34 @@ const Azure = Platform("Azure")
const ChatGLM = Platform("ChatGLM")
const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
// UserChatConfig 用户的聊天配置
type UserChatConfig struct {
ApiKeys map[Platform]string `json:"api_keys"`
}
type InviteReward struct {
ChatCalls int `json:"chat_calls"`
ImgCalls int `json:"img_calls"`
}
type ModelAPIConfig struct {
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
}
const QWen = Platform("QWen")
type SystemConfig struct {
Title string `json:"title"`
AdminTitle string `json:"admin_title"`
InitChatCalls int `json:"init_chat_calls"` // 新用户注册赠送对话次数
InitImgCalls int `json:"init_img_calls"` // 新用户注册赠送绘图次数
VipMonthCalls int `json:"vip_month_calls"` // VIP 会员每月赠送的对话次数
VipMonthImgCalls int `json:"vip_month_img_calls"` // VIP 会员每月赠送绘图次数
Title string `json:"title,omitempty"`
AdminTitle string `json:"admin_title,omitempty"`
Logo string `json:"logo,omitempty"`
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways"` // 注册方式:支持手机,邮箱注册
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward"` // 启用众筹功能
ChatCallPrice float64 `json:"chat_call_price"` // 对话单次调用费用
ImgCallPrice float64 `json:"img_call_price"` // 绘图单次调用费用
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout"` //订单支付超时时间
DefaultModels []string `json:"default_models"` // 默认开通的 AI 模型
OrderPayInfoText string `json:"order_pay_info_text"` // 订单支付页面说明文字
InviteChatCalls int `json:"invite_chat_calls"` // 邀请用户注册奖励对话次数
InviteImgCalls int `json:"invite_img_calls"` // 邀请用户注册奖励绘图次数
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
ShowDemoNotice bool `json:"show_demo_notice"` // 显示演示站公告
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"`
}

View File

@@ -9,7 +9,7 @@ type MKey interface {
string | int | uint
}
type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{}
*WsClient | *ChatSession | context.CancelFunc | []Message
}
type LMap[K MKey, T MValue] struct {
lock sync.RWMutex

View File

@@ -9,10 +9,9 @@ const (
)
type OrderRemark struct {
Days int `json:"days"` // 有效期
Calls int `json:"calls"` // 增加对话次
ImgCalls int `json:"img_calls"` // 增加绘图次数
Name string `json:"name"` // 产品名称
Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
Discount float64 `json:"discount"`
}

26
api/core/types/sms.go Normal file
View File

@@ -0,0 +1,26 @@
package types
type SMSConfig struct {
Active string
Ali SmsConfigAli
Bao SmsConfigBao
}
// SmsConfigAli 阿里云短信平台配置
type SmsConfigAli struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
}
// SmsConfigBao 短信宝平台配置
type SmsConfigBao struct {
Username string //短信宝平台注册的用户名
Password string //短信宝平台注册的密码
Domain string //域名
Sign string // 短信签名
CodeTemplate string // 验证码短信模板 匹配
}

View File

@@ -9,13 +9,17 @@ func (t TaskType) String() string {
const (
TaskImage = TaskType("image")
TaskBlend = TaskType("blend")
TaskSwapFace = TaskType("swapFace")
TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation")
)
// MjTask MidJourney 任务
type MjTask struct {
Id int `json:"id"`
Id uint `json:"id"`
TaskId string `json:"task_id"`
ImgArr []string `json:"img_arr"`
ChannelId string `json:"channel_id"`
SessionId string `json:"session_id"`
Type TaskType `json:"type"`

View File

@@ -30,6 +30,7 @@ const (
Success = BizCode(0)
Failed = BizCode(1)
NotAuthorized = BizCode(400) // 未授权
NotPermission = BizCode(403) // 没有权限
OkMsg = "Success"
ErrorMsg = "系统开小差了"

View File

@@ -27,6 +27,16 @@ require github.com/xxl-job/xxl-job-executor-go v1.2.0
require github.com/bg5t/mydiscordgo v0.28.1
require (
github.com/mojocn/base64Captcha v1.3.1
github.com/shopspring/decimal v1.3.1
)
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect

View File

@@ -65,6 +65,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
@@ -129,6 +131,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
@@ -166,6 +170,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
@@ -227,6 +233,8 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=

View File

@@ -5,10 +5,15 @@ import (
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/mojocn/base64Captcha"
"time"
"github.com/gin-gonic/gin"
@@ -17,47 +22,88 @@ import (
var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
h := ManagerHandler{db: db, redis: client}
h.App = app
return &h
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager
var data Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
manager := h.App.Config.Manager
if data.Username == manager.Username && data.Password == manager.Password {
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Username,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := "users/" + manager.Username
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, tokenString)
} else {
resp.ERROR(c, "用户名或者密码错误")
// add captcha
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
resp.ERROR(c, "验证码错误!")
return
}
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
return
}
password := utils.GenPassword(data.Password, manager.Salt)
if password != manager.Password {
resp.ERROR(c, "用户名或密码错误")
return
}
// 超级管理员默认是ID:1
if manager.Id != SuperManagerID && manager.Status == false {
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
return
}
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("admin/%d", manager.Id)
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
// 更新最后登录时间和IP
manager.LastLoginIp = c.ClientIP()
manager.LastLoginAt = time.Now().Unix()
h.DB.Updates(&manager)
var result = struct {
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
}{
IsSuperAdmin: manager.Id == 1,
Token: tokenString,
}
resp.SUCCESS(c, result)
}
// Logout 注销
@@ -72,10 +118,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
token := c.GetHeader(types.AdminAuthHeader)
if token == "" {
id := h.GetLoginUserId(c)
key := fmt.Sprintf("admin/%d", id)
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c)
} else {
resp.SUCCESS(c)
return
}
var manager model.AdminUser
res := h.DB.Where("id", id).First(&manager)
if res.Error != nil {
resp.NotAuth(c)
return
}
resp.SUCCESS(c, manager)
}
// List 数据列表
func (h *ManagerHandler) List(c *gin.Context) {
var items []model.AdminUser
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
users := make([]vo.AdminUser, 0)
for _, item := range items {
var u vo.AdminUser
err := utils.CopyObject(item, &u)
if err != nil {
continue
}
u.Id = item.Id
u.CreatedAt = item.CreatedAt.Unix()
users = append(users, u)
}
resp.SUCCESS(c, users)
}
func (h *ManagerHandler) Save(c *gin.Context) {
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Status bool `json:"status"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("username", data.Username).First(&user)
if res.Error == nil {
resp.ERROR(c, "用户名已存在")
return
}
// 生成密码
salt := utils.RandString(8)
password := utils.GenPassword(data.Password, salt)
res = h.DB.Save(&model.AdminUser{
Username: data.Username,
Password: password,
Salt: salt,
Status: data.Status,
})
if res.Error != nil {
resp.ERROR(c, "failed with update database")
return
}
resp.SUCCESS(c)
}
// Remove 删除管理员
func (h *ManagerHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id == SuperManagerID {
resp.ERROR(c, "超级管理员不能删除")
return
}
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用/禁用
func (h *ManagerHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// ResetPass 重置密码
func (h *ManagerHandler) ResetPass(c *gin.Context) {
id := h.GetLoginUserId(c)
if id != SuperManagerID {
resp.ERROR(c, "只有超级管理员能够进行该操作")
return
}
var data struct {
Id int `json:"id"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -14,13 +14,10 @@ import (
type ApiKeyHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db}
h.App = app
return &h
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
}
func (h *ApiKeyHandler) Save(c *gin.Context) {
@@ -32,7 +29,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
Value string `json:"value"`
ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"`
UseProxy bool `json:"use_proxy"`
ProxyURL string `json:"proxy_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -41,16 +38,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{}
if data.Id > 0 {
h.db.Find(&apiKey, data.Id)
h.DB.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled
apiKey.UseProxy = data.UseProxy
apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name
res := h.db.Save(&apiKey)
res := h.DB.Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -68,9 +65,14 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := h.db.Find(&items)
res := h.DB.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
@@ -100,7 +102,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -109,10 +111,15 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -0,0 +1,39 @@
package admin
import (
"chatplus/core"
"chatplus/handler"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)
type CaptchaHandler struct {
handler.BaseHandler
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {
CaptchaId string `json:"captcha_id"`
PicPath string `json:"pic_path"`
}
// GetCaptcha 获取验证码
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
var captchaVo CaptchaVo
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
// b64s是图片的base64编码
id, b64s, err := cp.Generate()
if err != nil {
resp.ERROR(c, "生成验证码错误!")
return
}
captchaVo.CaptchaId = id
captchaVo.PicPath = b64s
resp.SUCCESS(c, captchaVo)
}

View File

@@ -0,0 +1,266 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatHandler struct {
handler.BaseHandler
}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type chatItemVo struct {
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Role vo.ChatRole `json:"role"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
}
func (h *ChatHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct {
Title string `json:"title"`
UserId uint `json:"user_id"`
Model string `json:"model"`
CreateAt []string `json:"created_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Title != "" {
session = session.Where("title LIKE ?", "%"+data.Title+"%")
}
if data.UserId > 0 {
session = session.Where("user_id = ?", data.UserId)
}
if data.Model != "" {
session = session.Where("model = ?", data.Model)
}
if len(data.CreateAt) == 2 {
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.ChatItem{}).Count(&total)
var items []model.ChatItem
var list = make([]chatItemVo, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
userIds := make([]uint, 0)
chatIds := make([]string, 0)
roleIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
chatIds = append(chatIds, item.ChatId)
roleIds = append(roleIds, item.RoleId)
}
var messages []model.ChatMessage
var users []model.User
var roles []model.ChatRole
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles)
tokenMap := make(map[string]int)
userMap := make(map[uint]string)
msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole)
for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1
}
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, r := range roles {
var roleVo vo.ChatRole
err := utils.CopyObject(r, &roleVo)
if err != nil {
continue
}
roleMap[r.Id] = roleVo
}
for _, item := range items {
list = append(list, chatItemVo{
UserId: item.UserId,
Username: userMap[item.UserId],
ChatId: item.ChatId,
Title: item.Title,
Model: item.Model,
Token: tokenMap[item.ChatId],
MsgNum: msgMap[item.ChatId],
Role: roleMap[item.RoleId],
CreatedAt: item.CreatedAt.Unix(),
})
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
type chatMessageVo struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Username string `json:"username"`
Content string `json:"content"`
Type string `json:"type"`
Model string `json:"model"`
Token int `json:"token"`
Icon string `json:"icon"`
CreatedAt int64 `json:"created_at"`
}
// Messages 读取聊天记录列表
func (h *ChatHandler) Messages(c *gin.Context) {
var data struct {
UserId uint `json:"user_id"`
Content string `json:"content"`
Model string `json:"model"`
CreateAt []string `json:"created_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Content != "" {
session = session.Where("content LIKE ?", "%"+data.Content+"%")
}
if data.UserId > 0 {
session = session.Where("user_id = ?", data.UserId)
}
if data.Model != "" {
session = session.Where("model = ?", data.Model)
}
if len(data.CreateAt) == 2 {
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.ChatMessage{}).Count(&total)
var items []model.ChatMessage
var list = make([]chatMessageVo, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
userIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
userMap := make(map[uint]string)
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, item := range items {
list = append(list, chatMessageVo{
Id: item.Id,
UserId: item.UserId,
Username: userMap[item.UserId],
Content: item.Content,
Model: item.Model,
Token: item.Tokens,
Icon: item.Icon,
Type: item.Type,
CreatedAt: item.CreatedAt.Unix(),
})
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
} else {
for _, item := range items {
var v vo.HistoryMessage
err := utils.CopyObject(item, &v)
v.CreatedAt = item.CreatedAt.Unix()
v.UpdatedAt = item.UpdatedAt.Unix()
if err == nil {
messages = append(messages, v)
}
}
}
resp.SUCCESS(c, messages)
}
// RemoveChat 删除对话
func (h *ChatHandler) RemoveChat(c *gin.Context) {
chatId := h.GetTrim(c, "chat_id")
if chatId == "" {
resp.ERROR(c, "请传入 ChatId")
return
}
tx := h.DB.Begin()
// 删除聊天记录
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
if res.Error != nil {
resp.ERROR(c, "failed to remove chat message")
return
}
// 删除对话
res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{})
if res.Error != nil {
tx.Rollback() // 回滚
resp.ERROR(c, "failed to remove chat")
return
}
tx.Commit()
resp.SUCCESS(c)
}
// RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -15,26 +15,26 @@ import (
type ChatModelHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Weight int `json:"weight"`
CreatedAt int64 `json:"created_at"`
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -42,18 +42,21 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
}
item := model.ChatModel{
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
Weight: data.Weight}
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
Power: data.Power}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -72,7 +75,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{})
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
@@ -109,7 +117,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -129,7 +137,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -141,13 +149,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
func (h *ChatModelHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -15,13 +15,10 @@ import (
type ChatRoleHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
h := ChatRoleHandler{db: db}
h.App = app
return &h
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或者更新某个角色
@@ -41,7 +38,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&role)
res := h.DB.Save(&role)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -53,9 +50,14 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
}
func (h *ChatRoleHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.db.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
@@ -88,7 +90,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -110,7 +112,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -119,13 +121,18 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
if data.Id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
if res.Error != nil {
resp.ERROR(c, "删除失败!")
return

View File

@@ -14,36 +14,38 @@ import (
type ConfigHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db}
h.App = app
return &h
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ConfigHandler) Update(c *gin.Context) {
var data struct {
Key string `json:"key"`
Config map[string]interface{} `json:"config"`
Key string `json:"key"`
Config struct {
types.SystemConfig
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
str := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: str}
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value}
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
if config.Id > 0 {
config.Config = str
res := h.db.Updates(&config)
config.Config = value
res := h.DB.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -51,12 +53,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// update config cache for AppServer
var cfg model.Config
h.db.Where("marker", data.Key).First(&cfg)
h.DB.Where("marker", data.Key).First(&cfg)
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
} else if data.Key == "chat" {
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
}
if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error())
@@ -70,20 +70,25 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
key := c.Query("key")
var config model.Config
res := h.db.Where("marker", key).First(&config)
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var m map[string]interface{}
err := utils.JsonDecode(config.Config, &m)
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, m)
resp.SUCCESS(c, value)
}

View File

@@ -7,26 +7,25 @@ import (
"chatplus/store/model"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"time"
)
type DashboardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
h := DashboardHandler{db: db}
h.App = app
return &h
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type statsVo struct {
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"`
}
func (h *DashboardHandler) Stats(c *gin.Context) {
@@ -35,37 +34,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
var userCount int64
now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil {
stats.Users = userCount
}
// new chats statistic
var chatCount int64
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
if res.Error == nil {
stats.Chats = chatCount
}
// tokens took stats
var historyMessages []model.HistoryMessage
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages)
var historyMessages []model.ChatMessage
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
for _, item := range historyMessages {
stats.Tokens += item.Tokens
}
// 众筹收入
var rewards []model.Reward
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards)
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入
var orders []model.Order
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders {
stats.Income += item.Amount
}
// 统计7天的订单的图表
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
var statsChart = make(map[string]map[string]float64)
//// 初始化
var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
for i := 0; i < 7; i++ {
var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
userStatistic[initTime] = float64(0)
historyMessagesStatistic[initTime] = float64(0)
incomeStatistic[initTime] = float64(0)
}
// 统计用户7天增加的曲线
var users []model.User
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
if res.Error == nil {
for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
}
}
// 统计7天Token 消耗
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
for _, item := range historyMessages {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
}
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
statsChart["users"] = userStatistic
statsChart["historyMessage"] = historyMessagesStatistic
statsChart["orders"] = incomeStatistic
stats.Chart = statsChart
resp.SUCCESS(c, stats)
}

View File

@@ -17,13 +17,10 @@ import (
type FunctionHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
h := FunctionHandler{db: db}
h.App = app
return &h
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *FunctionHandler) Save(c *gin.Context) {
@@ -44,7 +41,7 @@ func (h *FunctionHandler) Save(c *gin.Context) {
Enabled: data.Enabled,
}
res := h.db.Save(&f)
res := h.DB.Save(&f)
if res.Error != nil {
resp.ERROR(c, "error with save data:"+res.Error.Error())
return
@@ -65,7 +62,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -74,8 +71,13 @@ func (h *FunctionHandler) Set(c *gin.Context) {
}
func (h *FunctionHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Function
res := h.db.Find(&items)
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
@@ -97,7 +99,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Delete(&model.Function{Id: uint(id)})
res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -8,24 +8,28 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OrderHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct {
OrderNo string `json:"order_no"`
Status int `json:"status"`
PayTime []string `json:"pay_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
@@ -35,7 +39,7 @@ func (h *OrderHandler) List(c *gin.Context) {
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if data.OrderNo != "" {
session = session.Where("order_no", data.OrderNo)
}
@@ -44,8 +48,9 @@ func (h *OrderHandler) List(c *gin.Context) {
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
}
session = session.Where("status = ?", types.OrderPaidSuccess)
if data.Status >= 0 {
session = session.Where("status", data.Status)
}
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order
@@ -74,7 +79,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
if id > 0 {
var item model.Order
res := h.db.First(&item, id)
res := h.DB.First(&item, id)
if res.Error != nil {
resp.ERROR(c, "记录不存在!")
return
@@ -85,7 +90,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return
}
res = h.db.Where("id = ?", id).Delete(&model.Order{})
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -0,0 +1,71 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
handler.BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`
Type int `json:"type"`
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Model != "" {
session = session.Where("model", data.Model)
}
if data.Type > 0 {
session = session.Where("type", data.Type)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}

View File

@@ -15,13 +15,10 @@ import (
type ProductHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ProductHandler) Save(c *gin.Context) {
@@ -32,8 +29,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
Discount float64 `json:"discount"`
Enabled bool `json:"enabled"`
Days int `json:"days"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -46,14 +42,13 @@ func (h *ProductHandler) Save(c *gin.Context) {
Price: data.Price,
Discount: data.Discount,
Days: data.Days,
Calls: data.Calls,
ImgCalls: data.ImgCalls,
Power: data.Power,
Enabled: data.Enabled}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -72,7 +67,12 @@ func (h *ProductHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{})
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
@@ -108,7 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return
}
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -128,7 +128,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -142,7 +142,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.Product{})
res := h.DB.Where("id = ?", id).Delete(&model.Product{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -2,6 +2,7 @@ package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
@@ -13,18 +14,20 @@ import (
type RewardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db}
h.App = app
return &h
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Reward
res := h.db.Order("id DESC").Find(&items)
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
@@ -32,7 +35,7 @@ func (h *RewardHandler) List(c *gin.Context) {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.db.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
@@ -57,10 +60,15 @@ func (h *RewardHandler) List(c *gin.Context) {
}
func (h *RewardHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.Reward{})
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -0,0 +1,45 @@
package admin
import (
"chatplus/core"
"chatplus/handler"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
handler.BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
userId := 0
res := h.DB.Create(&model.File{
UserId: userId,
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}

View File

@@ -9,6 +9,7 @@ import (
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -16,17 +17,19 @@ import (
type UserHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
h := UserHandler{db: db}
h.App = app
return &h
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
username := h.GetTrim(c, "username")
@@ -36,7 +39,7 @@ func (h *UserHandler) List(c *gin.Context) {
var users = make([]vo.User, 0)
var total int64
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%")
}
@@ -66,13 +69,12 @@ func (h *UserHandler) Save(c *gin.Context) {
Id uint `json:"id"`
Password string `json:"password"`
Username string `json:"username"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
ChatRoles []string `json:"chat_roles"`
ChatModels []string `json:"chat_models"`
ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"`
Status bool `json:"status"`
Vip bool `json:"vip"`
Power int `json:"power"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -82,18 +84,45 @@ func (h *UserHandler) Save(c *gin.Context) {
var res *gorm.DB
var userVo vo.User
if data.Id > 0 { // 更新
user.Id = data.Id
// 此处需要用 map 更新,用结构体无法更新 0 值
res = h.db.Model(&user).Updates(map[string]interface{}{
"username": data.Username,
"calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status,
"vip": data.Vip,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
"chat_models_json": utils.JsonEncode(data.ChatModels),
"expired_time": utils.Str2stamp(data.ExpiredTime),
})
res = h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, "user not found")
return
}
var oldPower = user.Power
user.Username = data.Username
user.Status = data.Status
user.Vip = data.Vip
user.Power = data.Power
user.ChatRoles = utils.JsonEncode(data.ChatRoles)
user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力日志
if oldPower != user.Power {
mark := types.PowerAdd
amount := user.Power - oldPower
if oldPower > user.Power {
mark = types.PowerSub
amount = oldPower - user.Power
}
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerGift,
Amount: amount,
Balance: user.Power,
Mark: mark,
Model: "管理员",
Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID%d", oldPower, user.Power, h.GetLoginUserId(c)),
CreatedAt: time.Now(),
})
}
} else {
salt := utils.RandString(8)
u := model.User{
@@ -102,21 +131,13 @@ func (h *UserHandler) Save(c *gin.Context) {
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Power: data.Power,
Status: true,
ChatRoles: utils.JsonEncode(data.ChatRoles),
ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: data.Calls,
ImgCalls: data.ImgCalls,
}
res = h.db.Create(&u)
res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id
userVo.CreatedAt = u.CreatedAt.Unix()
@@ -143,7 +164,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.First(&user, data.Id)
res := h.DB.First(&user, data.Id)
if res.Error != nil {
resp.ERROR(c, "No user found")
return
@@ -151,7 +172,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -161,36 +182,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
tx := h.db.Begin()
res := h.db.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除聊天历史记录
res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除登录日志
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c)
}
@@ -198,10 +215,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
var total int64
h.db.Model(&model.UserLoginLog{}).Count(&total)
h.DB.Model(&model.UserLoginLog{}).Count(&total)
offset := (page - 1) * pageSize
var items []model.UserLoginLog
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "获取数据失败")
return

View File

@@ -4,8 +4,11 @@ import (
"chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/utils"
"errors"
"fmt"
"gorm.io/gorm"
"strings"
"github.com/gin-gonic/gin"
@@ -15,6 +18,7 @@ var logger = logger2.GetLogger()
type BaseHandler struct {
App *core.AppServer
DB *gorm.DB
}
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
@@ -57,3 +61,27 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
}
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0
}
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
value, exists := c.Get(types.LoginUserCache)
if exists {
return value.(model.User), nil
}
userId, ok := c.Get(types.LoginUserID)
if !ok {
return model.User{}, errors.New("user not login")
}
var user model.User
res := h.DB.First(&user, userId)
// 更新缓存
if res.Error == nil {
c.Set(types.LoginUserCache, user)
}
return user, res.Error
}

View File

@@ -12,37 +12,34 @@ import (
type ChatModelHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
// 只加载用户订阅的 AI 模型
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
var res *gorm.DB
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
err := utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
}
var models []string
err = utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res := h.db.Where("enabled = ?", true).Where(
h.db.Where("value IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel

View File

@@ -14,27 +14,25 @@ import (
type ChatRoleHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
handler := &ChatRoleHandler{db: db}
handler.App = app
return handler
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List get user list
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error())
return
}
// 获取所有角色
if all {
if userId == 0 || all {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
@@ -49,13 +47,8 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.NotAuth(c)
return
}
var user model.User
h.db.First(&user, userId)
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
@@ -80,7 +73,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// UpdateRole 更新用户聊天角色
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
@@ -94,7 +87,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return
}
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil {
logger.Error("添加应用失败:", err)
resp.ERROR(c, "更新数据库失败!")

View File

@@ -19,7 +19,7 @@ import (
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -103,8 +103,6 @@ func (h *ChatHandler) sendAzureMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
@@ -113,64 +111,64 @@ func (h *ChatHandler) sendAzureMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -181,7 +179,8 @@ func (h *ChatHandler) sendAzureMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -36,7 +36,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -138,63 +135,63 @@ func (h *ChatHandler) sendBaiduMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -205,7 +202,8 @@ func (h *ChatHandler) sendBaiduMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -6,6 +6,7 @@ import (
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@@ -16,6 +17,7 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
@@ -26,26 +28,31 @@ import (
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
const ErrImg = "![](/images/wx.png)"
var ErrImg = "![](/images/wx.png)"
var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
redis *redis.Client
uploadManager *oss.UploaderManager
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
h := ChatHandler{
db: db,
redis: redis,
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
}
h.App = app
return &h
}
var chatConfig types.ChatConfig
func (h *ChatHandler) Init() {
// 如果后台有上传微信客服微信二维码,则覆盖
if h.App.SysConfig.WechatCardURL != "" {
ErrImg = fmt.Sprintf("![](%s)", h.App.SysConfig.WechatCardURL)
}
}
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
@@ -63,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
client := types.NewWsClient(ws)
// get model info
var chatModel model.ChatModel
res := h.db.First(&chatModel, modelId)
res := h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -72,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
@@ -89,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.db.Where("chat_id=?", chatId).First(&chat)
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
@@ -97,28 +104,24 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Value: chatModel.Value,
Weight: chatModel.Weight,
Platform: types.Platform(chatModel.Platform)}
Id: chatModel.Id,
Name: chatModel.Name,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// 初始化聊天配置
var config model.Config
h.db.Where("marker", "chat").First(&config)
err = utils.JsonDecode(config.Config, &chatConfig)
if err != nil {
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
c.Abort()
return
}
h.Init()
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
@@ -126,7 +129,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
h.App.ChatClients.Delete(sessionId)
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
@@ -137,19 +139,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
message := string(msg)
logger.Info("Receive a message: ", message)
//utils.ReplyMessage(client, "这是一条测试消息!")
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, message, client)
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Info("回答完毕: " + string(message))
logger.Infof("回答完毕: %v", message.Content)
}
}
@@ -166,9 +179,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "非法用户,请联系管理员")
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作")
return res.Error
}
var userVo vo.User
@@ -184,14 +197,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Calls < session.Model.Weight {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数%d已不足以支付当前模型的单次对话需要消耗的对话额度%d", userVo.Calls, session.Model.Weight))
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
}
@@ -201,35 +208,34 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure:
req.Temperature = h.App.ChatConfig.Azure.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
break
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
case types.Baidu:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
// TODO 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.db.Where("enabled", true).Find(&items)
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
var tools = make([]interface{}, 0)
var functions = make([]interface{}, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
@@ -247,26 +253,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"required": required,
},
})
functions = append(functions, gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
})
}
//if len(tools) > 0 {
// req.Tools = tools
// req.ToolChoice = "auto"
//}
if len(functions) > 0 {
req.Functions = functions
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
@@ -274,40 +273,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 加载聊天上下文
var chatCtx []interface{}
if h.App.ChatConfig.EnableContext {
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num
tokens := req.MaxTokens
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
if chatConfig.ContextDeep > 0 {
var historyMessages []model.HistoryMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
@@ -317,6 +295,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)
@@ -324,10 +325,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
if session.Model.Platform == types.QWen {
req.Input = map[string]interface{}{"prompt": prompt}
if len(reqMgs) > 0 {
req.Input["messages"] = reqMgs
}
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
switch session.Model.Platform {
case types.Azure:
@@ -340,7 +348,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
@@ -363,9 +372,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
if data.Text == "" && data.ChatId != "" {
var item model.HistoryMessage
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -416,7 +425,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
@@ -434,15 +443,15 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
case types.Baidu:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
if req.Model == "gpt-4-all" || strings.HasPrefix(req.Model, "gpt-4-gizmo-g-") {
apiURL = "https://gpt.bemore.lol/v1/chat/completions"
} else {
apiURL = apiKey.ApiURL
}
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
@@ -469,9 +478,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
var proxyURL string
if h.App.Config.ProxyURL != "" && apiKey.UseProxy { // 使用代理
proxyURL = h.App.Config.ProxyURL
proxy, _ := url.Parse(proxyURL)
if apiKey.ProxyURL != "" { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
@@ -480,7 +488,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform {
case types.Azure:
request.Header.Set("api-key", apiKey.Value)
@@ -496,25 +504,63 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
return client.Do(request)
}
// 扣减用户的对话次数
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
// 仅当用户没有导入自己的 API KEY 时才进行扣减
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
num := 1
if session.Model.Weight > 0 {
num = session.Model.Weight
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
}
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
// 将AI回复消息中生成的图片链接下载到本地
func (h *ChatHandler) extractImgUrl(text string) string {
pattern := `!\[([^\]]*)]\(([^)]+)\)`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(text, -1)
// 下载图片并替换链接地址
for _, match := range matches {
imageURL := match[2]
logger.Debug(imageURL)
// 对于相同地址的图片,已经被替换了,就不再重复下载了
if !strings.Contains(text, imageURL) {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
text = strings.ReplaceAll(text, imageURL, newImgURL)
}
return text
}

View File

@@ -6,27 +6,29 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
if !h.IsLogin(c) {
resp.SUCCESS(c)
return
}
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.db.Find(&roles, roleIds)
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
@@ -58,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
@@ -70,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil {
resp.ERROR(c, "No chats found")
return
@@ -89,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
}
err = h.db.Transaction(func(tx *gorm.DB) error {
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil {
return res.Error
}
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.HistoryMessage{})
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
if res.Error != nil {
return res.Error
}
@@ -116,9 +118,9 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.HistoryMessage
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
@@ -144,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
// 删除当前会话的聊天记录
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.")
return
@@ -179,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
}
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
if res.Error != nil {
resp.ERROR(c, "No chat found")
return

View File

@@ -20,7 +20,7 @@ import (
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -107,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -117,63 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -184,7 +182,8 @@ func (h *ChatHandler) sendChatGLMMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -20,7 +20,7 @@ import (
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -46,6 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
if response.Body != nil {
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
}
return err
} else {
defer response.Body.Close()
@@ -96,7 +100,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
if !utils.IsEmptyValue(tool) {
res := h.db.Where("name = ?", tool.Function.Name).First(&function)
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
@@ -169,9 +173,6 @@ func (h *ChatHandler) sendOpenAiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -179,77 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && toolCall == false {
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -260,17 +261,20 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
@@ -278,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {

View File

@@ -0,0 +1,240 @@
package chatimpl
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
}
return nil
}

View File

@@ -50,15 +50,16 @@ type xunFeiResp struct {
}
var Model2URL = map[string]string{
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -68,13 +69,13 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
@@ -86,6 +87,7 @@ func (h *ChatHandler) sendXunFeiMessage(
}
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
@@ -166,9 +168,6 @@ func (h *ChatHandler) sendXunFeiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -176,63 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -243,7 +243,8 @@ func (h *ChatHandler) sendXunFeiMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
@@ -259,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": float64(req.Temperature),
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",

View File

@@ -0,0 +1,39 @@
package handler
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ConfigHandler struct {
BaseHandler
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, value)
}

View File

@@ -19,21 +19,18 @@ import (
type FunctionHandler struct {
BaseHandler
db *gorm.DB
config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager
proxyURL string
}
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
return &FunctionHandler{
BaseHandler: BaseHandler{
App: server,
DB: db,
},
db: db,
config: config.ApiConfig,
uploadManager: manager,
proxyURL: config.ProxyURL,
}
}
@@ -192,68 +189,49 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
}
logger.Debugf("绘画参数:%+v", params)
// check img calls
var user model.User
tx := h.db.Where("id = ?", params["user_id"]).First(&user)
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "当前用户的绘图次数额度不足")
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return
}
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
return
}
// get image generation api URL
var conf model.Config
var chatConfig types.ChatConfig
tx = h.db.Where("marker", "chat").First(&conf)
if tx.Error != nil {
resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
return
}
err := utils.JsonDecode(conf.Config, &chatConfig)
if err != nil {
resp.ERROR(c, "error with decode chat config: "+err.Error())
return
}
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]), h.App.Config.ProxyURL)
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt
}
imgNum := chatConfig.DallImgNum
if imgNum <= 0 {
imgNum = 1
}
var res imgRes
var errRes ErrRes
var request *req.Request
if apiKey.UseProxy && h.proxyURL != "" {
request = req.C().SetProxyURL(h.proxyURL).R()
if apiKey.ProxyURL != "" {
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else {
request = req.C().R()
}
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, h.proxyURL)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: imgNum,
N: 1,
Size: "1024x1024",
}).
SetErrorResult(&errRes).
@@ -263,7 +241,8 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return
}
// 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res)
// 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
if err != nil {
@@ -272,8 +251,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
}
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// update user's img_calls
h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 更新用户算力
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
h.DB.Where("id", user.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content)
}

View File

@@ -15,32 +15,29 @@ import (
// InviteHandler 用户邀请
type InviteHandler struct {
BaseHandler
db *gorm.DB
}
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
h := InviteHandler{db: db}
h.App = app
return &h
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c)
var inviteCode model.InviteCode
res := h.db.Where("user_id = ?", userId).First(&inviteCode)
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
// 如果邀请码不存在,则创建一个
if res.Error != nil {
code := strings.ToUpper(utils.RandString(8))
for {
res = h.db.Where("code = ?", code).First(&inviteCode)
res = h.DB.Where("code = ?", code).First(&inviteCode)
if res.Error != nil { // 不存在相同的邀请码则退出
break
}
}
inviteCode.UserId = userId
inviteCode.Code = code
h.db.Create(&inviteCode)
h.DB.Create(&inviteCode)
}
var codeVo vo.InviteCode
@@ -65,7 +62,7 @@ func (h *InviteHandler) List(c *gin.Context) {
return
}
userId := h.GetLoginUserId(c)
session := h.db.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
@@ -91,6 +88,6 @@ func (h *InviteHandler) List(c *gin.Context) {
// Hits 访问邀请码
func (h *InviteHandler) Hits(c *gin.Context) {
code := c.Query("code")
h.db.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c)
}

View File

@@ -13,42 +13,43 @@ import (
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
type MidJourneyHandler struct {
BaseHandler
db *gorm.DB
pool *mj.ServicePool
snowflake *service.Snowflake
uploader *oss.UploaderManager
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
h := MidJourneyHandler{
db: db,
return &MidJourneyHandler{
snowflake: snowflake,
pool: pool,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
@@ -85,19 +86,20 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Img string `json:"img"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Weight float32 `json:"weight"`
SessionId string `json:"session_id"`
TaskType string `json:"task_type"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Weight float32 `json:"weight"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -120,11 +122,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
prompt += fmt.Sprintf(" --c %d", data.Chaos)
}
if data.Img != "" {
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
}
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
}
if data.Raw {
prompt += " --style raw"
@@ -142,6 +141,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
prompt += fmt.Sprintf(" %s", data.Model)
}
// 处理融图和换脸的提示词
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
prompt = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
// generate task id
@@ -151,31 +155,60 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
job := model.MidJourneyJob{
Type: types.TaskImage.String(),
Type: data.TaskType,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: prompt,
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
opt := "绘图"
if data.TaskType == types.TaskBlend.String() {
job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
opt = "融图"
} else if data.TaskType == types.TaskSwapFace.String() {
job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
opt = "换脸"
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: int(job.Id),
Id: job.Id,
TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskImage,
Prompt: fmt.Sprintf("%s %s", taskId, prompt),
Type: types.TaskType(data.TaskType),
Prompt: prompt,
UserId: userId,
ImgArr: data.ImgArr,
})
client := h.pool.Clients.Get(uint(job.UserId))
_ = client.Send([]byte("Task Updated"))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -215,13 +248,13 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: int(job.Id),
Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskUpscale,
Prompt: data.Prompt,
@@ -233,7 +266,9 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
})
client := h.pool.Clients.Get(uint(job.UserId))
_ = client.Send([]byte("Task Updated"))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
@@ -261,15 +296,16 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: int(job.Id),
Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskVariation,
Prompt: data.Prompt,
@@ -281,22 +317,64 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
})
client := h.pool.Clients.Get(uint(job.UserId))
_ = client.Send([]byte("Task Updated"))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
// ImgWall 照片墙
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
userId := h.GetInt(c, "user_id", 0)
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
session := h.db.Session(&gorm.Session{})
if status == 1 {
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
@@ -304,6 +382,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish = ?", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
@@ -312,8 +393,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
var items []model.MidJourneyJob
res := session.Find(&items)
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
return res.Error, nil
}
var jobs = make([]vo.MidJourneyJob, 0)
@@ -324,25 +404,21 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
continue
}
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// 正在运行中任务使用代理访问图片
if h.App.Config.ImgCdnURL != "" {
job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL)
} else {
// discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} else {
job.ImgURL = job.OrgURL
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
return nil, jobs
}
// Remove remove task image
@@ -358,7 +434,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
}
// remove job recode
res := h.db.Delete(&model.MidJourneyJob{Id: data.Id})
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -371,7 +447,9 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
}
client := h.pool.Clients.Get(data.UserId)
_ = client.Send([]byte("Task Updated"))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
@@ -396,3 +474,23 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
resp.SUCCESS(c)
}
// Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -7,19 +7,17 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OrderHandler struct {
BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
@@ -31,8 +29,8 @@ func (h *OrderHandler) List(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, _ := utils.GetLoginUser(c, h.db)
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order

View File

@@ -11,6 +11,7 @@ import (
"embed"
"encoding/base64"
"fmt"
"github.com/shopspring/decimal"
"math"
"net/http"
"net/url"
@@ -34,7 +35,6 @@ type PaymentHandler struct {
huPiPayService *payment.HuPiPayService
js *payment.PayJS
snowflake *service.Snowflake
db *gorm.DB
fs embed.FS
lock sync.Mutex
}
@@ -44,20 +44,21 @@ func NewPaymentHandler(
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
js *payment.PayJS,
snowflake *service.Snowflake,
db *gorm.DB,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
h := PaymentHandler{
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
js: js,
snowflake: snowflake,
fs: fs,
db: db,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
}
h.App = server
return &h
}
func (h *PaymentHandler) DoPay(c *gin.Context) {
@@ -70,14 +71,20 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
if order.Status == types.OrderPaidSuccess {
resp.ERROR(c, "This order had been paid, please do not pay twice")
return
}
// 更新扫码状态
h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
if payWay == "alipay" { // 支付宝
// 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL
@@ -101,30 +108,12 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
WapName: "极客学长",
}
res, err := h.huPiPayService.Pay(params)
r, err := h.huPiPayService.Pay(params)
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
resp.ERROR(c, err.Error())
return
}
var r struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
err = utils.JsonDecode(res, &r)
if err != nil {
logger.Debugf(res)
resp.ERROR(c, "error with decode payment result: "+err.Error())
return
}
if r.ErrCode != 0 {
resp.ERROR(c, "error with generate pay url: "+r.ErrMsg)
return
}
c.Redirect(302, r.URL)
}
resp.ERROR(c, "Invalid operations")
@@ -141,7 +130,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
@@ -156,7 +145,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
for {
time.Sleep(time.Second)
var item model.Order
h.db.Where("order_no = ?", data.OrderNo).First(&item)
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
@@ -180,7 +169,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
}
var product model.Product
res := h.db.First(&product, data.ProductId)
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
@@ -192,7 +181,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return
}
var user model.User
res = h.db.First(&user, data.UserId)
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
@@ -214,24 +203,25 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Calls: product.Calls,
ImgCalls: product.ImgCalls,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: product.Price - product.Discount,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.db.Create(&order)
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
@@ -287,10 +277,121 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
}
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
CallbackURL: returnURL,
WapName: "极客学长",
}
r, err := h.huPiPayService.Pay(params)
if err != nil {
logger.Error("error with generating Pay URL: ", err.Error())
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
payURL = r.URL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
returnURL = h.App.Config.JPayConfig.ReturnURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
params := url.Values{}
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
params.Add("out_trade_no", orderNo)
params.Add("body", product.Name)
params.Add("notify_url", notifyURL)
params.Add("auto", "0")
payURL = h.js.PayH5(params)
case "alipay":
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
returnURL = h.App.Config.AlipayConfig.ReturnURL
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
if err != nil {
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
return
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
resp.SUCCESS(c, payURL)
}
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string) error {
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error)
logger.Error(err)
@@ -306,7 +407,7 @@ func (h *PaymentHandler) notify(orderNo string) error {
}
var user model.User
res = h.db.First(&user, order.UserId)
res = h.DB.First(&user, order.UserId)
if res.Error != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error)
logger.Error(err)
@@ -321,29 +422,33 @@ func (h *PaymentHandler) notify(orderNo string) error {
return err
}
var opt string
var power int
if user.Vip { // 已经是 VIP 用户
if remark.Days > 0 { // 只延期 VIP不增加调用次数
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
} else { // 充值点卡,直接增加次数即可
user.Calls += remark.Calls
user.ImgCalls += remark.ImgCalls
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
} else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, calls == 0
} else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, power == 0
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Calls += h.App.SysConfig.VipMonthCalls
user.ImgCalls += h.App.SysConfig.VipMonthImgCalls
user.Power += h.App.SysConfig.VipMonthPower
user.Vip = true
opt = "VIP充值"
power = h.App.SysConfig.VipMonthPower
} else { //点卡days == 0, calls > 0
user.Calls += remark.Calls
user.ImgCalls += remark.ImgCalls
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
}
// 更新用户信息
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
@@ -353,7 +458,8 @@ func (h *PaymentHandler) notify(orderNo string) error {
// 更新订单状态
order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess
res = h.db.Updates(&order)
order.TradeNo = tradeNo
res = h.DB.Updates(&order)
if res.Error != nil {
err := fmt.Errorf("error with update order info: %v", res.Error)
logger.Error(err)
@@ -361,7 +467,23 @@ func (h *PaymentHandler) notify(orderNo string) error {
}
// 更新产品销量
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
// 记录算力充值日志
if opt != "" {
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
}
return nil
}
@@ -389,10 +511,15 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
}
orderNo := c.Request.Form.Get("trade_order_id")
logger.Infof("收到订单支付回调,订单 NO%s", orderNo)
// TODO 是否要保存订单交易流水号
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,订单 NO%s交易流水号%s", orderNo, tradeNo)
err = h.notify(orderNo)
if err = h.huPiPayService.Check(tradeNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
@@ -409,16 +536,17 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return
}
// TODO这里最好用支付宝的公钥签名签证一下交易真假
//res := h.alipayService.TradeVerify(c.Request.Form)
r := h.alipayService.TradeQuery(c.Request.Form.Get("out_trade_no"))
logger.Infof("验证支付结果:%+v", r)
if !r.Success() {
// TODO验证交易签名
res := h.alipayService.TradeVerify(c.Request.Form)
logger.Infof("验证支付结果:%+v", res)
if !res.Success() {
logger.Error("订单校验失败:", res.Message)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(r.OutTradeNo)
tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(res.OutTradeNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
@@ -443,7 +571,16 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
return
}
err = h.notify(orderNo)
// 校验订单支付状态
tradeNo := c.Request.Form.Get("payjs_order_id")
err = h.js.Check(tradeNo)
if err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return

View File

@@ -0,0 +1,67 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
userId := h.GetLoginUserId(c)
session = session.Where("user_id", userId)
if data.Model != "" {
session = session.Where("model", data.Model)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}

View File

@@ -12,20 +12,17 @@ import (
type ProductHandler struct {
BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product
var list = make([]vo.Product, 0)
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product

View File

@@ -16,13 +16,10 @@ const translatePromptTemplate = "Translate the following painting prompt words i
type PromptHandler struct {
BaseHandler
db *gorm.DB
}
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
h := &PromptHandler{db: db}
h.App = app
return h
return &PromptHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Rewrite translate and rewrite prompt with ChatGPT
@@ -35,7 +32,7 @@ func (h *PromptHandler) Rewrite(c *gin.Context) {
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(rewritePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -53,7 +50,7 @@ func (h *PromptHandler) Translate(c *gin.Context) {
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return

View File

@@ -7,37 +7,35 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
BaseHandler
db *gorm.DB
lock sync.Mutex
}
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db, lock: sync.Mutex{}}
h.App = server
return &h
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
func (h *RewardHandler) Verify(c *gin.Context) {
var data struct {
TxId string `json:"tx_id"`
Type string `json:"type"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
@@ -50,7 +48,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
defer h.lock.Unlock()
var item model.Reward
res := h.db.Where("tx_id = ?", data.TxId).First(&item)
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!")
return
@@ -61,18 +59,13 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
tx := h.db.Begin()
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
if data.Type == "chat" {
calls := math.Ceil(item.Amount / h.App.SysConfig.ChatCallPrice)
exchange.Calls = int(calls)
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
} else if data.Type == "img" {
calls := math.Ceil(item.Amount / h.App.SysConfig.ImgCallPrice)
exchange.ImgCalls = int(calls)
res = h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", calls))
}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -81,13 +74,25 @@ func (h *RewardHandler) Verify(c *gin.Context) {
item.Status = true
item.UserId = user.Id
item.Exchange = utils.JsonEncode(exchange)
res = h.db.Updates(&item)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)

View File

@@ -11,8 +11,11 @@ import (
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
@@ -21,23 +24,44 @@ import (
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
pool *sd.ServicePool
uploader *oss.UploaderManager
}
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
h := SdJobHandler{
db: db,
return &SdJobHandler{
pool: pool,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
@@ -48,8 +72,8 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
@@ -116,9 +140,10 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params),
Prompt: data.Prompt,
Progress: 0,
Power: h.App.SysConfig.SdPower,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
@@ -133,21 +158,67 @@ func (h *SdJobHandler) Image(c *gin.Context) {
UserId: userId,
})
// update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
// JobList 获取 stable diffusion 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
userId := h.GetInt(c, "user_id", 0)
// ImgWall 照片墙
func (h *SdJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
session := h.db.Session(&gorm.Session{})
if status == 1 {
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
@@ -155,6 +226,9 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
@@ -163,8 +237,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
var items []model.SdJob
res := session.Find(&items)
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
return res.Error, nil
}
var jobs = make([]vo.SdJob, 0)
@@ -175,18 +248,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
continue
}
if job.Progress == -1 {
h.db.Delete(&model.SdJob{Id: job.Id})
}
if item.Progress < 100 {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*5 {
h.db.Delete(&item)
// 退回绘图次数
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
}
// 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "")
if err == nil {
@@ -195,13 +257,15 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
return nil, jobs
}
// Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -210,7 +274,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
}
// remove job recode
res := h.db.Delete(&model.SdJob{Id: data.Id})
res := h.DB.Delete(&model.SdJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -222,5 +286,30 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *SdJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -4,6 +4,7 @@ import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/sms"
"chatplus/utils"
"chatplus/utils/resp"
"strings"
@@ -17,7 +18,7 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct {
BaseHandler
redis *redis.Client
sms *service.AliYunSmsService
sms *sms.ServiceManager
smtp *service.SmtpService
captcha *service.CaptchaService
}
@@ -25,12 +26,15 @@ type SmsHandler struct {
func NewSmsHandler(
app *core.AppServer,
client *redis.Client,
sms *service.AliYunSmsService,
sms *sms.ServiceManager,
smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler {
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha, smtp: smtp}
handler.App = app
return handler
return &SmsHandler{
redis: client,
sms: sms,
captcha: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
}
// SendCode 发送验证码
@@ -63,7 +67,8 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, "系统已禁用手机号注册!")
return
}
err = h.sms.SendVerifyCode(data.Receiver, code)
err = h.sms.GetService().SendVerifyCode(data.Receiver, code)
}
if err != nil {
resp.ERROR(c, err.Error())

View File

@@ -3,12 +3,6 @@ package handler
import (
"chatplus/service"
"chatplus/service/payment"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
@@ -21,208 +15,3 @@ type TestHandler struct {
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js}
}
type reqBody struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
type resBody struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
func (h *TestHandler) Test(c *gin.Context) {
image(c)
}
func upscale(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/submit/action"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := map[string]string{
"customId": "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a",
"taskId": "1704880156226095",
"notifyHook": "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type queryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func query(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+r.Status)
return
}
resp.SUCCESS(c, res)
}
type errRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func image(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := reqBody{
BotType: "MID_JOURNEY",
Prompt: "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6",
NotifyHook: "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type cbReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (h *TestHandler) Mj(c *gin.Context) {
var data cbReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error(err)
}
logger.Debugf("任务ID%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt)
apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
_, _ = req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress)
}
func (h *TestHandler) initUserNickname(c *gin.Context) {
var users []model.User
tx := h.db.Find(&users)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
for _, u := range users {
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
h.db.Updates(&u)
}
resp.SUCCESS(c)
}
func (h *TestHandler) initMjTaskId(c *gin.Context) {
var jobs []model.MidJourneyJob
tx := h.db.Find(&jobs)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
for _, job := range jobs {
id, _ := h.snowflake.Next(true)
job.TaskId = id
h.db.Updates(&job)
}
resp.SUCCESS(c)
}

View File

@@ -14,14 +14,11 @@ import (
type UploadHandler struct {
BaseHandler
db *gorm.DB
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
handler := &UploadHandler{db: db, uploaderManager: manager}
handler.App = app
return handler
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
@@ -32,9 +29,10 @@ func (h *UploadHandler) Upload(c *gin.Context) {
}
userId := h.GetLoginUserId(c)
res := h.db.Create(&model.File{
UserId: userId,
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
@@ -52,7 +50,7 @@ func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
h.db.Debug().Where("user_id = ?", userId).Find(&items)
h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 {
for _, v := range items {
var file vo.File
@@ -68,3 +66,29 @@ func (h *UploadHandler) List(c *gin.Context) {
resp.SUCCESS(c, files)
}
// Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}

View File

@@ -21,7 +21,6 @@ import (
type UserHandler struct {
BaseHandler
db *gorm.DB
searcher *xdb.Searcher
redis *redis.Client
}
@@ -31,15 +30,14 @@ func NewUserHandler(
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, redis: client}
handler.App = app
return handler
return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
}
// Register user register
func (h *UserHandler) Register(c *gin.Context) {
// parameters process
var data struct {
RegWay string `json:"reg_way"`
Username string `json:"username"`
Password string `json:"password"`
Code string `json:"code"`
@@ -57,8 +55,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检查验证码
var key string
if utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") ||
utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
key = CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
@@ -70,7 +67,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 验证邀请码
inviteCode := model.InviteCode{}
if data.InviteCode != "" {
res := h.db.Where("code = ?", data.InviteCode).First(&inviteCode)
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "无效的邀请码")
return
@@ -79,8 +76,8 @@ func (h *UserHandler) Register(c *gin.Context) {
// check if the username is exists
var item model.User
res := h.db.Where("username = ?", data.Username).First(&item)
if res.RowsAffected > 0 {
res := h.DB.Where("username = ?", data.Username).First(&item)
if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册")
return
}
@@ -95,18 +92,10 @@ func (h *UserHandler) Register(c *gin.Context) {
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.InitChatCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
Power: h.App.SysConfig.InitPower,
}
res = h.db.Create(&user)
res = h.DB.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
@@ -116,21 +105,32 @@ func (h *UserHandler) Register(c *gin.Context) {
// 记录邀请关系
if data.InviteCode != "" {
// 增加邀请数量
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InviteChatCalls > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("calls", gorm.Expr("calls + ?", h.App.SysConfig.InviteChatCalls))
}
if h.App.SysConfig.InviteImgCalls > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", h.App.SysConfig.InviteImgCalls))
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
// 记录邀请算力充值日志
var inviter model.User
h.DB.Where("id", inviteCode.UserId).First(&inviter)
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
})
}
// 添加邀请记录
h.db.Create(&model.InviteLog{
h.DB.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Reward: utils.JsonEncode(types.InviteReward{ChatCalls: h.App.SysConfig.InviteChatCalls, ImgCalls: h.App.SysConfig.InviteImgCalls}),
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
})
}
@@ -166,7 +166,7 @@ func (h *UserHandler) Login(c *gin.Context) {
return
}
var user model.User
res := h.db.Where("username = ?", data.Username).First(&user)
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户名不存在")
return
@@ -186,9 +186,9 @@ func (h *UserHandler) Login(c *gin.Context) {
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.db.Model(&user).Updates(user)
h.DB.Model(&user).Updates(user)
h.db.Create(&model.UserLoginLog{
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
@@ -233,7 +233,7 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
@@ -249,27 +249,23 @@ func (h *UserHandler) Session(c *gin.Context) {
}
type userProfile struct {
Id uint `json:"id"`
Nickname string `json:"nickname"`
Username string `json:"username"`
Avatar string `json:"avatar"`
ChatConfig types.UserChatConfig `json:"chat_config"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
TotalTokens int64 `json:"total_tokens"`
Tokens int64 `json:"tokens"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
Id uint `json:"id"`
Nickname string `json:"nickname"`
Username string `json:"username"`
Avatar string `json:"avatar"`
Power int `json:"power"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
}
func (h *UserHandler) Profile(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
var profile userProfile
err = utils.CopyObject(user, &profile)
if err != nil {
@@ -289,15 +285,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
user.Avatar = data.Avatar
user.Nickname = data.Nickname
res := h.db.Updates(&user)
res := h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新用户信息失败")
return
@@ -322,21 +318,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
password := utils.GenPassword(data.OldPass, user.Salt)
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
if password != user.Password {
resp.ERROR(c, "原密码错误")
return
}
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.db.Model(&user).UpdateColumn("password", newPass)
res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败")
@@ -359,7 +355,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.Where("username", data.Username).First(&user)
res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户不存在!")
return
@@ -375,7 +371,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -405,19 +401,19 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定
var item model.User
res := h.db.Where("username = ?", data.Username).First(&item)
res := h.DB.Where("username = ?", data.Username).First(&item)
if res.Error == nil {
resp.ERROR(c, "该账号已经被其他账号绑定")
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res = h.db.Model(&user).UpdateColumn("username", data.Username)
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return

View File

@@ -12,6 +12,7 @@ import (
"chatplus/service/oss"
"chatplus/service/payment"
"chatplus/service/sd"
"chatplus/service/sms"
"chatplus/service/wx"
"chatplus/store"
"context"
@@ -59,11 +60,13 @@ func main() {
}
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
logger.Info("Loading config file: ", configFile)
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
}
}()
if !debug {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
}
}()
}
app := fx.New(
// 初始化配置应用配置
@@ -122,6 +125,8 @@ func main() {
fx.Provide(handler.NewPaymentHandler),
fx.Provide(handler.NewOrderHandler),
fx.Provide(handler.NewProductHandler),
fx.Provide(handler.NewConfigHandler),
fx.Provide(handler.NewPowerLogHandler),
fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
@@ -133,9 +138,11 @@ func main() {
fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler),
fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler),
// 创建服务
fx.Provide(service.NewAliYunSmsService),
fx.Provide(sms.NewSendServiceManager),
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
return service.NewCaptchaService(config.ApiConfig)
}),
@@ -215,6 +222,7 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
s.Engine.POST("/api/upload", h.Upload)
s.Engine.GET("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/")
@@ -236,14 +244,23 @@ func main() {
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("notify", h.Notify)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
}),
// 管理后台控制器
@@ -257,13 +274,18 @@ func main() {
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
group := s.Engine.Group("/api/admin/user/")
@@ -279,12 +301,12 @@ func main() {
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/")
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/")
@@ -308,6 +330,7 @@ func main() {
group.GET("payWays", h.GetPayWays)
group.POST("query", h.OrderQuery)
group.POST("qrcode", h.PayQrcode)
group.POST("mobile", h.Mobile)
group.POST("alipay/notify", h.AlipayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify)
@@ -359,6 +382,18 @@ func main() {
group.GET("token", h.GenToken)
}),
// 验证码
fx.Provide(admin.NewCaptchaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
group := s.Engine.Group("/api/admin/login/")
group.GET("captcha", h.GetCaptcha)
}),
fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload)
}),
fx.Provide(handler.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
group := s.Engine.Group("/api/function/")
@@ -366,11 +401,21 @@ func main() {
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
s.Engine.GET("/api/test", h.Test)
s.Engine.POST("/api/test/mj", h.Mj)
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
}),
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
group := s.Engine.Group("/api/powerLog/")
group.POST("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
group := s.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db)
@@ -378,7 +423,6 @@ func main() {
log.Fatal(err)
}
}),
// 注册生命周期回调函数
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
lifecycle.Append(fx.Hook{

View File

@@ -2,6 +2,7 @@ package mj
import (
"chatplus/core/types"
"errors"
"fmt"
"time"
@@ -11,13 +12,12 @@ import (
// MidJourney client
type Client struct {
client *req.Client
Config types.MidJourneyConfig
imgCdnURL string
apiURL string
client *req.Client
Config types.MidJourneyConfig
apiURL string
}
func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *Client {
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second)
var apiURL string
// set proxy URL
@@ -30,10 +30,10 @@ func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *C
}
}
return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL}
return &Client{client: client, Config: config, apiURL: apiURL}
}
func (c *Client) Imagine(prompt string) error {
func (c *Client) Imagine(task types.MjTask) error {
interactionsReq := &InteractionsRequest{
Type: 2,
ApplicationID: ApplicationID,
@@ -49,7 +49,7 @@ func (c *Client) Imagine(prompt string) error {
{
"type": 3,
"name": "prompt",
"value": prompt,
"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
},
},
"application_command": map[string]any{
@@ -88,20 +88,28 @@ func (c *Client) Imagine(prompt string) error {
return nil
}
func (c *Client) Blend(task types.MjTask) error {
return errors.New("function not implemented")
}
func (c *Client) SwapFace(task types.MjTask) error {
return errors.New("function not implemented")
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) error {
func (c *Client) Upscale(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
@@ -120,19 +128,19 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) error {
func (c *Client) Variation(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}

View File

@@ -3,9 +3,14 @@ package plus
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"encoding/base64"
"errors"
"fmt"
"github.com/imroc/req/v3"
"io"
"github.com/gin-gonic/gin"
)
var logger = logger2.GetLogger()
@@ -13,16 +18,27 @@ var logger = logger2.GetLogger()
// Client MidJourney Plus Client
type Client struct {
Config types.MidJourneyPlusConfig
apiURL string
}
func NewClient(config types.MidJourneyPlusConfig) *Client {
return &Client{Config: config}
var apiURL string
if config.CdnURL != "" {
apiURL = config.CdnURL
} else {
apiURL = config.ApiURL
}
if config.Mode == "" {
config.Mode = "fast"
}
return &Client{Config: config, apiURL: apiURL}
}
type ImageReq struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
BotType string `json:"botType"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
@@ -47,12 +63,23 @@ type ErrRes struct {
} `json:"error"`
}
func (c *Client) Imagine(prompt string) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL)
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
NotifyHook: c.Config.NotifyURL,
BotType: "MID_JOURNEY",
Prompt: task.Prompt,
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
var res ImageRes
var errRes ErrRes
@@ -63,9 +90,103 @@ func (c *Client) Imagine(prompt string) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"notifyHook": c.Config.NotifyURL,
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
@@ -74,13 +195,13 @@ func (c *Client) Imagine(prompt string) (ImageRes, error) {
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) {
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
"taskId": messageId,
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
@@ -101,13 +222,13 @@ func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, er
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) {
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
"taskId": messageId,
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
@@ -153,7 +274,7 @@ type QueryRes struct {
}
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.Config.ApiURL, taskId)
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).

View File

@@ -58,49 +58,54 @@ func (s *Service) Run() {
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
//if task.ChannelId != "" && task.ChannelId != s.Name {
// logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
// s.taskQueue.RPush(task)
// time.Sleep(time.Second)
// continue
//}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
index := strings.Index(task.Prompt, " ")
res, err = s.Client.Imagine(task.Prompt[index+1:])
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash)
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash)
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
var job model.MidJourneyJob
s.db.Where("id = ?", task.Id).First(&job)
if err != nil || (res.Code != 1 && res.Code != 22) {
logger.Error("绘画任务执行失败:", err)
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(task.UserId)
// restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
// TODO: 任务提交失败,加入队列重试
continue
}
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
"task_id": res.Result,
"channel_id": s.Name,
})
job.TaskId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
@@ -140,26 +145,54 @@ type CBReq struct {
} `json:"properties"`
}
func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0)
job.Prompt = data.Properties.FinalPrompt
if data.ImageUrl != "" {
job.OrgURL = data.ImageUrl
}
job.UseProxy = true
job.MessageId = data.Id
logger.Debugf("JOB: %+v", job)
res := s.db.Updates(&job)
if res.Error != nil {
return fmt.Errorf("error with update job: %v", res.Error)
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
if data.Status == "SUCCESS" {
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
if s.Client.Config.CdnURL != "" {
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
} else {
job.OrgURL = task.ImageUrl
}
}
job.MessageId = task.Id
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&s.HandledTaskNum, -1)
}
s.notifyQueue.RPush(job.UserId)
// 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -6,11 +6,9 @@ import (
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"github.com/go-redis/redis/v8"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
@@ -35,7 +33,6 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
if config.Enabled == false {
continue
}
config.ApiURL = "https://api.chat-plus.net"
client := plus.NewClient(config)
name := fmt.Sprintf("mj-service-plus-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
@@ -52,7 +49,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL)
client := NewClient(config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
@@ -96,6 +93,9 @@ func (p *ServicePool) CheckTaskNotify() {
continue
}
client := p.Clients.Get(userId)
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
@@ -118,30 +118,33 @@ func (p *ServicePool) DownloadImages() {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if v.UseProxy {
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if task.ImageUrl != "" {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false)
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
}
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
}
if err != nil {
logger.Error("error with download image: ", err)
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
@@ -165,7 +168,7 @@ func (p *ServicePool) HasAvailableService() bool {
}
func (p *ServicePool) Notify(data plus.CBReq) error {
logger.Infof("收到任务回调:%+v", data)
logger.Debugf("收到任务回调:%+v", data)
var job model.MidJourneyJob
res := p.db.Where("task_id = ?", data.Id).First(&job)
if res.Error != nil {
@@ -177,7 +180,7 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
return nil
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(data, job)
return servicePlus.Notify(job)
}
return nil
@@ -193,46 +196,38 @@ func (p *ServicePool) SyncTaskProgress() {
continue
}
for _, v := range items {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(v.CreatedAt) > time.Minute*30 {
p.db.Delete(&v)
// 退回绘图次数
p.db.Model(&model.User{}).Where("id = ?", v.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
}
if !strings.HasPrefix(v.ChannelId, "mj-service-plus") {
continue
}
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, err := servicePlus.Client.QueryTask(v.TaskId)
if err != nil {
for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)
// 略过 Upscale 任务
if job.Type != types.TaskUpscale.String() {
continue
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
oldProgress := v.Progress
v.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
v.Prompt = task.PromptEn
if task.ImageUrl != "" {
v.OrgURL = task.ImageUrl
}
v.UseProxy = true
v.MessageId = task.Id
}
p.db.Updates(&v)
if !strings.HasPrefix(job.ChannelId, "mj-service-plus") {
continue
}
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&servicePlus.HandledTaskNum, -1)
}
// 通知前端更新任务进度
if oldProgress != v.Progress {
p.notifyQueue.RPush(v.UserId)
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
@@ -251,11 +246,3 @@ func (p *ServicePool) getServicePlus(name string) *plus.Service {
}
return nil
}
func getImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -65,28 +65,39 @@ func (s *Service) Run() {
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
err = s.client.Imagine(task)
break
case types.TaskUpscale:
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
err = s.client.Upscale(task)
break
case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
err = s.client.Variation(task)
break
case types.TaskBlend:
err = s.client.Blend(task)
break
case types.TaskSwapFace:
err = s.client.SwapFace(task)
break
}
if err != nil {
logger.Error("绘画任务执行失败:", err)
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(task.UserId)
// restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
if task.Type.String() != types.TaskUpscale.String() {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
}
continue
}
logger.Infof("Task Executed: %+v", task)
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
@@ -126,6 +137,12 @@ func (s *Service) Notify(data CBReq) {
} else {
tx = tx.Where("task_id = ?", split[0])
}
// fixed: 修复 U/V 操作任务混淆覆盖的 Bug
if strings.Contains(data.Prompt, "** - Image #") { // for upscale
tx = tx.Where("type = ?", types.TaskUpscale.String())
} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations
tx = tx.Where("type = ?", types.TaskVariation.String())
}
res = tx.First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
@@ -138,10 +155,11 @@ func (s *Service) Notify(data CBReq) {
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL
if s.client.Config.UseCDN {
job.UseProxy = true
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL)
job.OrgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL)
} else {
job.OrgURL = data.Image.URL
}
res = s.db.Updates(&job)

View File

@@ -8,8 +8,8 @@ const (
type InteractionsRequest struct {
Type int `json:"type"`
ApplicationID string `json:"application_id"`
MessageFlags *int `json:"message_flags,omitempty"`
MessageID *string `json:"message_id,omitempty"`
MessageFlags int `json:"message_flags,omitempty"`
MessageID string `json:"message_id,omitempty"`
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
SessionID string `json:"session_id"`

View File

@@ -5,11 +5,13 @@ import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/gin-gonic/gin"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/gin-gonic/gin"
)
type AliYunOss struct {
@@ -66,10 +68,11 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
return File{
Name: file.Filename,
URL: fmt.Sprintf("%s/%s", s.config.Domain, objectKey),
Ext: fileExt,
Size: file.Size,
Name: file.Filename,
ObjKey: objectKey,
URL: fmt.Sprintf("%s/%s", s.config.Domain, objectKey),
Ext: fileExt,
Size: file.Size,
}, nil
}
@@ -88,7 +91,7 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
@@ -99,9 +102,14 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
}
func (s AliYunOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.bucket.DeleteObject(key)
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.bucket.DeleteObject(objectKey)
}
var _ Uploader = AliYunOss{}

View File

@@ -4,11 +4,12 @@ import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"github.com/gin-gonic/gin"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
)
type LocalStorage struct {
@@ -29,7 +30,7 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
return File{}, fmt.Errorf("error with get form: %v", err)
}
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename)
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
if err != nil {
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
}
@@ -41,10 +42,11 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
ext := filepath.Ext(file.Filename)
return File{
Name: file.Filename,
URL: utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path),
Ext: ext,
Size: file.Size,
Name: file.Filename,
ObjKey: path,
URL: utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path),
Ext: ext,
Size: file.Size,
}, nil
}
@@ -54,7 +56,7 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
filename := filepath.Base(parse.Path)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
if err != nil {
return "", fmt.Errorf("error with generate image dir: %v", err)
}
@@ -72,6 +74,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
}
func (s LocalStorage) Delete(fileURL string) error {
if _, err := os.Stat(fileURL); err == nil {
return os.Remove(fileURL)
}
filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1)
return os.Remove(filePath)
}

View File

@@ -5,13 +5,14 @@ import (
"chatplus/utils"
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type MiniOss struct {
@@ -77,7 +78,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
defer fileReader.Close()
fileExt := filepath.Ext(file.Filename)
fileExt := utils.GetImgExt(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Content-Type"),
@@ -87,17 +88,23 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
return File{
Name: file.Filename,
URL: fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key),
Ext: fileExt,
Size: file.Size,
Name: file.Filename,
ObjKey: info.Key,
URL: fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s MiniOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.client.RemoveObject(context.Background(), s.config.Bucket, key, minio.RemoveObjectOptions{})
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{})
}
var _ Uploader = MiniOss{}

View File

@@ -6,12 +6,14 @@ import (
"chatplus/utils"
"context"
"fmt"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"net/url"
"path/filepath"
"time"
)
type QinNiuOss struct {
@@ -74,10 +76,11 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
return File{
Name: file.Filename,
URL: fmt.Sprintf("%s/%s", s.config.Domain, ret.Key),
Ext: fileExt,
Size: file.Size,
Name: file.Filename,
ObjKey: key,
URL: fmt.Sprintf("%s/%s", s.config.Domain, ret.Key),
Ext: fileExt,
Size: file.Size,
}, nil
}
@@ -97,7 +100,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
fileExt := utils.GetImgExt(parse.Path)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
ret := storage.PutRet{}
extra := storage.PutExtra{}
@@ -110,9 +113,15 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
}
func (s QinNiuOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.manager.Delete(s.config.Bucket, key)
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.manager.Delete(s.config.Bucket, objectKey)
}
var _ Uploader = QinNiuOss{}

View File

@@ -2,11 +2,17 @@ package oss
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
type File struct {
Name string `json:"name"`
Size int64 `json:"size"`
URL string `json:"url"`
Ext string `json:"ext"`
Name string `json:"name"`
ObjKey string `json:"obj_key"`
Size int64 `json:"size"`
URL string `json:"url"`
Ext string `json:"ext"`
}
type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error)

View File

@@ -9,11 +9,6 @@ type UploaderManager struct {
handler Uploader
}
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
active := Local
if config.OSS.Active != "" {

View File

@@ -4,6 +4,8 @@ import (
"chatplus/core/types"
"chatplus/utils"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
@@ -17,14 +19,14 @@ import (
type HuPiPayService struct {
appId string
appSecret string
host string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
host: config.HuPiPayConfig.PayURL,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
@@ -42,8 +44,16 @@ type HuPiPayReq struct {
NonceStr string `json:"nonce_str"`
}
type HuPiResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayReq) (string, error) {
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
@@ -55,38 +65,98 @@ func (s *HuPiPayService) Pay(params HuPiPayReq) (string, error) {
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
encode = utils.JsonEncode(params)
m = make(map[string]string)
_ = utils.JsonDecode(encode, &m)
data.Add("hash", s.Sign(m))
resp, err := http.PostForm(s.host, data)
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return "error", err
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return "error", err
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
}
return string(all), err
var res HuPiResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params map[string]string) string {
var data string
keys := make([]string, 0, 0)
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
keys = append(keys, key)
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
//拼接
for _, k := range keys {
data = fmt.Sprintf("%s%s=%s&", data, k, params[k])
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(tradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("open_order_id", tradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
data = strings.Trim(data, "&")
data = fmt.Sprintf("%s%s", data, s.appSecret)
m := md5.New()
m.Write([]byte(data))
sign := fmt.Sprintf("%x", m.Sum(nil))
return sign
}

View File

@@ -5,6 +5,7 @@ import (
"chatplus/utils"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
@@ -28,16 +29,17 @@ type JPayReq struct {
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"body"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"callback_url"`
}
type JPayReps struct {
CodeUrl string `json:"code_url"`
OutTradeNo string `json:"out_trade_no"`
OrderId string `json:"payjs_order_id"`
Qrcode string `json:"qrcode"`
ReturnCode int `json:"return_code"`
ReturnMsg string `json:"return_msg"`
Sign string `json:"sign"`
Sign string `json:"Sign"`
TotalFee string `json:"total_fee"`
CodeUrl string `json:"code_url,omitempty"`
Qrcode string `json:"qrcode,omitempty"`
}
func (r JPayReps) IsOK() bool {
@@ -55,10 +57,11 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
}
p.Add("mchid", js.config.AppId)
p.Add("sign", sign(p, js.config.PrivateKey))
p.Add("sign", js.sign(p))
cli := http.Client{}
r, err := cli.PostForm(js.config.ApiURL, p)
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
r, err := cli.PostForm(apiURL, p)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
@@ -76,7 +79,13 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
return data
}
func sign(params url.Values, priKey string) string {
func (js *PayJS) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *PayJS) sign(params url.Values) string {
params.Del(`sign`)
var keys = make([]string, 0, 0)
for key := range params {
@@ -94,9 +103,46 @@ func sign(params url.Values, priKey string) string {
}
}
var src = strings.Join(pList, "&")
src += "&key=" + priKey
src += "&key=" + js.config.PrivateKey
md5bs := md5.Sum([]byte(src))
md5res := hex.EncodeToString(md5bs[:])
return strings.ToUpper(md5res)
}
// Check 查询订单支付状态
// @param tradeNo 支付平台交易 ID
func (js *PayJS) Check(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{}
params.Add("payjs_order_id", tradeNo)
params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
defer resp.Body.Close()
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ReturnCode int `json:"return_code"`
Status int `json:"status"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ReturnCode == 1 && r.Status == 1 {
return nil
} else {
logger.Errorf("PayJs 支付验证响应:%s", string(body))
return errors.New("order not paid")
}
}

View File

@@ -4,20 +4,26 @@ import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
if config.Enabled == false {
@@ -26,7 +32,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 1, 300, config, queue, db, manager)
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
// run sd service
go func() {
service.Run()
@@ -36,8 +42,11 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
}
return &ServicePool{
taskQueue: queue,
services: services,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
@@ -47,6 +56,66 @@ func (p *ServicePool) PushTask(task types.SdTask) {
p.taskQueue.RPush(task)
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
p.db.Delete(&job)
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
}
}()
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0

View File

@@ -24,6 +24,7 @@ type Service struct {
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
@@ -33,12 +34,13 @@ type Service struct {
taskTimeout int64
}
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
return &Service{
name: name,
config: config,
httpClient: req.C(),
taskQueue: queue,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
uploadManager: manager,
taskTimeout: timeout,
@@ -66,13 +68,16 @@ func (s *Service) Run() {
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err)
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
// restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
// 通知前端,任务失败
s.notifyQueue.RPush(task.UserId)
continue
}
@@ -296,8 +301,12 @@ func (s *Service) callback(data CBReq) {
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// update the task progress
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1)
// restore img_calls
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": data.Message,
})
}
// 发送更新状态信号
s.notifyQueue.RPush(data.UserId)
}

View File

@@ -1,4 +1,4 @@
package service
package sms
import (
"chatplus/core/types"
@@ -7,22 +7,23 @@ import (
)
type AliYunSmsService struct {
config *types.AliYunSmsConfig
config *types.SmsConfigAli
client *dysmsapi.Client
}
func NewAliYunSmsService(config *types.AppConfig) (*AliYunSmsService, error) {
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
config := &appConfig.SMS.Ali
// 创建阿里云短信客户端
client, err := dysmsapi.NewClientWithAccessKey(
"cn-hangzhou",
config.SmsConfig.AccessKey,
config.SmsConfig.AccessSecret)
config.AccessKey,
config.AccessSecret)
if err != nil {
return nil, fmt.Errorf("failed to create client: %v", err)
}
return &AliYunSmsService{
config: &config.SmsConfig,
config: config,
client: client,
}, nil
}
@@ -46,8 +47,7 @@ func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
if response.Code != "OK" {
return fmt.Errorf("failed to send SMS:%v", response.Message)
}
return nil
}
var _ SmsService = &AliYunSmsService{}
var _ Service = &AliYunSmsService{}

72
api/service/sms/bao.go Normal file
View File

@@ -0,0 +1,72 @@
package sms
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
)
type BaoSmsService struct {
config *types.SmsConfigBao
}
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
config := appConfig.SMS.Bao
if config.Domain == "" { // use default domain
config.Domain = "api.smsbao.com"
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
}
return &BaoSmsService{
config: &config,
}
}
var errMsg = map[string]string{
"0": "短信发送成功",
"-1": "参数不全",
"-2": "服务器空间不支持请确认支持curl或者fsocket联系您的空间商解决或者更换空间",
"30": "密码错误",
"40": "账号不存在",
"41": "余额不足",
"42": "账户已过期",
"43": "IP地址限制",
"50": "内容含有敏感词",
}
func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
content := fmt.Sprintf("%s%s", s.config.Sign, s.config.CodeTemplate)
content = strings.ReplaceAll(content, "{code}", strconv.Itoa(code))
password := utils.Md5(s.config.Password)
params := url.Values{}
params.Set("u", s.config.Username)
params.Set("p", password)
params.Set("m", mobile)
params.Set("c", content)
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
response, err := http.Get(apiURL)
if err != nil {
return err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return err
}
result := string(body)
logger.Debugf("send SmsBao result: %v", errMsg[result])
if result != "0" {
return fmt.Errorf("failed to send SMS:%v", errMsg[result])
}
return nil
}
var _ Service = &BaoSmsService{}

View File

@@ -0,0 +1,8 @@
package sms
const Ali = "ALI"
const Bao = "BAO"
type Service interface {
SendVerifyCode(mobile string, code int) error
}

View File

@@ -0,0 +1,39 @@
package sms
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"strings"
)
type ServiceManager struct {
handler Service
}
var logger = logger2.GetLogger()
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
active := Ali
if config.SMS.Active != "" {
active = strings.ToUpper(config.SMS.Active)
}
var handler Service
switch active {
case Ali:
client, err := NewAliYunSmsService(config)
if err != nil {
return nil, err
}
handler = client
break
case Bao:
handler = NewSmsBaoSmsService(config)
break
}
return &ServiceManager{handler: handler}, nil
}
func (m *ServiceManager) GetService() Service {
return m.handler
}

View File

@@ -1,5 +0,0 @@
package service
type SmsService interface {
SendVerifyCode(mobile string, code int) error
}

View File

@@ -64,12 +64,12 @@ func (b *Bot) messageHandler(msg *openwechat.Message) {
msg.AppMsgType == openwechat.AppMsgTypeUrl {
// 解析支付金额
message := parseTransactionMessage(msg.Content)
if message.Url != "" {
transaction := extractTransaction(message)
logger.Infof("解析到收款信息:%+v", transaction)
transaction := extractTransaction(message)
logger.Infof("解析到收款信息:%+v", transaction)
if transaction.TransId != "" {
var item model.Reward
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
if res.Error == nil {
if item.Id > 0 {
logger.Error("当前交易 ID 己经存在!")
return
}

View File

@@ -39,14 +39,30 @@ func parseTransactionMessage(xmlData string) *Message {
}
break
}
if se.Name.Local == "weapp_path" && !strings.Contains(message.Url, "customerDetails.html") {
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
if err := decoder.DecodeElement(&value, &se); err == nil {
message.Url = strings.TrimSpace(value)
if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
message.Url = value
}
}
break
}
}
}
// 兼容旧版消息记录
if message.Url == "" {
var msg struct {
XMLName xml.Name `xml:"msg"`
AppMsg struct {
Des string `xml:"des"`
Url string `xml:"url"`
} `xml:"appmsg"`
}
if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil {
message.Url = msg.AppMsg.Url
}
}
return &message
}
@@ -80,6 +96,10 @@ func extractTransaction(message *Message) Transaction {
parse, err := url.Parse(message.Url)
if err == nil {
tx.TransId = parse.Query().Get("id")
if tx.TransId == "" {
tx.TransId = parse.Query().Get("trans_id")
}
}
return tx
}

View File

@@ -39,13 +39,14 @@ func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor {
func (e *XXLJobExecutor) Run() error {
e.executor.RegTask("ClearOrders", e.ClearOrders)
e.executor.RegTask("ResetVipCalls", e.ResetVipCalls)
e.executor.RegTask("ResetVipPower", e.ResetVipPower)
e.executor.RegTask("ResetUserPower", e.ResetUserPower)
return e.executor.Run()
}
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Debug("执行清理未支付订单...")
logger.Info("执行清理未支付订单...")
var sysConfig model.Config
res := e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
@@ -64,15 +65,17 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
timeout := time.Now().Unix() - int64(config.OrderPayTimeout)
start := utils.Stamp2str(timeout)
// 这里不是用软删除,而是永久删除订单
res = e.db.Unscoped().Where("status != ? AND created_at < ?", types.OrderPaidSuccess, start).Delete(&model.Order{})
return fmt.Sprintf("Clear order successfully, affect rows: %d", res.RowsAffected)
res = e.db.Unscoped().Where("status IN ? AND created_at < ?", []types.OrderStatus{types.OrderNotPaid, types.OrderScanned}, start).Delete(&model.Order{})
logger.Infof("Clear order successfully, affect rows: %d", res.RowsAffected)
return "success"
}
// ResetVipCalls 清理过期的 VIP 会员
func (e *XXLJobExecutor) ResetVipCalls(cxt context.Context, param *xxl.RunReq) (msg string) {
// ResetVipPower 重置VIP会员算力
// 自动将 VIP 会员的算力补充到每月赠送的最大值
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("开始进行月底账号盘点...")
var users []model.User
res := e.db.Where("vip = ?", 1).Find(&users)
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
if res.Error != nil {
return "No vip users found"
}
@@ -89,60 +92,92 @@ func (e *XXLJobExecutor) ResetVipCalls(cxt context.Context, param *xxl.RunReq) (
return "error with decode system config: " + err.Error()
}
// 获取本月月初时间
currentTime := time.Now()
year, month, _ := currentTime.Date()
firstOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, currentTime.Location()).Unix()
for _, u := range users {
// 账号到期,直接清零
if u.ExpiredTime <= currentTime.Unix() {
logger.Info("账号过期:", u.Username)
u.Calls = 0
// 处理过期的 VIP
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
u.Vip = false
} else {
if u.Calls <= 0 {
u.Calls = 0
}
if u.ImgCalls <= 0 {
u.ImgCalls = 0
}
// 如果该用户当月有充值点卡,则将点卡中未用完的点数结余到下个月
var orders []model.Order
e.db.Debug().Where("user_id = ? AND pay_time > ?", u.Id, firstOfMonth).Find(&orders)
var calls = 0
var imgCalls = 0
for _, o := range orders {
var remark types.OrderRemark
err = utils.JsonDecode(o.Remark, &remark)
if err != nil {
continue
}
if remark.Days > 0 { // 会员续费
continue
}
calls += remark.Calls
imgCalls += remark.ImgCalls
}
if u.Calls > calls { // 本月套餐没有用完
u.Calls = calls + config.VipMonthCalls
} else {
u.Calls = u.Calls + config.VipMonthCalls
}
if u.ImgCalls > imgCalls { // 本月套餐没有用完
u.ImgCalls = imgCalls + config.VipMonthImgCalls
} else {
u.ImgCalls = u.ImgCalls + config.VipMonthImgCalls
}
logger.Infof("%s 点卡结余:%d", u.Username, calls)
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
continue
}
u.Tokens = 0
// update user
e.db.Updates(&u)
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: config.VipMonthPower,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
}
}
logger.Info("月底盘点完成!")
return "success"
}
func (e *XXLJobExecutor) ResetUserPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("今日算力派发开始:", time.Now())
var users []model.User
res := e.db.Where("status", 1).Find(&users)
if res.Error != nil {
return "No matching users"
}
var sysConfig model.Config
res = e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return "error with get system config: " + res.Error.Error()
}
var config types.SystemConfig
err := utils.JsonDecode(sysConfig.Config, &config)
if err != nil {
return "error with decode system config: " + err.Error()
}
if config.DailyPower <= 0 {
return "success"
}
var counter = 0
var totalPower = 0
for _, u := range users {
if u.Power >= config.DailyPower {
continue
}
var power = config.DailyPower - u.Power
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
// 记录算力充值日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerGift,
Amount: power,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统赠送",
Remark: fmt.Sprintf("系统每日算力派发,今日额度:%d", config.DailyPower),
CreatedAt: time.Now(),
})
}
counter++
totalPower += power
}
logger.Infof("今日派发算力结束!累计派发 %d 人,累计派发算力:%d", counter, totalPower)
return "success"
}
type customLogger struct{}
func (l *customLogger) Info(format string, a ...interface{}) {

View File

@@ -0,0 +1,11 @@
package model
type AdminUser struct {
BaseModel
Username string
Password string
Salt string // 密码盐
Status bool `gorm:"default:true"` // 当前状态
LastLoginAt int64 // 最后登录时间
LastLoginIp string // 最后登录 IP
}

View File

@@ -9,6 +9,6 @@ type ApiKey struct {
Value string // API Key 的值
ApiURL string // 当前 KEY 的 API 地址
Enabled bool // 是否启用
UseProxy bool // 是否使用代理访问 API URL
ProxyURL string // 代理地址
LastUsedAt int64 // 最后使用时间
}

View File

@@ -2,11 +2,12 @@ package model
import "gorm.io/gorm"
type HistoryMessage struct {
type ChatMessage struct {
BaseModel
ChatId string // 会话 ID
UserId uint // 用户 ID
RoleId uint // 角色 ID
Model string // AI模型
Type string
Icon string
Tokens int
@@ -15,6 +16,6 @@ type HistoryMessage struct {
DeletedAt gorm.DeletedAt
}
func (HistoryMessage) TableName() string {
func (ChatMessage) TableName() string {
return "chatgpt_chat_history"
}

View File

@@ -7,7 +7,8 @@ type ChatItem struct {
ChatId string `gorm:"column:chat_id;unique"` // 会话 ID
UserId uint // 用户 ID
RoleId uint // 角色 ID
ModelId uint // 会话模型
ModelId uint // 模型 ID
Model string // 模型
Title string // 会话标题
DeletedAt gorm.DeletedAt
}

View File

@@ -2,11 +2,14 @@ package model
type ChatModel struct {
BaseModel
Platform string
Name string
Value string // API Key 的值
SortNum int
Enabled bool
Weight int // 对话权重,每次对话扣减多少次对话额度
Open bool // 是否开放模型给所有人使用
Platform string
Name string
Value string // API Key 的值
SortNum int
Enabled bool
Power int // 每次对话消耗算力
Open bool // 是否开放模型给所有人使用
MaxTokens int // 最大响应长度
MaxContext int // 最大上下文长度
Temperature float32 // 模型温度
}

View File

@@ -4,8 +4,9 @@ import "time"
type File struct {
Id uint `gorm:"primarykey;column:id"`
UserId uint
UserId int
Name string
ObjKey string
URL string
Ext string
Size int64

View File

@@ -10,6 +10,6 @@ type InviteLog struct {
UserId uint
Username string
InviteCode string
Reward string `gorm:"column:reward_json"` // 邀请奖励
Remark string
CreatedAt time.Time
}

View File

@@ -15,7 +15,10 @@ type MidJourneyJob struct {
Hash string // message hash
Progress int
Prompt string
UseProxy bool // 是否使用反代加载图片
UseProxy bool // 是否使用反代加载图片
Publish bool //是否发布图片到画廊
ErrMsg string // 报错信息
Power int // 消耗算力
CreatedAt time.Time
}

View File

@@ -12,6 +12,7 @@ type Order struct {
ProductId uint
Username string
OrderNo string
TradeNo string
Subject string
Amount float64
Status types.OrderStatus

View File

@@ -0,0 +1,20 @@
package model
import (
"chatplus/core/types"
"time"
)
// PowerLog 算力消费日志
type PowerLog struct {
Id uint `gorm:"primarykey;column:id"`
UserId uint
Username string
Type types.PowerType
Amount int
Balance int
Model string // 模型
Remark string // 备注
Mark types.PowerMark // 资金类型
CreatedAt time.Time
}

View File

@@ -7,8 +7,7 @@ type Product struct {
Price float64
Discount float64
Days int
Calls int
ImgCalls int
Power int
Enabled bool
Sales int
SortNum int

View File

@@ -11,6 +11,9 @@ type SdJob struct {
Progress int
Prompt string
Params string
Publish bool //是否发布图片到画廊
ErrMsg string // 报错信息
Power int // 消耗算力
CreatedAt time.Time
}

View File

@@ -7,9 +7,7 @@ type User struct {
Password string
Avatar string
Salt string // 密码盐
TotalTokens int64 // 总消耗 tokens
Calls int // 剩余对话次数
ImgCalls int // 剩余绘图次数
Power int // 剩余算力
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型
@@ -18,5 +16,4 @@ type User struct {
LastLoginAt int64 // 最后登录时间
LastLoginIp string // 最后登录 IP
Vip bool // 是否 VIP 会员
Tokens int
}

View File

@@ -11,7 +11,7 @@ import (
func NewGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Warn),
Logger: logger.Default.LogMode(logger.Silent),
NamingStrategy: schema.NamingStrategy{
TablePrefix: "chatgpt_", // 设置表前缀
SingularTable: false, // 使用单数表名形式

View File

@@ -0,0 +1,10 @@
package vo
type AdminUser struct {
BaseVo
Username string `json:"username"`
Status bool `json:"status"` // 当前状态
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
RoleIds interface{} `json:"role_ids"` //角色ids
}

View File

@@ -9,6 +9,6 @@ type ApiKey struct {
Value string `json:"value"` // API Key 的值
ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"`
UseProxy bool `json:"use_proxy"`
ProxyURL string `json:"proxy_url"`
LastUsedAt int64 `json:"last_used_at"` // 最后使用时间
}

View File

@@ -5,6 +5,7 @@ type HistoryMessage struct {
ChatId string `json:"chat_id"`
UserId uint `json:"user_id"`
RoleId uint `json:"role_id"`
Model string `json:"model"`
Type string `json:"type"`
Icon string `json:"icon"`
Tokens int `json:"tokens"`

View File

@@ -7,5 +7,6 @@ type ChatItem struct {
RoleId uint `json:"role_id"`
ChatId string `json:"chat_id"`
ModelId uint `json:"model_id"`
Model string `json:"model"`
Title string `json:"title"`
}

View File

@@ -2,11 +2,14 @@ package vo
type ChatModel struct {
BaseVo
Platform string `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Weight int `json:"weight"`
Open bool `json:"open"`
Platform string `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Power int `json:"power"`
Open bool `json:"open"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
}

Some files were not shown because too many files have changed in this diff Show More