Compare commits

...

369 Commits

Author SHA1 Message Date
RockYang
fbfa2a71a9 release v4.0.3 2024-04-15 08:26:07 +08:00
RockYang
a7237fe62f Merge branch 'main' of gitee.com:blackfox/chatgpt-plus 2024-04-07 18:33:55 +08:00
RockYang
c3c454b7d7 Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2024-04-07 18:32:20 +08:00
RockYang
d4d708d44b update database files 2024-04-07 18:26:45 +08:00
RockYang
7f0b6a3a46 set the enable status for adding new api key with default value true 2024-04-07 10:17:10 +08:00
RockYang
c2a7c089d2 update change log 2024-04-07 08:40:43 +08:00
RockYang
df5bd4df60 Merge branch 'dev' 2024-04-07 08:03:14 +08:00
RockYang
79b6010104 optimize ngin configuration for chat-plus.conf 2024-04-07 08:02:47 +08:00
RockYang
97b0a98793 fix 404 error with remove api keys 2024-04-07 08:01:25 +08:00
RockYang
6a3e26b566 update change log 2024-04-05 06:58:13 +08:00
RockYang
0355c37bef feat: stable diffusion image drawing on mobile is ready 2024-04-03 18:13:48 +08:00
RockYang
9b7ee538c4 feat: midjourney role and style consistency is ready 2024-04-02 19:01:28 +08:00
RockYang
d900a3d08e update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-04-02 10:00:45 +00:00
RockYang
cdf5b66729 Update README.md 2024-04-02 17:59:53 +08:00
RockYang
1cff4b63cd feat: image preview for stable-diffusion task is ready 2024-04-02 17:24:38 +08:00
RockYang
da14309ef9 feat: support midjourney --cref and --sref for role consistency 2024-04-02 14:59:53 +08:00
RockYang
fbb216fe3b feat: update menu icons, add version in site titles 2024-04-01 18:20:00 +08:00
RockYang
95efbd5659 show power for chat and imaging page 2024-03-31 20:49:12 +08:00
RockYang
4596c1049c opt: change the relative path with absolute path for midjourney image uploading 2024-03-31 17:45:22 +08:00
RockYang
b35d95f0c7 Merge branch 'main' into dev 2024-03-30 11:57:31 +08:00
RockYang
01419df998 remove dead code 2024-03-30 11:57:23 +08:00
RockYang
a6c00c42fa fixed conflicts 2024-03-29 18:10:59 +08:00
RockYang
4cc9db7115 update readme file 2024-03-29 18:06:55 +08:00
RockYang
4f1ed54059 update databases 2024-03-29 17:43:38 +08:00
RockYang
8227a73e35 feat: support custom menu 2024-03-29 15:41:58 +08:00
RockYang
adfd8c1939 add tip for mj image buttons 2024-03-28 21:41:49 +08:00
RockYang
8eed7ff534 fix: fix overflow hidden for admin page 2024-03-28 18:51:09 +08:00
RockYang
c79c4e74d0 fix: fix overflow hidden for mobile page 2024-03-28 18:13:33 +08:00
RockYang
f1855fd0a1 fix: use slide captcha for iphone 2024-03-28 15:00:53 +08:00
RockYang
1f964c74e9 feat: add mj_action_power system config item 2024-03-28 09:53:41 +08:00
RockYang
4fb2c5803c feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy 2024-03-27 18:57:15 +08:00
RockYang
b5947545cb feat: auto translate and rewrite prompt for midjourney and stable-diffusion 2024-03-27 13:45:52 +08:00
RockYang
342b76f666 feat: stable-diffusion refactored, replace websocket api with sdapi 2024-03-26 18:23:08 +08:00
RockYang
49b5906bc7 fix: can not change user's power in admin console 2024-03-25 11:40:03 +08:00
RockYang
3075bfb7fc fix: fix bug for update user's power in admin page did not work 2024-03-23 16:21:37 +08:00
RockYang
82e06fad33 No need to login with Stable-Diffusion page and Invite page 2024-03-23 15:45:37 +08:00
RockYang
4a9028747b fix: fix code highlight error when add formule detecting 2024-03-22 19:09:04 +08:00
RockYang
4a8ff0ccf0 chore: correct prompt messages 2024-03-22 18:27:57 +08:00
RockYang
99341f0484 feat: add chart for admin dashbord 2024-03-22 16:57:30 +08:00
RockYang
f58ac29ad0 feat: integrate xxl-job-admin to implements automatic task scheduling 2024-03-22 13:47:16 +08:00
RockYang
7060edb3e5 feat: save prompt in power log for dalle-3 2024-03-21 15:55:39 +08:00
RockYang
41ae411f9b feat: add manager list page in console page 2024-03-21 15:24:28 +08:00
RockYang
79b7fee47c feat: add powerlog page for admin console 2024-03-21 13:46:39 +08:00
RockYang
0044bf10af opt: optimize the formula show styles 2024-03-21 11:04:12 +08:00
RockYang
e9348d3611 always parse authorization token for all request 2024-03-20 21:11:52 +08:00
RockYang
b9236e09a7 fixed conflicts 2024-03-20 20:40:22 +08:00
RockYang
09b38d5f42 update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-03-20 20:39:44 +08:00
RockYang
7bb539a06e feat: Async loading midjourney job for mobile MidJourney page 2024-03-20 18:39:14 +08:00
RockYang
5cdada8265 feat: h5 payment for payjs is ready 2024-03-20 17:46:39 +08:00
RockYang
4147c217b1 feat: payment for mobile pages is ready 2024-03-20 16:14:02 +08:00
RockYang
8dda639b23 feat: the power log page is ready 2024-03-20 14:14:30 +08:00
RockYang
8487d2c9eb feat: no need login refactor with member and chatApps page 2024-03-19 18:59:02 +08:00
RockYang
c5e583b215 feat: load preview page do not require user to login 2024-03-19 18:25:01 +08:00
RockYang
549f618cff feat: optimize login dialog 2024-03-19 10:47:13 +08:00
RockYang
e9a3510346 feat: refactoring adjustments for reward page is ready 2024-03-18 18:28:34 +08:00
RockYang
30e6e963b3 feat: refactoring adjustments for member pages 2024-03-18 16:59:07 +08:00
RockYang
c72d963f45 feat: The 'chat_models' field of user table, holds the model IDS in place of the model values 2024-03-18 15:37:46 +08:00
RockYang
172d498618 remove new-ui files 2024-03-18 12:01:34 +08:00
RockYang
313993532e chore: adjust page styles 2024-03-18 06:46:08 +08:00
RockYang
e53db3582c 重构主体工作完成 2024-03-15 18:35:10 +08:00
RockYang
72c6bd3f77 restore new ui files 2024-03-15 11:13:02 +08:00
廖彦棋
ca8b349df3 fix(ui): 交互修复调整 2024-03-15 11:07:41 +08:00
RockYang
1b206c3640 feat: refactor user list page for new UI 2024-03-15 09:29:19 +08:00
廖彦棋
c60276fc9f fix(ui): 用户管理有效期传参调整,权限标识补充 2024-03-15 09:25:06 +08:00
廖彦棋
d00a3167c0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:49:18 +08:00
廖彦棋
6b1cd8c30c refactor(ui): 无权限页面调整 2024-03-14 17:49:15 +08:00
吴汉强
46f12dc9ad feat(ui): 后台首页去掉权限判断 2024-03-14 17:24:40 +08:00
廖彦棋
a3e1d8ae21 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:19:41 +08:00
廖彦棋
72a066b93e feat(ui): 无权限判断 2024-03-14 17:19:39 +08:00
吴汉强
0327a829ac Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 17:12:53 +08:00
吴汉强
882e9b8819 feat(ui): 新增 sql 2024-03-14 17:12:48 +08:00
廖彦棋
ef58cfadaa Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:06:15 +08:00
吴汉强
bf958d6113 feat(ui): 403,没权限 2024-03-14 16:41:38 +08:00
吴汉强
71611273d7 feat(ui): 网站配置需授权,去掉 2024-03-14 16:28:49 +08:00
廖彦棋
b27c654311 fix(ui): 调整 2024-03-14 16:14:49 +08:00
吴汉强
90930ea9f9 feat(ui): 后端加权限验证 2024-03-14 15:39:12 +08:00
廖彦棋
1ab2185ff1 feat(ui): 角色管理 2024-03-14 15:25:17 +08:00
廖彦棋
0f2f978d4c feat(ui): 新增系统分类菜单 2024-03-14 15:17:53 +08:00
廖彦棋
f61963b0b0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 10:56:56 +08:00
廖彦棋
2aa413960d fix(ui): 修复 2024-03-14 10:56:54 +08:00
吴汉强
aa4bbba5ec Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 10:28:39 +08:00
吴汉强
eba61fea2d feat(ui): 登录接口返回权限 2024-03-14 10:28:32 +08:00
廖彦棋
34e3455128 feat(ui): 管理后台新增权限及部分组合式函数优化 2024-03-14 10:27:09 +08:00
廖彦棋
07dca3e739 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 17:30:26 +08:00
廖彦棋
4cb4b145f9 feat(ui): web移动端初始化 2024-03-13 17:30:24 +08:00
吴汉强
1ed417cb69 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 17:24:38 +08:00
吴汉强
6cf91a84ca feat(ui): 增加角色管理,管理员方法新增角色关联 2024-03-13 17:24:30 +08:00
chenzifan
0b566980fc Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 14:40:45 +08:00
chenzifan
f86176b342 refactor: remove api change to post request 2024-03-13 14:40:38 +08:00
吴汉强
c700b32670 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 11:41:07 +08:00
吴汉强
22641b452a feat(ui): 增加系统权限管理 2024-03-13 11:41:01 +08:00
廖彦棋
d3fbb8c19e fix(ui): ui调整 2024-03-13 09:54:20 +08:00
RockYang
e3bb69ff10 docs: update sql file 2024-03-13 08:49:40 +08:00
RockYang
770360c614 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:48:10 +08:00
chenzifan
f302a0478f fix: 删除系统管理员失效的问题 2024-03-13 08:47:17 +08:00
chenzifan
a88697b43a Merge remote-tracking branch 'origin/ui' into ui
# Conflicts:
#	api/handler/admin/admin_user_handler.go
2024-03-13 08:46:16 +08:00
chenzifan
cc6f140812 fix: 删除系统管理员失效的问题 2024-03-13 08:45:09 +08:00
廖彦棋
424f2b3bdc Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:44:50 +08:00
廖彦棋
ec0c13a600 feat(ui): 调整 2024-03-13 08:44:48 +08:00
chenzf@pvc123.com
a1f03bec4c feat: 超级管理员不支持修改和删除 2024-03-12 21:16:05 +08:00
RockYang
b5bd4a5e0e Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 18:07:24 +08:00
RockYang
7c2e49bfdb fix conflicts 2024-03-12 18:07:19 +08:00
chenzifan
f80fe6d041 feat: 增加系统管理员 2024-03-12 18:06:49 +08:00
RockYang
72f80a96bc fix conflicts 2024-03-12 18:03:24 +08:00
RockYang
2de655a1cf refactor: use power replace calls for front pages 2024-03-12 17:47:06 +08:00
RockYang
da2bd4a501 refactor: 重构项目,为所有的 AI 工具都引入算力,采用算力统一结算各个工具的调用次数和权限 2024-03-12 15:40:44 +08:00
廖彦棋
e0aa62c40d Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 08:37:29 +08:00
廖彦棋
9d26a892d1 refactor(ui): 调整 2024-03-12 08:37:27 +08:00
huangqj
4ece7f2847 fix(ui):环境变量 2024-03-11 16:10:49 +08:00
廖彦棋
32368caf1b feat(ui): 新增系统管理员 2024-03-11 15:59:15 +08:00
廖彦棋
e91f54e79e fix(ui): 删除冗余 2024-03-11 14:23:11 +08:00
廖彦棋
bb8f4c57c4 fix(ui): 删除冗余 2024-03-11 14:22:39 +08:00
RockYang
43bfac99b6 feat: replace Tools param with Function param for OpenAI chat API 2024-03-11 14:09:19 +08:00
廖彦棋
be379b6d63 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-11 13:52:24 +08:00
廖彦棋
17f3c9b840 fix(ui): type 2024-03-11 13:52:22 +08:00
chenzifan
24de97fac2 feat: 优化后台UI 2024-03-11 13:51:26 +08:00
chenzifan
bf27b44fee Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 13:46:46 +08:00
廖彦棋
1802b4fe4d refactor(ui): 调整优化 2024-03-11 13:46:08 +08:00
廖彦棋
241a5c7bc9 feat(ui): 细节优化 2024-03-11 12:02:20 +08:00
廖彦棋
557d547bf1 feat(ui): 上传功能补充 2024-03-11 11:41:50 +08:00
廖彦棋
2e7b75affb feat(ui): 登录新增验证码及记住密码功能 2024-03-11 10:49:13 +08:00
廖彦棋
bc21a1d443 feat(ul): 顶部信息 2024-03-11 09:00:00 +08:00
huangqj
3fc9e10a24 feat(ui):看板图表 样式调整 2024-03-11 08:35:21 +08:00
chenzifan
5fa1aa2060 Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 08:01:54 +08:00
RockYang
be8a0ec184 opt: remove global keyup event bind 2024-03-10 09:45:59 +08:00
RockYang
b02e3aad95 docs: update config.toml 2024-03-10 09:45:59 +08:00
RockYang
08eca511ad docs: add database sql file for v3.2.7 2024-03-10 09:45:59 +08:00
RockYang
c34e911596 feat: allow to view chat message in manager console 2024-03-10 09:45:59 +08:00
RockYang
8a452c3072 feat: download image which ai generated in dialog and replace the image url 2024-03-10 09:45:59 +08:00
RockYang
13bfb14107 feat: image-wall page for mobile is ready 2024-03-10 09:45:59 +08:00
RockYang
4188b0969e feat: allow user config third-party platform openai and mj api key 2024-03-10 09:45:59 +08:00
RockYang
0c27795a10 feat: added delete file function 2024-03-10 09:45:59 +08:00
RockYang
d05693c5c1 fix: verifycation component touch event coordinates misplace in iphone browser 2024-03-10 09:45:59 +08:00
RockYang
c0b2063b38 fix: fix bug for regenerate button did not work 2024-03-10 09:45:59 +08:00
RockYang
4d183747b1 fix: Upscale and Variation task overrite each other 2024-03-10 09:45:59 +08:00
RockYang
08fe1b2f75 feat: midjourney mobile page all function is ready 2024-03-10 09:45:59 +08:00
RockYang
db3e8a267e feat: mobile mj list page is ready 2024-03-10 09:45:59 +08:00
RockYang
8fc62682c4 feat: add mj image list component for mobile page. fixed bug for html tag escape 2024-03-10 09:45:59 +08:00
RockYang
75031914a3 feat: add functions mj page for mobile 2024-03-10 09:45:59 +08:00
RockYang
a4c9fdd95a opt: add default extension for mj image 2024-03-10 09:45:59 +08:00
RockYang
6a9bfeb5aa feat: mj for mobile page payout is ready 2024-03-10 09:45:59 +08:00
RockYang
e654766f60 add docs and github link 2024-03-10 09:45:59 +08:00
RockYang
0ef6955f96 opt: enable use cdn url for mj-plus 2024-03-10 09:45:59 +08:00
RockYang
b4501557c9 feat: LaTeX parse is ready 2024-03-10 09:45:59 +08:00
RockYang
a2ed99e6cb feat: add model field for chat_item and and chat_history data table 2024-03-10 09:45:59 +08:00
RockYang
6bd6bb3885 feat: add err_msg field for mj and sd jobs 2024-03-10 09:45:59 +08:00
RockYang
399cf65fc9 feat: blend and swap face function for midjourney-plus is ready 2024-03-10 09:45:59 +08:00
RockYang
24906a6df1 feat: add blend and swapface task implements for midjourney 2024-03-10 09:45:59 +08:00
RockYang
d772bbebe6 opt: refactor chat session page for mobile device 2024-03-10 09:45:59 +08:00
RockYang
14988853a3 opt: optimize chat list page for mobile 2024-03-10 09:45:59 +08:00
RockYang
7b3f16ac9f feat: use vant replace element-plus as mobile UI framework 2024-03-10 09:45:59 +08:00
RockYang
82b2755c18 feat: add websocket heartbeat message for mj page 2024-03-10 09:45:59 +08:00
廖彦棋
ff4b267858 fix(ui): 细节调整 2024-03-08 17:46:48 +08:00
huangqj
a590d0497f feat(ui):细节调整 2024-03-08 10:24:38 +08:00
廖彦棋
ac30d906f0 feat(ui): prettier 2024-03-08 09:59:40 +08:00
廖彦棋
5bc071e038 refactor(ui): 登录页重构 2024-03-08 09:45:09 +08:00
廖彦棋
88b956cf98 refactor(ui): 优化 2024-03-08 09:12:39 +08:00
huangqj
f725cf4661 feat(ui):路由 2024-03-08 08:35:41 +08:00
廖彦棋
057cc1e8a6 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 18:03:08 +08:00
廖彦棋
de122735b8 feat(ui): 过期跳转登录 2024-03-07 18:03:06 +08:00
huangqj
e87ede981c feat(ui):apiKey 语言模型 角色管理 产品 2024-03-07 17:58:25 +08:00
廖彦棋
606fb498e1 feat(ui): 新增系统设置 2024-03-07 17:24:50 +08:00
廖彦棋
a0c06e40a4 feat(ui): 对话管理 2024-03-07 15:32:32 +08:00
huangqj
aba8f57279 feat(ui):用户 2024-03-07 15:05:01 +08:00
huangqj
960286a350 feat(ui):simpleTable 2024-03-07 14:59:45 +08:00
huangqj
8c93fa51f6 feat(ui):searchtable 2024-03-07 14:59:16 +08:00
huangqj
cb0e7d64ff feat(ui):用户 2024-03-07 14:03:55 +08:00
廖彦棋
8e7413da97 feat(ui): 函数管理 2024-03-07 11:40:57 +08:00
廖彦棋
a36f14eb94 feat(ui): 新增弹窗及时间格式化 2024-03-07 09:23:45 +08:00
chenzifan
f2f9f6e488 Merge branch 'main' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 08:37:54 +08:00
chenzifan
85068b8ca2 feat: 增加系统用户管理 2024-03-07 08:37:48 +08:00
廖彦棋
f2cfcfeefc fix(ui): ts类型 2024-03-06 18:20:07 +08:00
RockYang
755273a898 feat: update changelogs 2024-03-06 17:58:17 +08:00
廖彦棋
d4a24a0f1d feat(ui): 新增登录 2024-03-06 17:54:38 +08:00
RockYang
92281fcbb7 feat: Mj and sd jobs data loading in pages 2024-03-06 17:31:54 +08:00
RockYang
636db4afcc add prompt translating function for mobile midjourney page 2024-03-06 16:22:03 +08:00
huangqj
ba25b8755e feat(ui):simpleTable 2024-03-06 15:33:37 +08:00
廖彦棋
6399d13a49 feat(ui): 新增请求方法及表格 2024-03-06 13:55:38 +08:00
廖彦棋
06fa54fd25 feat(ui): 管理后台基础配置 2024-03-06 10:23:55 +08:00
廖彦棋
a335b965d0 chore(ui): 目录结构调整 2024-03-06 09:32:47 +08:00
廖彦棋
725adaa7d0 feat(ui): 初始化 2024-03-06 09:27:11 +08:00
chenzifan
7e7e81e974 refactor: 初始化UI重构 2024-03-06 08:57:46 +08:00
RockYang
8cfe6bfc17 Merge branch 'dev' of gitee.com:blackfox/chatgpt-plus-pro into dev 2024-03-04 08:34:12 +08:00
RockYang
33de83f2ac feat: add removing order button in admin order list page 2024-03-03 19:27:22 +08:00
RockYang
3f856afec8 fix: fix major bugs for unauthorized access to data 2024-03-03 10:40:32 +08:00
RockYang
4e4dc4cb73 Update README.md 2024-03-01 23:08:05 +08:00
RockYang
02a9c422fe fix: fixed bug image preview im mobile chat session page 2024-02-29 15:41:45 +08:00
RockYang
ca69341024 feat: add draw same image for midjourney page 2024-02-29 11:44:09 +08:00
RockYang
169bf069ce opt: add logs for mj-plus api error 2024-02-28 15:50:42 +08:00
RockYang
1bee0ab04d opt: replace proxy url for discord image url 2024-02-27 17:45:57 +08:00
RockYang
440d91dd0e feat: add change password in with mobile page 2024-02-27 15:36:20 +08:00
RockYang
8168e246a8 feat: add image preview for mobile chat page 2024-02-26 18:11:37 +08:00
RockYang
2ef07574ae feat: replace http polling with webscoket notify in sd image page 2024-02-26 15:45:54 +08:00
RockYang
37392f2bb2 chore: replace 'token' with power 2024-02-23 18:11:57 +08:00
RockYang
a80cd3848e docs: update mj-plus api domain 2024-02-23 15:41:02 +08:00
RockYang
db6ed84451 opt: remove global keyup event bind 2024-02-23 11:43:27 +08:00
RockYang
4463cc5963 docs: update config.toml 2024-02-22 17:53:39 +08:00
RockYang
d316158fe2 docs: add database sql file for v3.2.7 2024-02-22 17:28:22 +08:00
RockYang
e02a8d7586 feat: allow to view chat message in manager console 2024-02-22 17:16:44 +08:00
RockYang
9988dff885 feat: download image which ai generated in dialog and replace the image url 2024-02-20 18:38:03 +08:00
RockYang
35ef5674ff feat: image-wall page for mobile is ready 2024-02-20 17:38:18 +08:00
RockYang
976da45bce feat: allow user config third-party platform openai and mj api key 2024-02-20 11:23:55 +08:00
RockYang
c83ac48bd2 feat: added delete file function 2024-02-19 16:43:03 +08:00
RockYang
3d159a833e fix: verifycation component touch event coordinates misplace in iphone browser 2024-02-19 14:04:50 +08:00
RockYang
4b09878bdd fix: fix bug for regenerate button did not work 2024-02-19 11:22:42 +08:00
RockYang
b0162e6a92 fix: Upscale and Variation task overrite each other 2024-02-16 18:08:29 +08:00
RockYang
8ab15e5dc4 feat: midjourney mobile page all function is ready 2024-02-16 15:55:04 +08:00
RockYang
d2ac807252 feat: mobile mj list page is ready 2024-02-15 18:11:22 +08:00
RockYang
0af01f6f1f feat: add mj image list component for mobile page. fixed bug for html tag escape 2024-02-15 11:39:04 +08:00
RockYang
013b319fab feat: add functions mj page for mobile 2024-01-31 07:24:35 +08:00
RockYang
2899ba5949 opt: add default extension for mj image 2024-01-30 21:46:17 +08:00
RockYang
a558b7e104 feat: mj for mobile page payout is ready 2024-01-30 18:34:01 +08:00
RockYang
7a833e2233 add docs and github link 2024-01-30 16:18:27 +08:00
RockYang
bf65746d00 opt: enable use cdn url for mj-plus 2024-01-28 21:56:25 +08:00
RockYang
f08a7862de feat: LaTeX parse is ready 2024-01-26 18:04:53 +08:00
RockYang
023a2c2f09 feat: add model field for chat_item and and chat_history data table 2024-01-26 16:54:00 +08:00
RockYang
1bcd0f4c1a feat: add err_msg field for mj and sd jobs 2024-01-26 14:50:36 +08:00
RockYang
a0f3bc8ccb feat: blend and swap face function for midjourney-plus is ready 2024-01-26 11:57:08 +08:00
RockYang
dea72738c1 feat: add blend and swapface task implements for midjourney 2024-01-25 18:50:24 +08:00
RockYang
a1d1fe7763 opt: refactor chat session page for mobile device 2024-01-25 14:07:10 +08:00
RockYang
a39ed9764c opt: optimize chat list page for mobile 2024-01-24 18:23:24 +08:00
RockYang
aaa5ba99aa feat: use vant replace element-plus as mobile UI framework 2024-01-24 17:34:30 +08:00
RockYang
2113508b6d feat: add websocket heartbeat message for mj page 2024-01-24 09:33:04 +08:00
RockYang
7fe4212684 add v3.2.6 database sql file 2024-01-23 18:00:49 +08:00
RockYang
8bdda64794 doc: update config sample file 2024-01-23 17:56:22 +08:00
RockYang
ec08c24dca fix: auto fill apiURL when platform changed for ApiKey add page 2024-01-23 17:30:54 +08:00
RockYang
a992a5b3b3 feat: merge sms branch,add DuanXinBao sms service implemetation 2024-01-23 16:16:47 +08:00
RockYang
0f05970141 opt: add heartbeat message for websocket connects 2024-01-22 18:42:51 +08:00
whale_fall
e5e762efcd feat: 添加支持多个短信服务商支持 添加短信宝服务商支持,同时添加配置示例 2024-01-22 16:38:44 +08:00
RockYang
b3d0c1ef9c fix: fix bug for wechat transfer message parse failed 2024-01-22 16:10:08 +08:00
RockYang
397078f7ff feat: HuPiPay order check function is ready 2024-01-22 15:17:26 +08:00
RockYang
3ad8065e20 opt: verify the order in notify callback 2024-01-22 13:58:25 +08:00
RockYang
66c7717f04 chore: print error detail when call http api failed with mj 2024-01-21 22:30:24 +08:00
RockYang
412f8ecc6c opt: add image upload support for md-editor-3 2024-01-19 18:43:13 +08:00
RockYang
51dcf642b3 fixed conflicts 2024-01-19 18:21:49 +08:00
RockYang
bfeea555b2 feat: system notice function is ready 2024-01-19 18:19:51 +08:00
RockYang
479f94c372 feat: system notice function is ready 2024-01-19 18:18:10 +08:00
RockYang
0140713e86 fix: fixed bug for img_call increased when upscale task run failed 2024-01-19 17:10:52 +08:00
RockYang
15b2ec9721 feat: add system config item for wechat qrcode 2024-01-19 16:58:13 +08:00
RockYang
c9cd082855 chore: optimize variable name 2024-01-19 11:26:22 +08:00
RockYang
d7c002890c Merge branch 'main' into dev 2024-01-19 10:09:18 +08:00
RockYang
348dd22279 Merge branch 'main' of gitee.com:blackfox/chatgpt-plus 2024-01-19 10:09:01 +08:00
RockYang
3e99b4cbf6 !4 添加支持阿里旗下的大模型 通义千问对话
Merge pull request !4 from 鲸落/qwen
2024-01-19 02:06:17 +00:00
whale_fall
6968da3ac7 feat: 添加支持阿里的通义千问对话 2024-01-19 09:52:16 +08:00
RockYang
bf1c1b84c3 feat: add image publish function, ONLY published image show in image wall page 2024-01-19 06:52:23 +08:00
RockYang
c70314d930 update change log 2024-01-18 18:11:25 +08:00
RockYang
9104ca8e49 feat: add system config disable user registeration 2024-01-18 17:24:02 +08:00
RockYang
2af33b3630 opt: compatible wechat old message format for parsing wechat transfer message 2024-01-18 16:58:20 +08:00
RockYang
654e795545 opt: optimize order query alg, reduce polling times 2024-01-18 09:39:36 +08:00
RockYang
c62ba2451e update config 2024-01-16 15:24:06 +08:00
RockYang
d72d1b8a99 docs: update change log 2024-01-16 15:05:54 +08:00
RockYang
b939d6016b docs: update config files 2024-01-16 14:38:18 +08:00
RockYang
36a2626ccc opt: optimize markdown image parser, identify image and blockquote tags 2024-01-16 10:13:00 +08:00
RockYang
bd057a4cc9 feat: attachments manage function is ready 2024-01-15 18:48:01 +08:00
RockYang
dc24a8c781 feat: gpt-4-gizmo-g-* model is supported 2024-01-15 15:03:05 +08:00
RockYang
59fa21779b feat: gpt-4-all model is ready 2024-01-15 14:07:24 +08:00
RockYang
a140671aad opt: optimize vip recharge logic 2024-01-15 11:01:57 +08:00
RockYang
5fe8990fb4 Merge branch 'dev' 2024-01-15 10:29:01 +08:00
RockYang
12799b7159 opt: optimize vip recharge logic 2024-01-15 10:28:46 +08:00
RockYang
9929746b1d feat: add asynchronously pull midjourney task progress in case the synchronization callback is fails 2024-01-12 18:24:28 +08:00
RockYang
d70035ff0c feat: midjourney plus service is ready 2024-01-11 18:16:48 +08:00
RockYang
eec90274d8 feat: update video tutorial 2024-01-10 08:50:13 +08:00
RockYang
e8fff55c42 feat: update video tutorial 2024-01-10 08:48:05 +08:00
RockYang
3cf3cdd705 remove api key for hupipay 2024-01-09 17:16:27 +08:00
RockYang
9801fce659 fix: fixed bug for gorm insert record failed and Error is not nil 2024-01-08 18:10:32 +08:00
RockYang
4c1f51110b feat: change mobile field to username 2024-01-08 17:34:09 +08:00
RockYang
913d538587 add changelog 2024-01-08 12:01:58 +08:00
RockYang
9e704365fc chore: do not close pop window when click model 2024-01-08 11:01:19 +08:00
RockYang
485bdbc56a fix: function call 兼容中转 API 2024-01-07 22:32:59 +08:00
RockYang
7000168fd4 opt: add support to disable code verify 2024-01-07 17:31:26 +08:00
RockYang
5694f97a6b feat: payjs payment channel is ready 2024-01-07 14:36:02 +08:00
RockYang
b677d3fac7 fix: add user failed in admin user list page 2024-01-07 10:49:36 +08:00
RockYang
dc6719cf54 release v3.2.4 2024-01-06 21:09:19 +08:00
RockYang
7de5b55091 chore: remove useless system config items 2024-01-06 17:38:55 +08:00
RockYang
76c5101092 chore: error recover is enable ONLY in debug mode 2024-01-06 17:16:02 +08:00
RockYang
2f8d2f4854 feat: payjs service is ready 2024-01-06 15:53:30 +08:00
RockYang
b1ee34ba0c chore: rename bind username api 2024-01-05 18:21:47 +08:00
RockYang
069ad6a09a feat: email registration function is ready 2024-01-05 18:17:11 +08:00
RockYang
bf1403c818 feat: update api key last_use_time after dalle3 call 2024-01-04 18:15:00 +08:00
RockYang
bcc622a24d feat: support dall-e3 api mirrors, add name field for ApiKey 2024-01-04 16:29:57 +08:00
RockYang
a06a81a415 feat: refactor LLM api request code, get API URL from ApiKey object 2024-01-04 14:51:33 +08:00
RockYang
d1950acd01 feat: api key manage page funciton is ready 2024-01-04 10:48:04 +08:00
RockYang
039b70eed2 fix: fixed bug for concurrency risk for getting token for chat histroy with issue #92 2024-01-04 09:03:19 +08:00
RockYang
d8e4308b1b fix: add unique key for MidJourney task_id 2024-01-03 18:06:10 +08:00
RockYang
434fbb3463 feat: show notice in chat page 2024-01-03 15:19:24 +08:00
RockYang
de3eb8969c feat: fixed bug for wechat bot to parse transactions. enable user to exchange reward with img_calls 2024-01-03 11:15:54 +08:00
RockYang
fbd6eac877 fix: fixed chat export page styles 2024-01-02 11:32:36 +08:00
RockYang
1fecab177b fix: fixed for img_call repeated reductions 2024-01-01 18:54:48 +08:00
RockYang
b1b385c455 feat: add switch for enable|disable chat role 2023-12-29 17:51:56 +08:00
RockYang
3c6e86d04b feat: add nickname field for user 2023-12-29 17:39:37 +08:00
RockYang
3d2035d08a feat: add authorization for local function call 2023-12-29 17:21:29 +08:00
RockYang
da86f916d8 update changelog 2023-12-29 11:53:37 +08:00
RockYang
e7a07f7e92 feat: add router for function manager 2023-12-29 11:22:26 +08:00
RockYang
b01e6387fc fix: restore user img_calls quota when image task run failed 2023-12-29 10:41:29 +08:00
RockYang
d86aca0f5d merge pull request #72 2023-12-29 10:09:37 +08:00
RockYang
09414fe36a Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2023-12-29 09:39:52 +08:00
RockYang
df0e7508db Merge pull request #72 from Unclesimonlau/main
重新设计了移动端web页面,新增了移动端CSS,增加移动端SD绘图页面
2023-12-29 09:39:33 +08:00
RockYang
92b1f01118 merge pull request #71 2023-12-29 09:31:25 +08:00
RockYang
8fb8bd932b Merge pull request #71 from JingHong0202/main
fix: Azure Api request failure after changing the API-version parameter
2023-12-29 09:27:23 +08:00
RockYang
3f74b94784 chore: remove dead code 2023-12-29 09:02:55 +08:00
RockYang
e9467341fa feat: function manager refactor is ready 2023-12-28 18:14:38 +08:00
JingHong
131e051ddc Merge branch 'yangjian102621:main' into main 2023-12-27 23:33:46 +08:00
UncleSimonlau
f626fe3166 Merge branch 'main' of https://github.com/Unclesimonlau/chatgpt-plus 2023-12-26 14:19:25 +08:00
RockYang
6bc57b6132 fix: fixed bug #70, XunFei 1.5 url version map error 2023-12-26 14:19:00 +08:00
RockYang
d972e97c88 fix: fixed bug #70, XunFei 1.5 url version map error 2023-12-25 08:54:17 +08:00
RockYang
3991f4daec feat: function CRUD operation is ready 2023-12-24 22:12:12 +08:00
jinghong0202
f6b567d6fc fix: Azure Api 更换api-version参数后请求失败的问题 2023-12-24 08:36:34 +01:00
RockYang
8addba8203 feat: function add for admin page is ready 2023-12-23 22:30:27 +08:00
RockYang
3ab930a107 feat: support CDN reverse proxy for MidJourney and OpenAI API 2023-12-22 17:25:31 +08:00
RockYang
de512a5ea2 feat: add function list page in admin console 2023-12-21 18:06:09 +08:00
RockYang
113cfae2dc opt: optimize image compress alg, add cache control for image 2023-12-21 15:00:46 +08:00
RockYang
33aebf9cb5 feat: add funcitons manger page 2023-12-21 08:58:24 +08:00
RockYang
6e58ddf681 feat: auto translate image creating prompt 2023-12-19 18:54:19 +08:00
RockYang
cae5c049e4 fix: fixed bug for HuPiPay qrcode generation. set field 'openid' of result struct to Any data type 2023-12-19 11:31:57 +08:00
RockYang
ff76e4bd89 chore: update copyright information 2023-12-18 18:19:41 +08:00
RockYang
a0a506a3c4 feat: add remove action to remove task and images for MJ and SD task list page 2023-12-18 17:44:52 +08:00
RockYang
aa5a4a9977 opt: merge RAG branch 2023-12-18 16:41:40 +08:00
RockYang
abf4f061c1 opt: make sure the Upscale and Variation task is assign to the same mj service with Image task 2023-12-18 16:34:33 +08:00
RockYang
245cd3ee1a fix: fixed bug for mj service pool config pointer 2023-12-15 22:52:57 +08:00
RockYang
45cb29d9a0 feat: add img_calls field for recharge products 2023-12-15 16:56:56 +08:00
RockYang
d974b1ff0e chore: update default config.toml 2023-12-15 11:23:13 +08:00
RockYang
56269170cb opt: limit the image display size in reply component 2023-12-15 10:48:13 +08:00
RockYang
4290c4ca22 docs: update changelog 2023-12-15 09:04:02 +08:00
RockYang
7f7c8e831e docs: add sql file 2023-12-14 17:05:51 +08:00
RockYang
8f057ca9d1 refactor: refactor stable diffusion service, add service pool support 2023-12-14 16:48:54 +08:00
RockYang
4a56621ec3 chore: add sub dir support for OSS 2023-12-13 17:02:49 +08:00
RockYang
a398e7a550 refactor: add midjourney pool implementation, add translate prompt for mj drawing 2023-12-13 16:38:27 +08:00
RockYang
96816c12ca fix: fixed bug for aliyun OSS img url 2023-12-13 09:49:55 +08:00
RockYang
9984926f69 refactor mj service, add mj service pool support 2023-12-12 18:33:24 +08:00
RockYang
a2a6081027 opt: remove default value for stable-diffusion page 2023-12-12 09:59:20 +08:00
RockYang
5a10ed37a7 docs: update readme file 2023-12-12 09:52:28 +08:00
RockYang
1a9dd9de0b docs: update build config.toml 2023-12-12 07:25:36 +08:00
RockYang
0dae5bef71 docs: update changelog file 2023-12-11 17:17:17 +08:00
RockYang
b4413ed726 add translate api for midjourney 2023-12-11 17:01:02 +08:00
RockYang
5e1fe88b8b feat: add prompt translate handler 2023-12-11 06:56:00 +08:00
RockYang
91ed41b536 feat: add system config item for dall e3 generate image num 2023-12-10 17:13:25 +08:00
RockYang
024c0032eb chore: change default params for stable diffusion 2023-12-10 14:45:22 +08:00
RockYang
4a9f7e3bce feat: add HuPiPay payment support 2023-12-08 19:43:13 +08:00
RockYang
cf4dcc34ec feat: add image generation API URL in chat configurations 2023-12-07 16:31:32 +08:00
RockYang
4d612c15af docs: add arm64 build script 2023-12-07 15:44:20 +08:00
RockYang
8aec87cc02 fix: fixed bug for prompt code format, prevent xss attacks 2023-12-07 14:02:13 +08:00
RockYang
442e411cde opt: save chat ID when the chat websocket disconnect 2023-12-07 11:07:08 +08:00
RockYang
acec0194de feat: adjust task list component styles 2023-12-06 19:05:51 +08:00
RockYang
8557f5b94a Merge branch 'pr_3' into dev 2023-12-06 18:54:45 +08:00
RockYang
babef8baae feat: refactor midjourney image creating page 2023-12-06 18:54:30 +08:00
RockYang
efd4ab46f5 docs: update readme file 2023-12-06 14:44:06 +08:00
liyuwanglan
ae8239e5de 修改 2023-12-05 11:08:03 +08:00
RockYang
f0994ba457 docs: update comments 2023-11-30 17:35:56 +08:00
RockYang
dae91ed243 fix: fixed bug for upload image failed 2023-11-29 17:46:46 +08:00
RockYang
de42a428e6 opt: create new chat session when change role or model, fix bug for mobile no validate 2023-11-29 17:36:27 +08:00
RockYang
63c7041e1f docs: update change logs 2023-11-28 15:33:30 +08:00
RockYang
b1263ddc69 docs: 添加一键部署脚本 2023-11-28 15:25:51 +08:00
RockYang
7e50e17aaf opt: 缩略图生成算法 2023-11-28 14:50:19 +08:00
RockYang
a7265c4251 feat: 为大图片生成缩略图,加快前端图片加载速度 2023-11-28 12:04:02 +08:00
RockYang
6f39f639bd fix: fix bug for oss image domain 2023-11-28 07:27:18 +08:00
RockYang
a7db123437 docs: update build script 2023-11-27 18:24:52 +08:00
RockYang
241c714a8b Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2023-11-27 12:03:23 +08:00
RockYang
67ac3cfe32 feat: merge mysql and redis docker service to docker-compose.yaml file 2023-11-27 10:56:18 +08:00
RockYang
c926e0afcc Merge pull request #55 from openjst/main
修改code显示颜色样式
2023-11-27 10:27:13 +08:00
RockYang
5bc07e6d57 docs: make a full docker-compose.yaml 2023-11-27 07:21:37 +08:00
openjst
c3666a9a71 修改code显示颜色样式 2023-11-26 23:26:08 +08:00
RockYang
23b5ffa97d feat: implements image function replace Mj with DALL-E-3 2023-11-26 20:37:48 +08:00
RockYang
a2c7a75705 feat: add type field for api key 2023-11-24 18:05:59 +08:00
RockYang
d68f2ef12c feat: add support for registing use force use invite code 2023-11-24 12:02:28 +08:00
RockYang
67d30353f0 opt: optimize image preview for MidJourney image list page, only preview current image not for all images 2023-11-23 17:55:12 +08:00
RockYang
4813163eac opt: 增加中间件自动对HTTP请求的参数去掉首尾空格 2023-11-23 17:50:55 +08:00
RockYang
5c5210625e opt: optimize styles for invitation page 2023-11-23 17:40:15 +08:00
RockYang
a4a1eec30b feat: add invitation and promotion functions 2023-11-23 16:30:15 +08:00
RockYang
d35164506a docs: add database sql file for v3.1.9 2023-11-23 09:58:01 +08:00
330 changed files with 28091 additions and 7206 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
deploy
docs
api/static
web/node_modules
desktop

View File

@@ -1,6 +1,153 @@
# 更新日志
## v4.0.3 2024-04-15
* Bug修复修复MJ-PLUS 服务会自动删除10分钟前的任务问题
* Bug修复修复MJ 的 U/V 操作会强制使用 Fast 模式 Bug
## v4.0.2
* 功能新增:支持前端菜单可以配置
* 功能优化:在登录和注册界面标题显示软件版本号
* 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
* Bug修复解决因为图片上传使用相对路径而导致融图失败的问题
* 功能新增:手机端支持 Stable-Diffusion 绘画
* Bug修复修复管理后台 API KEY 删除失败的问题
## v4.0.1
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口SDAPI 兼容各种 stable-diffusion
发行版,稳定性更强一些
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API兼容
MJ-Plus 中转
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
* Bug修复修复 iphone 手机无法通过图形验证码的Bug使用滑动验证码替换
* Bug修复修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
## v4.0.0
非兼容版本重大重构引入算力概念将系统中所有的能力AI对话MJ绘画SD绘画DALL绘画全部使用算力来兑换。
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力一次 GPT4 对话消耗10个算力。一次 MJ
对话消耗15个算力...
* 功能重构:重构整体系统,全部采用算力来进行结算
* 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
* 功能优化:移动端聊天页面图片支持预览和放大功能
* 功能优化MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
* 功能优化:**PC端不登录也可以预览功能只有在发起操作的时候才需要登录**
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
* 功能新增支持H5支付
* 功能优化:支持数学公式的识别和美化输出
* 功能新增:新增算力消费日志功能
* 功能优化:整合 XXL-JOB 实现订单清理每日算力派发VIP 算力重置等任务
* 功能新增管理后台新增7日内新增用户和新增订单统计
## v3.2.7
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
* Bug修复修复 issue [
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
* 功能优化:在对话和聊天记录表中新增冗余字段 model存储对话模型
* Bug修复IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
* Bug修复重新生成按钮功能失效问题
* Bug修复对话输入HTML标签不显示的问题
* 功能优化gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
* 功能新增:新增删除文件功能
* Bug修复解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
* 功能新增:后台管理新怎对话查看和检索功能
## v3.2.6
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
* 功能优化:兼用旧版本微信收款消息解析
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
* 功能新增:新增网站公告,可以在管理后台自定义配置
* 功能新增:新增阿里通义千问大模型支持
* Bug修复修复 MJ 放大任务失败时候 img_call 会增加的 Bug
* 功能优化新增虎皮椒和PayJS订单状态校验功能增加安全性
* Bug修复修复微信转账交易 ID 提取失败 Bug
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
* 功能新增:新增短信宝短信平台发送平台集成
## v3.2.5
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
Plus 账号了!!!
* 功能优化:增强 markdown 图片和引用块解析。
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
* 功能优化function call 兼用中转 API。
* Bug修复修复部分已知的 Bug。
## v3.2.4.1
* 功能新增:新增 PayJs 支付通道
* Bug修复紧急修复后台添加用户失败问题
* Bug修复紧急修复使用中转 API-KEY 无法绘图的问题
* Bug修复允许用户关闭手机和邮箱注册通道移除验证码依赖
## v3.2.4
* 功能新增:重磅更新,支持邮箱注册
* 功能优化:优化函数调用授权
* 功能优化:给用户表新增 nickname 字段
* 功能优化:管理后台给聊天角色增加启用/禁用开关
* Bug修复SD绘画出现重复扣减绘图次数
* 功能优化:优化聊天对话导出样式,适应移动端
* 功能新增:众筹核销可以选择兑换对话还是绘图的额度
* Bug修复修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
* Bug修复修复 MidJourney 绘图任务调度Bug为 task_id 建议唯一索引
* 功能重构:重构了 API KEY模块支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
## v3.2.3
* 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
* 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
* Bug修复修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
* Bug修复确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
* 功能新增:新增删除绘图任何和图片功能
* Bug修复修复虎皮椒支付二维码重复扫码时报错问题
* 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
* 功能优化优化AI绘画的大图压缩算法新增图片缓存
* 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
* Bug修复修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
* Bug修复修复科大讯飞 V1.5 API 请求失败的问题
* Bug修复绘图失败后自动恢复用户的剩余绘图次数
* 功能新增:为移动端新增 SD 绘图功能,分享功能
## v3.2.2
* 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
* 功能新增AI绘画页面增加翻译和重写提示词功能
* 功能优化OSS上传组件支持在 Bucket 下设置二级目录
* Bug修复修复阿里云 OSS 访问路径错误
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
## v3.2.1
* 功能优化:切换角色和模型的时候自动创建新的对话
* Bug修复修复文件上传失败No such file bug
* 功能新增MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
* Bug修复[PC端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
* 功能新增:增加 arm64 架构打包脚本
* 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
* 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
## v3.2.0
* 功能新增:新增邀请注册功能
* 功能优化增加中间件自动对HTTP请求的参数去掉首尾空格
* 功能优化:增加中间件自动为大图片生成缩略图
* 功能优化MidJourney 页面图片加载优化,实现图片预览懒加载
* 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
* Bug修复修复阿里云 OSS 域名设置不起做用的bug
* Bug修复修复MidJourney绘图失败后重复添加到队列的问题
## v3.1.9
* 功能新增:增加讯飞星火大模型 v3.0 支持
* 功能新增:新增找回密码功能
* 功能新增:支持 Markdown 代码复制功能

View File

@@ -9,7 +9,7 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
* 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。
* 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
* 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
* 已集成支付宝支付功能,支持多种会员套餐和点卡购买功能。
* 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
* 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
绘画函数插件。
@@ -63,17 +63,42 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
![Mobile chat setting](/docs/imgs/mobile_user_profile.png)
![Mobile chat setting](/docs/imgs/mobile_pay.png)
### 7. 体验地址
### 体验地址
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
## 快速部署
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
```shell
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v4.0.2-ba5a891bc0.sh)"
```
最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
* 前端访问地址http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
* 后台管理地址http://localhost:8080/admin
* 移动端地址http://localhost:8080/mobile
* 初始后台管理账号admin/admin123
* 初始前端体验账号18575670125/12345678
服务启动成功之后不能立刻使用,需要先登录管理后台 -> API-KEY 去添加一个 OpenAI 或者文心一言,科大讯飞等至少一个平台的 API
KEY。
![](https://ai.r9it.com/docs/images/env/admin_api_keys.png)
另外,如果您目前还没有 OpenAI 的 API KEY的推荐您去 https://api.chat-plus.net 购买,**无需魔法,高速稳定,且价格还远低于 OpenAI
官方**。
## 使用须知
1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
## 项目地址
* Github 地址https://github.com/yangjian102621/chatgpt-plus
@@ -84,22 +109,25 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2
## TODOLIST
* [ ] 支持基于知识库的 AI 问答
* [ ] 会员邀请注册推广功能
* [ ] 微信支付功能
## 项目文档
请参考 [ChatGPT-Plus 文档](https://ai.r9it.com/docs/)
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)。
加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。**
![微信名片](docs/imgs/wx.png)
## 参与贡献
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
如果有兴趣的话,也可以加微信进入微信讨论群(**添加好友时请注明来自Github!!!**)。
![微信名片](docs/imgs/wx.png)
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合
### Commit 类型
@@ -119,6 +147,3 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
![打赏](docs/imgs/donate.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/chatgpt-plus&type=Date)

1
api/.gitignore vendored
View File

@@ -18,4 +18,3 @@ data
config.toml
static/upload
storage.json
certs/alipay/*

View File

@@ -1,19 +1,14 @@
SHELL=/usr/bin/env bash
NAME := chatgpt-plus
all: window linux darwin
all: amd64 arm64
amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go
.PHONY: amd64
window:
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/$(NAME)-amd64.exe main.go
.PHONY: window
linux:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-amd64-linux main.go
.PHONY: linux
darwin:
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/$(NAME)-amd64-darwin main.go
.PHONY: darwin
arm64:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go
.PHONY: arm64
clean:
rm -rf bin/$(NAME)-*

View File

@@ -1,6 +1,6 @@
Listen = "0.0.0.0:5678"
ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = ""
@@ -10,10 +10,6 @@ WeChatBot = false
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
MaxAge = 86400
[Manager]
Username = "admin"
Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
[Redis] # redis 配置信息
Host = "localhost"
Port = 6379
@@ -25,19 +21,28 @@ WeChatBot = false
AppId = ""
Token = ""
[SmsConfig] # 阿里云短信服务配置
AccessKey = ""
AccessSecret = ""
Product = "Dysmsapi"
Domain = "dysmsapi.aliyuncs.com"
Sign = ""
CodeTempId = ""
[SMS] # Sms 配置,用于发送短信
Active = "Ali" # 当前启用的短信服务,默认使用阿里云
[SMS.Bao]
Username = ""
Password = ""
Domain = "api.smsbao.com"
Sign = "【极客学长】"
CodeTemplate = "您的验证码是{code}。5分钟有效若非本人操作请忽略本短信。"
[SMS.Ali]
AccessKey = ""
AccessSecret = ""
Product = "Dysmsapi"
Domain = "dysmsapi.aliyuncs.com"
Sign = ""
CodeTempId = ""
[OSS] # OSS 配置,用于存储 MJ 绘画图片
Active = "local" # 默认使用本地文件存储引擎
[OSS.Local]
BasePath = "./static/upload" # 本地文件上传根路径
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件 URL 如果是线上,则直接设置为 /static/upload 即可
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
[OSS.Minio]
Endpoint = "" # 如 172.22.11.200:9000
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
@@ -51,19 +56,30 @@ WeChatBot = false
AccessSecret = ""
Bucket = ""
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
[OSS.AliYun]
Endpoint = "oss-cn-hangzhou.aliyuncs.com"
AccessKey = ""
AccessSecret = ""
Bucket = "chatgpt-plus"
SubDir = ""
Domain = ""
[MjConfig]
Enabled = false
UserToken = ""
BotToken = ""
GuildId = ""
ChanelId = ""
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[SdConfig]
[[MjPlusConfigs]]
Enabled = false
ApiURL = "http://172.22.11.200:7860"
ApiURL = "https://api.chat-plus.net"
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/text2img.json"
Txt2ImgJsonPath = "res/sd/text2img.json"
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务
@@ -82,4 +98,27 @@ WeChatBot = false
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
NotifyURL = "http://r9it.com:6004/api/payment/alipay/notify" # 支付异步回调地址
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
[HuPiPayConfig]
Enabled = false
Name = "wechat"
AppId = ""
AppSecret = ""
ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口所以如果需要使用邮件功能请别用阿里云服务器
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
[JPayConfig] # PayJs 支付配置
Enabled = false
Name = "wechat" # 请不要改动
AppId = "" # 商户 ID
PrivateKey = "" # 秘钥
ApiURL = "https://payjs.cn"
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的

View File

@@ -1,8 +1,8 @@
package core
import (
"bytes"
"chatplus/core/types"
"chatplus/service/fun"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
@@ -11,9 +11,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/nfnt/resize"
"gorm.io/gorm"
"image"
"image/jpeg"
"io"
"log"
"net/http"
"os"
"runtime/debug"
"strings"
"time"
@@ -23,31 +28,28 @@ type AppServer struct {
Debug bool
Config *types.AppConfig
Engine *gin.Engine
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache
SysConfig *types.SystemConfig // system config cache
SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
Functions map[string]fun.Function
}
func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer {
func NewServer(appConfig *types.AppConfig) *AppServer {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
return &AppServer{
Debug: false,
Config: appConfig,
Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []interface{}](),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
Functions: functions,
}
}
@@ -57,30 +59,22 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
logger.Info("Enabled debug mode")
}
s.Engine.Use(corsMiddleware())
s.Engine.Use(staticResourceMiddleware())
s.Engine.Use(authorizeMiddleware(s, client))
s.Engine.Use(parameterHandlerMiddleware())
s.Engine.Use(errorHandler)
// 添加静态资源访问
s.Engine.Static("/static", s.Config.StaticDir)
}
func (s *AppServer) Run(db *gorm.DB) error {
// load chat config from database
var chatConfig model.Config
res := db.Where("marker", "chat").First(&chatConfig)
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
if err != nil {
return err
}
// load system configs
var sysConfig model.Config
res = db.Where("marker", "system").First(&sysConfig)
res := db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return res.Error
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return err
}
@@ -138,70 +132,64 @@ func corsMiddleware() gin.HandlerFunc {
// 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/mj/jobs" ||
c.Request.URL.Path == "/api/sd/jobs" ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
c.Request.URL.Path == "/api/admin/config/get" {
c.Next()
return
}
var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/sd/client" {
} else if c.Request.URL.Path == "/api/chat/new" {
tokenString = c.Query("token")
} else {
tokenString = c.GetHeader(types.UserAuthHeader)
}
if tokenString == "" {
resp.ERROR(c, "You should put Authorization in request headers")
c.Abort()
return
if needLogin(c) {
resp.ERROR(c, "You should put Authorization in request headers")
c.Abort()
return
} else { // 直接放行
c.Next()
return
}
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
if isAdminApi {
return []byte(s.Config.AdminSession.SecretKey), nil
} else {
return []byte(s.Config.Session.SecretKey), nil
}
return []byte(s.Config.Session.SecretKey), nil
})
if err != nil {
if err != nil && needLogin(c) {
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
if !ok || !token.Valid && needLogin(c) {
resp.NotAuth(c, "Token is invalid")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
resp.NotAuth(c, "Token is expired")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := client.Get(context.Background(), key).Result(); err != nil {
if isAdminApi {
key = fmt.Sprintf("admin/%v", claims["user_id"])
}
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
resp.NotAuth(c, "Token is not found in redis")
c.Abort()
return
@@ -209,3 +197,161 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
c.Set(types.LoginUserID, claims["user_id"])
}
}
func needLogin(c *gin.Context) bool {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false
}
return true
}
// 统一参数处理
func parameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// GET 参数处理
params := c.Request.URL.Query()
for key, values := range params {
for i, value := range values {
params[key][i] = strings.TrimSpace(value)
}
}
// update get parameters
c.Request.URL.RawQuery = params.Encode()
// skip file upload requests
contentType := c.Request.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
c.Next()
return
}
if strings.Contains(contentType, "application/json") {
// process POST JSON request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}
// 还原请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 将请求体解析为 JSON
var jsonData map[string]interface{}
if err := c.ShouldBindJSON(&jsonData); err != nil {
c.Next()
return
}
// 对 JSON 数据中的字符串值去除两端空格
trimJSONStrings(jsonData)
// 更新请求体
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
}
c.Next()
}
}
// 递归对 JSON 数据中的字符串值去除两端空格
func trimJSONStrings(data interface{}) {
switch v := data.(type) {
case map[string]interface{}:
for key, value := range v {
switch valueType := value.(type) {
case string:
v[key] = strings.TrimSpace(valueType)
case map[string]interface{}, []interface{}:
trimJSONStrings(value)
}
}
case []interface{}:
for i, value := range v {
switch valueType := value.(type) {
case string:
v[i] = strings.TrimSpace(valueType)
case map[string]interface{}, []interface{}:
trimJSONStrings(value)
}
}
}
}
// 静态资源中间件
func staticResourceMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
url := c.Request.URL.String()
// 拦截生成缩略图请求
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
r := strings.SplitAfter(url, "imageView2")
size := strings.Split(r[1], "/")
if len(size) != 8 {
c.String(http.StatusNotFound, "invalid thumb args")
return
}
with := utils.IntValue(size[3], 0)
height := utils.IntValue(size[5], 0)
quality := utils.IntValue(size[7], 75)
// 打开图片文件
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
file, err := os.Open(filePath)
if err != nil {
c.String(http.StatusNotFound, "Image not found")
return
}
defer file.Close()
// 解码图片
img, _, err := image.Decode(file)
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
}
var newImg image.Image
if height == 0 || with == 0 {
// 固定宽度,高度自适应
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
} else {
// 生成缩略图
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
}
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
log.Fatal(err)
}
// 设置图片缓存有效期为一年 (365天)
c.Header("Cache-Control", "max-age=31536000, public")
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}
}

View File

@@ -14,13 +14,11 @@ var logger = logger2.GetLogger()
func NewDefaultConfig() *types.AppConfig {
return &types.AppConfig{
Listen: "0.0.0.0:5678",
ProxyURL: "",
Manager: types.Manager{Username: "admin", Password: "admin123"},
StaticDir: "./static",
StaticUrl: "http://localhost/5678/static",
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
AesEncryptKey: utils.RandString(24),
Listen: "0.0.0.0:5678",
ProxyURL: "",
StaticDir: "./static",
StaticUrl: "http://localhost/5678/static",
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
Session: types.Session{
SecretKey: utils.RandString(64),
MaxAge: 86400,
@@ -33,8 +31,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload",
},
},
MjConfig: types.MidJourneyConfig{Enabled: false},
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
}

View File

@@ -8,7 +8,13 @@ type ApiRequest struct {
Stream bool `json:"stream"`
Messages []interface{} `json:"messages,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
Functions []Function `json:"functions,omitempty"`
Tools []interface{} `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"`
Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问
Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
}
type Message struct {
@@ -27,10 +33,14 @@ type ChoiceItem struct {
}
type Delta struct {
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
FunctionCall FunctionCall `json:"function_call,omitempty"`
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
} `json:"function_call,omitempty"`
}
// ChatSession 聊天会话对象
@@ -44,10 +54,14 @@ type ChatSession struct {
}
type ChatModel struct {
Id uint `json:"id"`
Platform Platform `json:"platform"`
Value string `json:"value"`
Weight int `json:"weight"`
Id uint `json:"id"`
Platform Platform `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
}
type ApiError struct {
@@ -61,17 +75,37 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message
const MjMsg = "mj"
var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"chatglm_pro": 32768, // 清华智普
"chatglm_std": 16384,
"chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞
"general2": 8192,
// PowerType 算力日志类型
type PowerType int
const (
PowerRecharge = PowerType(1) // 充值
PowerConsume = PowerType(2) // 消费
PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
PowerInvite = PowerType(4) // 邀请奖励
PowerReward = PowerType(5) // 众筹
PowerGift = PowerType(6) // 系统赠送
)
func (t PowerType) String() string {
switch t {
case PowerRecharge:
return "充值"
case PowerConsume:
return "消费"
case PowerRefund:
return "退款"
case PowerReward:
return "众筹"
}
return "其他"
}
type PowerMark int
const (
PowerSub = PowerMark(0)
PowerAdd = PowerMark(1)
)

View File

@@ -5,25 +5,36 @@ import (
)
type AppConfig struct {
Path string `toml:"-"`
Listen string
Session Session
ProxyURL string
MysqlDns string // mysql 连接地址
Manager Manager // 后台管理员账户信息
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfig MidJourneyConfig // mj 绘画配置
WeChatBot bool // 是否启用微信机器人
SdConfig StableDiffusionConfig // sd 绘画配置
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config
MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig
AlipayConfig AlipayConfig
XXLConfig XXLConfig
AlipayConfig AlipayConfig
HuPiPayConfig HuPiPayConfig
SmtpConfig SmtpConfig // 邮件发送配置
JPayConfig JPayConfig // payjs 支付配置
}
type SmtpConfig struct {
Host string
Port int
AppName string // 应用名称
From string // 发件人邮箱地址
Password string // 发件人邮箱密码
}
type ChatPlusApiConfig struct {
@@ -32,36 +43,29 @@ type ChatPlusApiConfig struct {
Token string
}
type MidJourneyConfig struct {
Enabled bool
UserToken string
BotToken string
GuildId string // Server ID
ChanelId string // Chanel ID
}
type WeChatConfig struct {
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
ApiURL string
ApiKey string
Txt2ImgJsonPath string
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type AliYunSmsConfig struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct {
Enabled bool // 是否启用该服务
Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境
AppId string // 应用 ID
UserId string // 支付宝用户 ID
@@ -70,6 +74,28 @@ type AlipayConfig struct {
AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道
Name string // 支付名称wechat/alipay
AppId string // App ID
AppSecret string // app 密钥
ApiURL string // 支付网关
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
}
// JPayConfig PayJs 支付配置
type JPayConfig struct {
Enabled bool
Name string // 支付名称,默认 wechat
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步回调地址
ReturnURL string // 支付成功返回地址
}
type XXLConfig struct { // XXL 任务调度配置
@@ -92,25 +118,6 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
}
// ChatConfig 系统默认的聊天配置
type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
XunFei ModelAPIConfig `json:"xun_fei"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
ContextDeep int `json:"context_deep"` // 上下文深度
}
type Platform string
const OpenAI = Platform("OpenAI")
@@ -118,34 +125,35 @@ const Azure = Platform("Azure")
const ChatGLM = Platform("ChatGLM")
const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
// UserChatConfig 用户的聊天配置
type UserChatConfig struct {
ApiKeys map[Platform]string `json:"api_keys"`
}
type ModelAPIConfig struct {
ApiURL string `json:"api_url,omitempty"`
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
ApiKey string `json:"api_key"`
}
const QWen = Platform("QWen")
type SystemConfig struct {
Title string `json:"title"`
AdminTitle string `json:"admin_title"`
Models []string `json:"models"`
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
InitImgCalls int `json:"init_img_calls"`
VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数
EnabledRegister bool `json:"enabled_register"`
EnabledMsg bool `json:"enabled_msg"` // 启用短信验证码服务
EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能
RewardImg string `json:"reward_img"` // 众筹收款二维码地址
EnabledFunction bool `json:"enabled_function"` // 启用 API 函数功能
EnabledReward bool `json:"enabled_reward"` // 启用众筹功能
EnabledAlipay bool `json:"enabled_alipay"` // 是否启用支付宝支付通道
OrderPayTimeout int `json:"order_pay_timeout"` //订单支付超时时间
DefaultModels []string `json:"default_models"` // 默认开通的 AI 模型
OrderPayInfoText string `json:"order_pay_info_text"` // 订单支付页面说明文字
Title string `json:"title,omitempty"`
AdminTitle string `json:"admin_title,omitempty"`
Logo string `json:"logo,omitempty"`
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"`
}

View File

@@ -1,8 +1,11 @@
package types
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
type ToolCall struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Function struct {
@@ -21,72 +24,3 @@ type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
const (
FuncZaoBao = "zao_bao" // 每日早报
FuncHeadLine = "headline" // 今日头条
FuncWeibo = "weibo_hot" // 微博热搜
FuncMidJourney = "mid_journey" // MJ 绘画
)
var InnerFunctions = []Function{
{
Name: FuncZaoBao,
Description: "每日早报,获取当天全球的热门新闻事件列表",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: FuncWeibo,
Description: "新浪微博热搜榜,微博当日热搜榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: FuncHeadLine,
Description: "今日头条,给用户推荐当天的头条新闻,周榜热文",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: FuncMidJourney,
Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"prompt": {
Type: "string",
Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。提示词中的参数作为提示的一部分,不要删除",
},
},
Required: []string{},
},
},
}

View File

@@ -6,10 +6,10 @@ import (
)
type MKey interface {
string | int
string | int | uint
}
type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{}
*WsClient | *ChatSession | context.CancelFunc | []Message
}
type LMap[K MKey, T MValue] struct {
lock sync.RWMutex

View File

@@ -10,7 +10,7 @@ const (
type OrderRemark struct {
Days int `json:"days"` // 有效期
Calls int `json:"calls"` // 增加调用次
Power int `json:"power"` // 增加算力点
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
Discount float64 `json:"discount"`

View File

@@ -12,6 +12,7 @@ type MiniOssConfig struct {
AccessKey string
AccessSecret string
Bucket string
SubDir string
UseSSL bool
Domain string
}
@@ -21,6 +22,7 @@ type QiNiuOssConfig struct {
AccessKey string
AccessSecret string
Bucket string
SubDir string
Domain string
}
@@ -29,6 +31,7 @@ type AliYunOssConfig struct {
AccessKey string
AccessSecret string
Bucket string
SubDir string
Domain string
}

26
api/core/types/sms.go Normal file
View File

@@ -0,0 +1,26 @@
package types
type SMSConfig struct {
Active string
Ali SmsConfigAli
Bao SmsConfigBao
}
// SmsConfigAli 阿里云短信平台配置
type SmsConfigAli struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
}
// SmsConfigBao 短信宝平台配置
type SmsConfigBao struct {
Username string //短信宝平台注册的用户名
Password string //短信宝平台注册的密码
Domain string //域名
Sign string // 短信签名
CodeTemplate string // 验证码短信模板 匹配
}

View File

@@ -9,30 +9,24 @@ func (t TaskType) String() string {
const (
TaskImage = TaskType("image")
TaskBlend = TaskType("blend")
TaskSwapFace = TaskType("swapFace")
TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation")
TaskTxt2Img = TaskType("text2img")
)
// TaskSrc 任务来源
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
TaskSrcImg = TaskSrc("img") // 专业绘画页面
)
// MjTask MidJourney 任务
type MjTask struct {
Id int `json:"id"`
Id uint `json:"id"`
TaskId string `json:"task_id"`
ImgArr []string `json:"img_arr"`
ChannelId string `json:"channel_id"`
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
NegPrompt string `json:"neg_prompt,omitempty"`
Params string `json:"full_prompt"`
Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
@@ -42,28 +36,26 @@ type MjTask struct {
type SdTask struct {
Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"`
}
type SdTaskParams struct {
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
}

View File

@@ -30,6 +30,7 @@ const (
Success = BizCode(0)
Failed = BizCode(1)
NotAuthorized = BizCode(400) // 未授权
NotPermission = BizCode(403) // 没有权限
OkMsg = "Success"
ErrorMsg = "系统开小差了"

View File

@@ -6,7 +6,6 @@ require (
github.com/BurntSushi/toml v1.1.0
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
github.com/bwmarrin/discordgo v0.27.1
github.com/eatmoreapple/openwechat v1.2.1
github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5
@@ -26,6 +25,18 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0
require (
github.com/mojocn/base64Captcha v1.3.1
github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0
)
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect

View File

@@ -7,8 +7,6 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -29,6 +27,7 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
@@ -65,10 +64,15 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -76,7 +80,6 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -84,6 +87,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -129,12 +133,17 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
@@ -166,6 +175,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
@@ -190,6 +201,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
@@ -219,7 +232,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
@@ -227,10 +239,13 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -239,11 +254,13 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -290,12 +307,15 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -5,10 +5,15 @@ import (
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/mojocn/base64Captcha"
"time"
"github.com/gin-gonic/gin"
@@ -17,47 +22,88 @@ import (
var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
h := ManagerHandler{db: db, redis: client}
h.App = app
return &h
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager
var data Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
manager := h.App.Config.Manager
if data.Username == manager.Username && data.Password == manager.Password {
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Username,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := "users/" + manager.Username
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, tokenString)
} else {
resp.ERROR(c, "用户名或者密码错误")
// add captcha
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
resp.ERROR(c, "验证码错误!")
return
}
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
return
}
password := utils.GenPassword(data.Password, manager.Salt)
if password != manager.Password {
resp.ERROR(c, "用户名或密码错误")
return
}
// 超级管理员默认是ID:1
if manager.Id != SuperManagerID && manager.Status == false {
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
return
}
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("admin/%d", manager.Id)
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
// 更新最后登录时间和IP
manager.LastLoginIp = c.ClientIP()
manager.LastLoginAt = time.Now().Unix()
h.DB.Updates(&manager)
var result = struct {
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
}{
IsSuperAdmin: manager.Id == 1,
Token: tokenString,
}
resp.SUCCESS(c, result)
}
// Logout 注销
@@ -72,10 +118,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
token := c.GetHeader(types.AdminAuthHeader)
if token == "" {
id := h.GetLoginUserId(c)
key := fmt.Sprintf("admin/%d", id)
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c)
} else {
resp.SUCCESS(c)
return
}
var manager model.AdminUser
res := h.DB.Where("id", id).First(&manager)
if res.Error != nil {
resp.NotAuth(c)
return
}
resp.SUCCESS(c, manager)
}
// List 数据列表
func (h *ManagerHandler) List(c *gin.Context) {
var items []model.AdminUser
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
users := make([]vo.AdminUser, 0)
for _, item := range items {
var u vo.AdminUser
err := utils.CopyObject(item, &u)
if err != nil {
continue
}
u.Id = item.Id
u.CreatedAt = item.CreatedAt.Unix()
users = append(users, u)
}
resp.SUCCESS(c, users)
}
func (h *ManagerHandler) Save(c *gin.Context) {
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Status bool `json:"status"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("username", data.Username).First(&user)
if res.Error == nil {
resp.ERROR(c, "用户名已存在")
return
}
// 生成密码
salt := utils.RandString(8)
password := utils.GenPassword(data.Password, salt)
res = h.DB.Save(&model.AdminUser{
Username: data.Username,
Password: password,
Salt: salt,
Status: data.Status,
})
if res.Error != nil {
resp.ERROR(c, "failed with update database")
return
}
resp.SUCCESS(c)
}
// Remove 删除管理员
func (h *ManagerHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id == SuperManagerID {
resp.ERROR(c, "超级管理员不能删除")
return
}
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用/禁用
func (h *ManagerHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// ResetPass 重置密码
func (h *ManagerHandler) ResetPass(c *gin.Context) {
id := h.GetLoginUserId(c)
if id != SuperManagerID {
resp.ERROR(c, "只有超级管理员能够进行该操作")
return
}
var data struct {
Id int `json:"id"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -14,20 +14,22 @@ import (
type ApiKeyHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db}
h.App = app
return &h
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
}
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"`
ProxyURL string `json:"proxy_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -36,11 +38,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{}
if data.Id > 0 {
h.db.Find(&apiKey, data.Id)
h.DB.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
res := h.db.Debug().Save(&apiKey)
apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled
apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name
res := h.DB.Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -58,9 +65,14 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := h.db.Find(&items)
res := h.DB.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
@@ -78,15 +90,38 @@ func (h *ApiKeyHandler) List(c *gin.Context) {
resp.SUCCESS(c, keys)
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
func (h *ApiKeyHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,39 @@
package admin
import (
"chatplus/core"
"chatplus/handler"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)
type CaptchaHandler struct {
handler.BaseHandler
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {
CaptchaId string `json:"captcha_id"`
PicPath string `json:"pic_path"`
}
// GetCaptcha 获取验证码
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
var captchaVo CaptchaVo
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
// b64s是图片的base64编码
id, b64s, err := cp.Generate()
if err != nil {
resp.ERROR(c, "生成验证码错误!")
return
}
captchaVo.CaptchaId = id
captchaVo.PicPath = b64s
resp.SUCCESS(c, captchaVo)
}

View File

@@ -0,0 +1,266 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatHandler struct {
handler.BaseHandler
}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type chatItemVo struct {
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Role vo.ChatRole `json:"role"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
}
func (h *ChatHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct {
Title string `json:"title"`
UserId uint `json:"user_id"`
Model string `json:"model"`
CreateAt []string `json:"created_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Title != "" {
session = session.Where("title LIKE ?", "%"+data.Title+"%")
}
if data.UserId > 0 {
session = session.Where("user_id = ?", data.UserId)
}
if data.Model != "" {
session = session.Where("model = ?", data.Model)
}
if len(data.CreateAt) == 2 {
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.ChatItem{}).Count(&total)
var items []model.ChatItem
var list = make([]chatItemVo, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
userIds := make([]uint, 0)
chatIds := make([]string, 0)
roleIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
chatIds = append(chatIds, item.ChatId)
roleIds = append(roleIds, item.RoleId)
}
var messages []model.ChatMessage
var users []model.User
var roles []model.ChatRole
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles)
tokenMap := make(map[string]int)
userMap := make(map[uint]string)
msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole)
for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1
}
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, r := range roles {
var roleVo vo.ChatRole
err := utils.CopyObject(r, &roleVo)
if err != nil {
continue
}
roleMap[r.Id] = roleVo
}
for _, item := range items {
list = append(list, chatItemVo{
UserId: item.UserId,
Username: userMap[item.UserId],
ChatId: item.ChatId,
Title: item.Title,
Model: item.Model,
Token: tokenMap[item.ChatId],
MsgNum: msgMap[item.ChatId],
Role: roleMap[item.RoleId],
CreatedAt: item.CreatedAt.Unix(),
})
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
type chatMessageVo struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Username string `json:"username"`
Content string `json:"content"`
Type string `json:"type"`
Model string `json:"model"`
Token int `json:"token"`
Icon string `json:"icon"`
CreatedAt int64 `json:"created_at"`
}
// Messages 读取聊天记录列表
func (h *ChatHandler) Messages(c *gin.Context) {
var data struct {
UserId uint `json:"user_id"`
Content string `json:"content"`
Model string `json:"model"`
CreateAt []string `json:"created_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Content != "" {
session = session.Where("content LIKE ?", "%"+data.Content+"%")
}
if data.UserId > 0 {
session = session.Where("user_id = ?", data.UserId)
}
if data.Model != "" {
session = session.Where("model = ?", data.Model)
}
if len(data.CreateAt) == 2 {
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.ChatMessage{}).Count(&total)
var items []model.ChatMessage
var list = make([]chatMessageVo, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
userIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
userMap := make(map[uint]string)
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, item := range items {
list = append(list, chatMessageVo{
Id: item.Id,
UserId: item.UserId,
Username: userMap[item.UserId],
Content: item.Content,
Model: item.Model,
Token: item.Tokens,
Icon: item.Icon,
Type: item.Type,
CreatedAt: item.CreatedAt.Unix(),
})
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
} else {
for _, item := range items {
var v vo.HistoryMessage
err := utils.CopyObject(item, &v)
v.CreatedAt = item.CreatedAt.Unix()
v.UpdatedAt = item.UpdatedAt.Unix()
if err == nil {
messages = append(messages, v)
}
}
}
resp.SUCCESS(c, messages)
}
// RemoveChat 删除对话
func (h *ChatHandler) RemoveChat(c *gin.Context) {
chatId := h.GetTrim(c, "chat_id")
if chatId == "" {
resp.ERROR(c, "请传入 ChatId")
return
}
tx := h.DB.Begin()
// 删除聊天记录
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
if res.Error != nil {
resp.ERROR(c, "failed to remove chat message")
return
}
// 删除对话
res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{})
if res.Error != nil {
tx.Rollback() // 回滚
resp.ERROR(c, "failed to remove chat")
return
}
tx.Commit()
resp.SUCCESS(c)
}
// RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -15,26 +15,26 @@ import (
type ChatModelHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Weight int `json:"weight"`
CreatedAt int64 `json:"created_at"`
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -42,18 +42,21 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
}
item := model.ChatModel{
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
Weight: data.Weight}
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
Power: data.Power}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -72,7 +75,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{})
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
@@ -109,7 +117,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -129,7 +137,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -141,13 +149,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
func (h *ChatModelHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -15,13 +15,10 @@ import (
type ChatRoleHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
h := ChatRoleHandler{db: db}
h.App = app
return &h
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或者更新某个角色
@@ -41,7 +38,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&role)
res := h.DB.Save(&role)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -53,9 +50,14 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
}
func (h *ChatRoleHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.db.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
@@ -88,7 +90,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -98,14 +100,39 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
func (h *ChatRoleHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
if res.Error != nil {
resp.ERROR(c, "删除失败!")
return

View File

@@ -14,36 +14,38 @@ import (
type ConfigHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db}
h.App = app
return &h
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ConfigHandler) Update(c *gin.Context) {
var data struct {
Key string `json:"key"`
Config map[string]interface{} `json:"config"`
Key string `json:"key"`
Config struct {
types.SystemConfig
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
str := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: str}
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value}
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
if config.Id > 0 {
config.Config = str
res := h.db.Updates(&config)
config.Config = value
res := h.DB.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -51,12 +53,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// update config cache for AppServer
var cfg model.Config
h.db.Where("marker", data.Key).First(&cfg)
h.DB.Where("marker", data.Key).First(&cfg)
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
} else if data.Key == "chat" {
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
}
if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error())
@@ -70,20 +70,25 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
key := c.Query("key")
var config model.Config
res := h.db.Where("marker", key).First(&config)
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var m map[string]interface{}
err := utils.JsonDecode(config.Config, &m)
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, m)
resp.SUCCESS(c, value)
}

View File

@@ -7,26 +7,25 @@ import (
"chatplus/store/model"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"time"
)
type DashboardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
h := DashboardHandler{db: db}
h.App = app
return &h
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type statsVo struct {
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"`
}
func (h *DashboardHandler) Stats(c *gin.Context) {
@@ -35,37 +34,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
var userCount int64
now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil {
stats.Users = userCount
}
// new chats statistic
var chatCount int64
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
if res.Error == nil {
stats.Chats = chatCount
}
// tokens took stats
var historyMessages []model.HistoryMessage
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages)
var historyMessages []model.ChatMessage
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
for _, item := range historyMessages {
stats.Tokens += item.Tokens
}
// 众筹收入
var rewards []model.Reward
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards)
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入
var orders []model.Order
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders {
stats.Income += item.Amount
}
// 统计7天的订单的图表
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
var statsChart = make(map[string]map[string]float64)
//// 初始化
var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
for i := 0; i < 7; i++ {
var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
userStatistic[initTime] = float64(0)
historyMessagesStatistic[initTime] = float64(0)
incomeStatistic[initTime] = float64(0)
}
// 统计用户7天增加的曲线
var users []model.User
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
if res.Error == nil {
for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
}
}
// 统计7天Token 消耗
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
for _, item := range historyMessages {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
}
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
statsChart["users"] = userStatistic
statsChart["historyMessage"] = historyMessagesStatistic
statsChart["orders"] = incomeStatistic
stats.Chart = statsChart
resp.SUCCESS(c, stats)
}

View File

@@ -0,0 +1,126 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/golang-jwt/jwt/v5"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type FunctionHandler struct {
handler.BaseHandler
}
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *FunctionHandler) Save(c *gin.Context) {
var data vo.Function
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var f = model.Function{
Id: data.Id,
Name: data.Name,
Label: data.Label,
Description: data.Description,
Parameters: utils.JsonEncode(data.Parameters),
Action: data.Action,
Token: data.Token,
Enabled: data.Enabled,
}
res := h.DB.Save(&f)
if res.Error != nil {
resp.ERROR(c, "error with save data:"+res.Error.Error())
return
}
data.Id = f.Id
resp.SUCCESS(c, data)
}
func (h *FunctionHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *FunctionHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Function
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
}
functions := make([]vo.Function, 0)
for _, v := range items {
var f vo.Function
err := utils.CopyObject(v, &f)
if err != nil {
continue
}
functions = append(functions, f)
}
resp.SUCCESS(c, functions)
}
func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}
// GenToken generate function api access token
func (h *FunctionHandler) GenToken(c *gin.Context) {
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": 0,
"expired": 0,
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
logger.Error("error with generate token", err)
resp.ERROR(c)
return
}
resp.SUCCESS(c, tokenString)
}

View File

@@ -0,0 +1,121 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MenuHandler struct {
handler.BaseHandler
}
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *MenuHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
URL string `json:"url"`
SortNum int `json:"sort_num"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Save(&model.Menu{
Id: data.Id,
Name: data.Name,
Icon: data.Icon,
URL: data.URL,
SortNum: data.SortNum,
Enabled: data.Enabled,
})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu
err := utils.CopyObject(item, &product)
if err == nil {
list = append(list, product)
}
}
}
resp.SUCCESS(c, list)
}
func (h *MenuHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *MenuHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
for index, id := range data.Ids {
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}
func (h *MenuHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Menu{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -8,24 +8,28 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OrderHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct {
OrderNo string `json:"order_no"`
Status int `json:"status"`
PayTime []string `json:"pay_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
@@ -35,7 +39,7 @@ func (h *OrderHandler) List(c *gin.Context) {
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if data.OrderNo != "" {
session = session.Where("order_no", data.OrderNo)
}
@@ -44,6 +48,9 @@ func (h *OrderHandler) List(c *gin.Context) {
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
}
if data.Status >= 0 {
session = session.Where("status", data.Status)
}
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order
@@ -72,7 +79,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
if id > 0 {
var item model.Order
res := h.db.First(&item, id)
res := h.DB.First(&item, id)
if res.Error != nil {
resp.ERROR(c, "记录不存在!")
return
@@ -83,7 +90,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return
}
res = h.db.Where("id = ?", id).Delete(&model.Order{})
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -0,0 +1,77 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
handler.BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`
Type int `json:"type"`
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Model != "" {
session = session.Where("model", data.Model)
}
if data.Type > 0 {
session = session.Where("type", data.Type)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
// 统计消费算力总和
var totalPower float64
if len(data.Date) == 2 {
session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
}
resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
}

View File

@@ -15,13 +15,10 @@ import (
type ProductHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ProductHandler) Save(c *gin.Context) {
@@ -32,7 +29,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
Discount float64 `json:"discount"`
Enabled bool `json:"enabled"`
Days int `json:"days"`
Calls int `json:"calls"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -40,12 +37,18 @@ func (h *ProductHandler) Save(c *gin.Context) {
return
}
item := model.Product{Name: data.Name, Price: data.Price, Discount: data.Discount, Days: data.Days, Calls: data.Calls, Enabled: data.Enabled}
item := model.Product{
Name: data.Name,
Price: data.Price,
Discount: data.Discount,
Days: data.Days,
Power: data.Power,
Enabled: data.Enabled}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -62,16 +65,11 @@ func (h *ProductHandler) Save(c *gin.Context) {
resp.SUCCESS(c, itemVo)
}
// List 模型列表
// List 数据列表
func (h *ProductHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
}
var items []model.Product
var list = make([]vo.Product, 0)
res := session.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product
@@ -100,7 +98,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return
}
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -120,7 +118,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -134,7 +132,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.Product{})
res := h.DB.Where("id", id).Delete(&model.Product{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -2,6 +2,7 @@ package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
@@ -13,18 +14,20 @@ import (
type RewardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db}
h.App = app
return &h
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Reward
res := h.db.Order("id DESC").Find(&items)
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
@@ -32,7 +35,7 @@ func (h *RewardHandler) List(c *gin.Context) {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.db.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
@@ -46,7 +49,7 @@ func (h *RewardHandler) List(c *gin.Context) {
}
r.Id = v.Id
r.Username = userMap[v.UserId].Mobile
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
r.UpdatedAt = v.UpdatedAt.Unix()
rewards = append(rewards, r)
@@ -55,3 +58,21 @@ func (h *RewardHandler) List(c *gin.Context) {
resp.SUCCESS(c, rewards)
}
func (h *RewardHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,45 @@
package admin
import (
"chatplus/core"
"chatplus/handler"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
handler.BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
userId := 0
res := h.DB.Create(&model.File{
UserId: userId,
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}

View File

@@ -8,35 +8,40 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type UserHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
h := UserHandler{db: db}
h.App = app
return &h
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
mobile := h.GetTrim(c, "mobile")
username := h.GetTrim(c, "username")
offset := (page - 1) * pageSize
var items []model.User
var users = make([]vo.User, 0)
var total int64
session := h.db.Session(&gorm.Session{})
if mobile != "" {
session = session.Where("mobile LIKE ?", "%"+mobile+"%")
session := h.DB.Session(&gorm.Session{})
if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%")
}
session.Model(&model.User{}).Count(&total)
@@ -63,14 +68,13 @@ func (h *UserHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Password string `json:"password"`
Mobile string `json:"mobile"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
Username string `json:"username"`
ChatRoles []string `json:"chat_roles"`
ChatModels []string `json:"chat_models"`
ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"`
Status bool `json:"status"`
Vip bool `json:"vip"`
Power int `json:"power"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -80,40 +84,60 @@ func (h *UserHandler) Save(c *gin.Context) {
var res *gorm.DB
var userVo vo.User
if data.Id > 0 { // 更新
user.Id = data.Id
// 此处需要用 map 更新,用结构体无法更新 0 值
res = h.db.Model(&user).Updates(map[string]interface{}{
"mobile": data.Mobile,
"calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status,
"vip": data.Vip,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
"chat_models_json": utils.JsonEncode(data.ChatModels),
"expired_time": utils.Str2stamp(data.ExpiredTime),
})
res = h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, "user not found")
return
}
var oldPower = user.Power
user.Username = data.Username
user.Status = data.Status
user.Vip = data.Vip
user.Power = data.Power
user.ChatRoles = utils.JsonEncode(data.ChatRoles)
user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力日志
if oldPower != user.Power {
mark := types.PowerAdd
amount := user.Power - oldPower
if oldPower > user.Power {
mark = types.PowerSub
amount = oldPower - user.Power
}
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerGift,
Amount: amount,
Balance: user.Power,
Mark: mark,
Model: "管理员",
Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID%d", oldPower, user.Power, h.GetLoginUserId(c)),
CreatedAt: time.Now(),
})
}
} else {
salt := utils.RandString(8)
u := model.User{
Mobile: data.Mobile,
Username: data.Username,
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Power: data.Power,
Status: true,
ChatRoles: utils.JsonEncode(data.ChatRoles),
ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: data.Calls,
ImgCalls: data.ImgCalls,
}
res = h.db.Create(&u)
res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id
userVo.CreatedAt = u.CreatedAt.Unix()
@@ -140,7 +164,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.First(&user, data.Id)
res := h.DB.First(&user, data.Id)
if res.Error != nil {
resp.ERROR(c, "No user found")
return
@@ -148,7 +172,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -158,36 +182,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
tx := h.db.Begin()
res := h.db.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除聊天历史记录
res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除登录日志
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c)
}
@@ -195,10 +215,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
var total int64
h.db.Model(&model.UserLoginLog{}).Count(&total)
h.DB.Model(&model.UserLoginLog{}).Count(&total)
offset := (page - 1) * pageSize
var items []model.UserLoginLog
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "获取数据失败")
return

View File

@@ -4,8 +4,11 @@ import (
"chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/utils"
"errors"
"fmt"
"gorm.io/gorm"
"strings"
"github.com/gin-gonic/gin"
@@ -15,6 +18,7 @@ var logger = logger2.GetLogger()
type BaseHandler struct {
App *core.AppServer
DB *gorm.DB
}
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
@@ -49,3 +53,35 @@ func (h *BaseHandler) GetUserKey(c *gin.Context) string {
}
return fmt.Sprintf("users/%v", userId)
}
func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
userId, ok := c.Get(types.LoginUserID)
if !ok {
return 0
}
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0
}
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
value, exists := c.Get(types.LoginUserCache)
if exists {
return value.(model.User), nil
}
userId, ok := c.Get(types.LoginUserID)
if !ok {
return model.User{}, errors.New("user not login")
}
var user model.User
res := h.DB.First(&user, userId)
// 更新缓存
if res.Error == nil {
c.Set(types.LoginUserCache, user)
}
return user, res.Error
}

View File

@@ -45,3 +45,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
}
}
// SlideGet 获取滑动验证图片
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
data, err := h.service.SlideGet()
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// SlideCheck 滑动验证结果校验
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
var data struct {
Key string `json:"key"`
X int `json:"x"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.service.SlideCheck(data) {
resp.SUCCESS(c)
} else {
resp.ERROR(c)
}
}

View File

@@ -12,37 +12,34 @@ import (
type ChatModelHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
// 只加载用户订阅的 AI 模型
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
var res *gorm.DB
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
err := utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
}
var models []string
err = utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res := h.db.Where("enabled = ?", true).Where(
h.db.Where("value IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel

View File

@@ -14,27 +14,26 @@ import (
type ChatRoleHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
handler := &ChatRoleHandler{db: db}
handler.App = app
return handler
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List get user list
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
var roleVos = make([]vo.ChatRole, 0)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error())
resp.SUCCESS(c, roleVos)
return
}
// 获取所有角色
if all {
if userId == 0 || all {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
@@ -49,21 +48,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.NotAuth(c)
return
}
var user model.User
h.db.First(&user, userId)
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return
}
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
if !utils.ContainsStr(roleKeys, r.Key) {
continue
@@ -80,7 +73,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// UpdateRole 更新用户聊天角色
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
@@ -94,7 +87,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return
}
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil {
logger.Error("添加应用失败:", err)
resp.ERROR(c, "更新数据库失败!")

View File

@@ -9,7 +9,7 @@ import (
"context"
"encoding/json"
"fmt"
"gorm.io/gorm"
"html/template"
"io"
"strings"
"time"
@@ -19,7 +19,7 @@ import (
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@@ -56,9 +56,6 @@ func (h *ChatHandler) sendAzureMessage(
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var functionCall = false
var functionName string
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
@@ -68,34 +65,17 @@ func (h *ChatHandler) sendAzureMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
}
fun := responseBody.Choices[0].Delta.FunctionCall
if functionCall && fun.Name == "" {
arguments = append(arguments, fun.Arguments)
if len(responseBody.Choices) == 0 {
continue
}
if !utils.IsEmptyValue(fun) {
functionName = fun.Name
f := h.App.Functions[functionName]
if f != nil {
functionCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
continue
}
}
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
@@ -121,54 +101,8 @@ func (h *ChatHandler) sendAzureMessage(
}
}
if functionCall { // 调用函数完成任务
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
// for creating image, check if the user's img_calls > 0
if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
utils.ReplyMessage(ws, ErrImg)
} else {
f := h.App.Functions[functionName]
if functionName == types.FuncMidJourney {
params["user_id"] = userVo.Id
params["role_id"] = role.Id
params["chat_id"] = session.ChatId
params["icon"] = "/images/avatar/mid_journey.png"
params["session_id"] = session.SessionId
}
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
h.mjService.ChatClients.Put(session.SessionId, ws)
// update user's img_calls
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
}
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
@@ -177,77 +111,64 @@ func (h *ChatHandler) sendAzureMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && functionCall == false {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -258,7 +179,8 @@ func (h *ChatHandler) sendAzureMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -9,6 +9,7 @@ import (
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"net/http"
"strings"
@@ -35,7 +36,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -45,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@@ -127,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -137,63 +135,63 @@ func (h *ChatHandler) sendBaiduMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -204,7 +202,8 @@ func (h *ChatHandler) sendBaiduMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -6,7 +6,7 @@ import (
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@@ -15,39 +15,44 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"net/url"
"strings"
"time"
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
const ErrImg = "![](/images/wx.png)"
var ErrImg = "![](/images/wx.png)"
var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
mjService *mj.Service
redis *redis.Client
uploadManager *oss.UploaderManager
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, service *mj.Service) *ChatHandler {
h := ChatHandler{
db: db,
redis: redis,
mjService: service,
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
}
h.App = app
return &h
}
var chatConfig types.ChatConfig
func (h *ChatHandler) Init() {
// 如果后台有上传微信客服微信二维码,则覆盖
if h.App.SysConfig.WechatCardURL != "" {
ErrImg = fmt.Sprintf("![](%s)", h.App.SysConfig.WechatCardURL)
}
}
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
@@ -65,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
client := types.NewWsClient(ws)
// get model info
var chatModel model.ChatModel
res := h.db.First(&chatModel, modelId)
res := h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -74,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
@@ -83,7 +88,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session = &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
Username: user.Mobile,
Username: user.Username,
UserId: user.Id,
}
h.App.ChatSession.Put(sessionId, session)
@@ -91,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.db.Where("chat_id=?", chatId).First(&chat)
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
@@ -99,28 +104,24 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Value: chatModel.Value,
Weight: chatModel.Weight,
Platform: types.Platform(chatModel.Platform)}
Id: chatModel.Id,
Name: chatModel.Name,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// 初始化聊天配置
var config model.Config
h.db.Where("marker", "chat").First(&config)
err = utils.JsonDecode(config.Config, &chatConfig)
if err != nil {
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
c.Abort()
return
}
h.Init()
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
@@ -128,7 +129,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
h.App.ChatClients.Delete(sessionId)
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
@@ -139,19 +139,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
message := string(msg)
logger.Info("Receive a message: ", message)
//utils.ReplyMessage(client, "这是一条测试消息!")
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, message, client)
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Info("回答完毕: " + string(message))
logger.Infof("回答完毕: %v", message.Content)
}
}
@@ -159,16 +170,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
defer func() {
if r := recover(); r != nil {
logger.Error("Recover message from error: ", r)
}
}()
if !h.App.Debug {
defer func() {
if r := recover(); r != nil {
logger.Error("Recover message from error: ", r)
}
}()
}
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "非法用户,请联系管理员")
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作")
return res.Error
}
var userVo vo.User
@@ -184,14 +197,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Calls < session.Model.Weight {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数%d已不足以支付当前模型的单次对话需要消耗的对话额度%d", userVo.Calls, session.Model.Weight))
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
}
@@ -201,39 +208,64 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure:
req.Temperature = h.App.ChatConfig.Azure.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
case types.Baidu:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
// TODO 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
if h.App.SysConfig.EnabledFunction {
var functions = make([]types.Function, 0)
for _, f := range types.InnerFunctions {
if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
continue
}
functions = append(functions, f)
}
req.Functions = functions
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
var tools = make([]interface{}, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
required := parameters["required"]
delete(parameters, "required")
tools = append(tools, gin.H{
"type": "function",
"function": gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
},
})
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
@@ -241,43 +273,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 加载聊天上下文
var chatCtx []interface{}
if h.App.ChatConfig.EnableContext {
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num
tokens := req.MaxTokens
for _, f := range types.InnerFunctions {
tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model)
tokens += tks
}
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.ModelToTokens[req.Model] {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
if chatConfig.ContextDeep > 0 {
var historyMessages []model.HistoryMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
@@ -287,6 +295,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)
@@ -294,10 +325,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
if session.Model.Platform == types.QWen {
req.Input = map[string]interface{}{"prompt": prompt}
if len(reqMgs) > 0 {
req.Input["messages"] = reqMgs
}
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
switch session.Model.Platform {
case types.Azure:
@@ -310,7 +348,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
@@ -322,8 +361,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
// Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) {
var data struct {
Text string `json:"text"`
Model string `json:"model"`
Text string `json:"text"`
Model string `json:"model"`
ChatId string `json:"chat_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -331,10 +371,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
}
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
if data.Text == "" {
var item model.HistoryMessage
if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Last(&item)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -384,39 +424,37 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
var apiURL string
switch platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = h.App.ChatConfig.OpenAI.ApiURL
apiURL = apiKey.ApiURL
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(*apiKey)
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
@@ -424,6 +462,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
logger.Debugf(utils.JsonEncode(req))
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@@ -437,9 +477,9 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
proxyURL := h.App.Config.ProxyURL
if proxyURL != "" && platform == types.OpenAI { // 使用代理
proxy, _ := url.Parse(proxyURL)
var proxyURL string
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
@@ -448,42 +488,79 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform {
case types.Azure:
request.Header.Set("api-key", *apiKey)
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM:
token, err := h.getChatGLMToken(*apiKey)
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info(token)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
return client.Do(request)
}
// 扣减用户的对话次数
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
// 仅当用户没有导入自己的 API KEY 时才进行扣减
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
num := 1
if session.Model.Weight > 0 {
num = session.Model.Weight
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
}
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
// 将AI回复消息中生成的图片链接下载到本地
func (h *ChatHandler) extractImgUrl(text string) string {
pattern := `!\[([^\]]*)]\(([^)]+)\)`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(text, -1)
// 下载图片并替换链接地址
for _, match := range matches {
imageURL := match[2]
logger.Debug(imageURL)
// 对于相同地址的图片,已经被替换了,就不再重复下载了
if !strings.Contains(text, imageURL) {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
text = strings.ReplaceAll(text, imageURL, newImgURL)
}
return text
}

View File

@@ -6,27 +6,29 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
if !h.IsLogin(c) {
resp.SUCCESS(c)
return
}
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.db.Find(&roles, roleIds)
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
@@ -58,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
@@ -70,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil {
resp.ERROR(c, "No chats found")
return
@@ -89,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
}
err = h.db.Transaction(func(tx *gorm.DB) error {
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil {
return res.Error
}
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.HistoryMessage{})
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
if res.Error != nil {
return res.Error
}
@@ -116,9 +118,9 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.HistoryMessage
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
@@ -144,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
// 删除当前会话的聊天记录
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.")
return
@@ -179,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
}
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
if res.Error != nil {
resp.ERROR(c, "No chat found")
return

View File

@@ -10,6 +10,7 @@ import (
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt/v5"
"html/template"
"io"
"strings"
"time"
@@ -19,7 +20,7 @@ import (
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -29,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@@ -106,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -116,63 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -183,7 +182,8 @@ func (h *ChatHandler) sendChatGLMMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -9,16 +9,18 @@ import (
"context"
"encoding/json"
"fmt"
"gorm.io/gorm"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
req2 "github.com/imroc/req/v3"
)
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -28,7 +30,7 @@ func (h *ChatHandler) sendOpenAiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
@@ -44,6 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
if response.Body != nil {
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
}
return err
} else {
defer response.Body.Close()
@@ -55,8 +61,8 @@ func (h *ChatHandler) sendOpenAiMessage(
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var functionCall = false
var functionName string
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
@@ -74,24 +80,37 @@ func (h *ChatHandler) sendOpenAiMessage(
break
}
var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && tool.Function.Name == "" {
arguments = append(arguments, tool.Function.Arguments)
continue
}
}
// 兼容 Function Call
fun := responseBody.Choices[0].Delta.FunctionCall
if functionCall && fun.Name == "" {
if fun.Name != "" {
tool = *new(types.ToolCall)
tool.Function.Name = fun.Name
} else if toolCall {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(fun) {
functionName = fun.Name
f := h.App.Functions[functionName]
if f != nil {
functionCall = true
if !utils.IsEmptyValue(tool) {
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
}
continue
}
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
@@ -120,55 +139,40 @@ func (h *ChatHandler) sendOpenAiMessage(
}
}
if functionCall { // 调用函数完成任务
if toolCall { // 调用函数完成任务
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
// for creating image, check if the user's img_calls > 0
if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
utils.ReplyMessage(ws, ErrImg)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", function.Token).
SetBody(params).
SetSuccessResult(&apiRes).Post(function.Action)
errMsg := ""
if err != nil {
errMsg = err.Error()
} else if r.IsErrorState() {
errMsg = r.Status
}
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
f := h.App.Functions[functionName]
if functionName == types.FuncMidJourney {
params["user_id"] = userVo.Id
params["role_id"] = role.Id
params["chat_id"] = session.ChatId
params["icon"] = "/images/avatar/mid_journey.png"
params["session_id"] = session.SessionId
}
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
h.mjService.ChatClients.Put(session.SessionId, ws)
// update user's img_calls
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: apiRes.Data,
})
contents = append(contents, utils.InterfaceToString(apiRes.Data))
}
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -176,77 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && functionCall == false {
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -257,17 +261,20 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
@@ -275,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {

View File

@@ -0,0 +1,240 @@
package chatimpl
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
}
return nil
}

View File

@@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"html/template"
"io"
"net/http"
"net/url"
@@ -49,15 +50,16 @@ type xunFeiResp struct {
}
var Model2URL = map[string]string{
"generalv1": "1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -66,29 +68,26 @@ func (h *ChatHandler) sendXunFeiMessage(
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
if apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
apiKey = key.Value
var apiKey model.ApiKey
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey, "|")
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1)
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
@@ -169,9 +168,6 @@ func (h *ChatHandler) sendXunFeiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -179,63 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -246,7 +243,8 @@ func (h *ChatHandler) sendXunFeiMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
@@ -262,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": float64(req.Temperature),
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",

View File

@@ -0,0 +1,39 @@
package handler
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ConfigHandler struct {
BaseHandler
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, value)
}

View File

@@ -0,0 +1,274 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"strings"
"time"
)
type FunctionHandler struct {
BaseHandler
config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager
}
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
return &FunctionHandler{
BaseHandler: BaseHandler{
App: server,
DB: db,
},
config: config.ApiConfig,
uploadManager: manager,
}
}
type resVo struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data struct {
Title string `json:"title"`
UpdatedAt string `json:"updated_at"`
Items []dataItem `json:"items"`
} `json:"data"`
}
type dataItem struct {
Title string `json:"title"`
Url string `json:"url"`
Remark string `json:"remark"`
}
// check authorization
func (h *FunctionHandler) checkAuth(c *gin.Context) error {
tokenString := c.GetHeader(types.UserAuthHeader)
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(h.App.Config.Session.SecretKey), nil
})
if err != nil {
return fmt.Errorf("error with parse auth token: %v", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return errors.New("token is invalid")
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
return errors.New("token is expired")
}
return nil
}
// WeiBo 微博热搜
func (h *FunctionHandler) WeiBo(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
if h.config.Token == "" {
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
var res resVo
r, err := req.C().R().
SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
return
}
if res.Code != types.Success {
resp.ERROR(c, res.Message)
return
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
for i, v := range res.Data.Items {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
}
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
}
// ZaoBao 今日早报
func (h *FunctionHandler) ZaoBao(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
if h.config.Token == "" {
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
var res resVo
r, err := req.C().R().
SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
return
}
if res.Code != types.Success {
resp.ERROR(c, res.Message)
return
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
for _, v := range res.Data.Items {
builder = append(builder, v.Title)
}
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
// Dall3 DallE3 AI 绘图
func (h *FunctionHandler) Dall3(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
var params map[string]interface{}
if err := c.ShouldBindJSON(&params); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
logger.Debugf("绘画参数:%+v", params)
var user model.User
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return
}
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
return
}
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt
}
var res imgRes
var errRes ErrRes
var request *req.Request
if len(apiKey.ProxyURL) > 5 {
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else {
request = req.C().R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if r.IsErrorState() {
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
return
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res)
// 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
if err != nil {
resp.ERROR(c, "下载图片失败: "+err.Error())
return
}
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// 更新用户算力
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
h.DB.Where("id", user.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content)
}

View File

@@ -0,0 +1,93 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"strings"
)
// InviteHandler 用户邀请
type InviteHandler struct {
BaseHandler
}
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c)
var inviteCode model.InviteCode
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
// 如果邀请码不存在,则创建一个
if res.Error != nil {
code := strings.ToUpper(utils.RandString(8))
for {
res = h.DB.Where("code = ?", code).First(&inviteCode)
if res.Error != nil { // 不存在相同的邀请码则退出
break
}
}
inviteCode.UserId = userId
inviteCode.Code = code
h.DB.Create(&inviteCode)
}
var codeVo vo.InviteCode
err := utils.CopyObject(inviteCode, &codeVo)
if err != nil {
resp.ERROR(c, "拷贝对象失败")
return
}
resp.SUCCESS(c, codeVo)
}
// List Log 用户邀请记录
func (h *InviteHandler) List(c *gin.Context) {
var data struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
var list = make([]vo.InviteLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var v vo.InviteLog
err := utils.CopyObject(item, &v)
if err == nil {
v.Id = item.Id
v.CreatedAt = item.CreatedAt.Unix()
list = append(list, v)
} else {
logger.Error(err)
}
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
// Hits 访问邀请码
func (h *InviteHandler) Hits(c *gin.Context) {
code := c.Query("code")
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,36 @@
package handler
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MenuHandler struct {
BaseHandler
}
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu
err := utils.CopyObject(item, &product)
if err == nil {
list = append(list, product)
}
}
}
resp.SUCCESS(c, list)
}

View File

@@ -3,66 +3,57 @@ package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
type MidJourneyHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
mjService *mj.Service
pool *mj.ServicePool
snowflake *service.Snowflake
uploader *oss.UploaderManager
}
func NewMidJourneyHandler(
app *core.AppServer,
client *redis.Client,
db *gorm.DB,
mjService *mj.Service) *MidJourneyHandler {
h := MidJourneyHandler{
redis: client,
db: db,
mjService: mjService,
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
return &MidJourneyHandler{
snowflake: snowflake,
pool: pool,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
h.mjService.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
@@ -70,97 +61,180 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
}
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
if !h.App.Config.MjConfig.Enabled {
resp.ERROR(c, "MidJourney service is disabled")
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Img string `json:"img"`
Weight float32 `json:"weight"`
SessionId string `json:"session_id"`
TaskType string `json:"task_type"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Iw float32 `json:"iw"`
CRef string `json:"cref"` //生成角色一致的图像
SRef string `json:"sref"` //生成风格一致的图像
Cw int `json:"cw"` // 参考程度
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if !h.checkLimits(c) {
if !h.preCheck(c) {
return
}
var prompt = data.Prompt
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
prompt += " --ar " + data.Rate
var params = ""
if data.Rate != "" && !strings.Contains(params, "--ar") {
params += " --ar " + data.Rate
}
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
prompt += fmt.Sprintf(" --seed %d", data.Seed)
if data.Seed > 0 && !strings.Contains(params, "--seed") {
params += fmt.Sprintf(" --seed %d", data.Seed)
}
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
prompt += fmt.Sprintf(" --s %d", data.Stylize)
if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
params += fmt.Sprintf(" --s %d", data.Stylize)
}
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
prompt += fmt.Sprintf(" --c %d", data.Chaos)
if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
params += fmt.Sprintf(" --c %d", data.Chaos)
}
if data.Img != "" {
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
}
if len(data.ImgArr) > 0 && data.Iw > 0 {
params += fmt.Sprintf(" --iw %.2f", data.Iw)
}
if data.Raw {
prompt += " --style raw"
params += " --style raw"
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += data.Model
if data.Quality > 0 {
params += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.Tile {
params += " --tile "
}
if data.CRef != "" {
params += fmt.Sprintf(" --cref %s", data.CRef)
if data.Cw > 0 {
params += fmt.Sprintf(" --cw %d", data.Cw)
} else {
params += " --cw 100"
}
}
if data.SRef != "" {
params += fmt.Sprintf(" --sref %s", data.CRef)
}
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
params += fmt.Sprintf(" %s", data.Model)
}
// 处理融图和换脸的提示词
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
}
// 如果本地图片上传的是相对地址,处理成绝对地址
for k, v := range data.ImgArr {
if !strings.HasPrefix(v, "http") {
data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/"))
}
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
// generate task id
taskId, err := h.snowflake.Next(true)
if err != nil {
resp.ERROR(c, "error with generate task id: "+err.Error())
return
}
job := model.MidJourneyJob{
Type: types.TaskImage.String(),
Type: data.TaskType,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: prompt,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil {
opt := "绘图"
if data.TaskType == types.TaskBlend.String() {
job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
opt = "融图"
} else if data.TaskType == types.TaskSwapFace.String() {
job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
opt = "换脸"
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.mjService.PushTask(types.MjTask{
Id: int(job.Id),
h.pool.PushTask(types.MjTask{
Id: job.Id,
TaskId: taskId,
SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage,
Prompt: prompt,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
})
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
type reqVo struct {
Src string `json:"src"`
Index int `json:"index"`
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
SessionId string `json:"session_id"`
@@ -178,64 +252,60 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return
}
if !h.checkLimits(c) {
if !h.preCheck(c) {
return
}
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := types.TaskSrc(data.Src)
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
UserId: userId,
Hash: data.MessageHash,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
ReferenceId: data.MessageId,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
h.mjService.PushTask(types.MjTask{
Id: jobId,
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: job.Id,
SessionId: data.SessionId,
Src: src,
Type: types.TaskUpscale,
Prompt: data.Prompt,
UserId: userId,
RoleId: data.RoleId,
Icon: data.Icon,
ChatId: data.ChatId,
ChannelId: data.ChannelId,
Index: data.Index,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
})
if src == types.TaskSrcChat {
wsClient := h.App.ChatClients.Get(data.SessionId)
if wsClient != nil {
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
utils.ReplyMessage(wsClient, content)
if h.mjService.ChatClients.Get(data.SessionId) == nil {
h.mjService.ChatClients.Put(data.SessionId, wsClient)
}
}
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -248,79 +318,100 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return
}
if !h.checkLimits(c) {
if !h.preCheck(c) {
return
}
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := types.TaskSrc(data.Src)
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
UserId: userId,
ImgURL: "",
Hash: data.MessageHash,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
ChannelId: data.ChannelId,
ReferenceId: data.MessageId,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
h.mjService.PushTask(types.MjTask{
Id: jobId,
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: job.Id,
SessionId: data.SessionId,
Src: src,
Type: types.TaskVariation,
Prompt: data.Prompt,
UserId: userId,
RoleId: data.RoleId,
Icon: data.Icon,
ChatId: data.ChatId,
Index: data.Index,
ChannelId: data.ChannelId,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
})
if src == types.TaskSrcChat {
// 从聊天窗口发送的请求,记录客户端信息
wsClient := h.mjService.ChatClients.Get(data.SessionId)
if wsClient != nil {
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
utils.ReplyMessage(wsClient, content)
if h.mjService.Clients.Get(data.SessionId) == nil {
h.mjService.Clients.Put(data.SessionId, wsClient)
}
}
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
userId := h.GetInt(c, "user_id", 0)
// ImgWall 照片墙
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
session := h.db.Session(&gorm.Session{})
if status == 1 {
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
@@ -328,6 +419,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish = ?", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
@@ -336,8 +430,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
var items []model.MidJourneyJob
res := session.Find(&items)
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
return res.Error, nil
}
var jobs = make([]vo.MidJourneyJob, 0)
@@ -347,20 +440,73 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
if err != nil {
continue
}
if item.Progress < 100 {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
h.db.Delete(&item)
continue
}
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} else {
job.ImgURL = job.OrgURL
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
return nil, jobs
}
// Remove remove task image
func (h *MidJourneyHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -7,19 +7,17 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OrderHandler struct {
BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
@@ -31,8 +29,8 @@ func (h *OrderHandler) List(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, _ := utils.GetLoginUser(c, h.db)
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order

View File

@@ -11,70 +11,115 @@ import (
"embed"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"github.com/shopspring/decimal"
"math"
"net/http"
"net/url"
"sync"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const (
PayWayAlipay = "支付宝"
PayWayWechat = "微信支付"
PayWayXunHu = "虎皮椒"
PayWayJs = "PayJS"
)
// PaymentHandler 支付服务回调 handler
type PaymentHandler struct {
BaseHandler
alipayService *payment.AlipayService
snowflake *service.Snowflake
db *gorm.DB
fs embed.FS
lock sync.Mutex
alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService
js *payment.PayJS
snowflake *service.Snowflake
fs embed.FS
lock sync.Mutex
}
func NewPaymentHandler(server *core.AppServer, alipayService *payment.AlipayService, snowflake *service.Snowflake, db *gorm.DB, fs embed.FS) *PaymentHandler {
h := PaymentHandler{lock: sync.Mutex{}}
h.App = server
h.alipayService = alipayService
h.snowflake = snowflake
h.db = db
h.fs = fs
return &h
func NewPaymentHandler(
server *core.AppServer,
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
js *payment.PayJS,
db *gorm.DB,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
js: js,
snowflake: snowflake,
fs: fs,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
}
}
func (h *PaymentHandler) Alipay(c *gin.Context) {
func (h *PaymentHandler) DoPay(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no")
payWay := h.GetTrim(c, "pay_way")
if orderNo == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
// 更新扫码状态
h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
// 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL
returnURL := "" // 关闭同步回跳
amount := fmt.Sprintf("%.2f", order.Amount)
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
if order.Status == types.OrderPaidSuccess {
resp.ERROR(c, "This order had been paid, please do not pay twice")
return
}
c.Redirect(302, uri)
// 更新扫码状态
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
if payWay == "alipay" { // 支付宝
// 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL
returnURL := "" // 关闭同步回跳
amount := fmt.Sprintf("%.2f", order.Amount)
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
c.Redirect(302, uri)
return
} else if payWay == "hupi" { // 虎皮椒支付
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", order.Amount),
Title: order.Subject,
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
WapName: "极客学长",
}
r, err := h.huPiPayService.Pay(params)
if err != nil {
resp.ERROR(c, err.Error())
return
}
c.Redirect(302, r.URL)
}
resp.ERROR(c, "Invalid operations")
}
// OrderQuery 单状态查询
// OrderQuery 查询订单状态
func (h *PaymentHandler) OrderQuery(c *gin.Context) {
var data struct {
OrderNo string `json:"order_no"`
@@ -85,7 +130,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
@@ -100,7 +145,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
for {
time.Sleep(time.Second)
var item model.Order
h.db.Where("order_no = ?", data.OrderNo).First(&item)
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
@@ -111,16 +156,12 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
resp.SUCCESS(c, gin.H{"status": order.Status})
}
// AlipayQrcode 生成支付宝支付 URL 二维码
func (h *PaymentHandler) AlipayQrcode(c *gin.Context) {
if !h.App.SysConfig.EnabledAlipay || h.alipayService == nil {
resp.ERROR(c, "当前支付通道已经关闭,请联系管理员开通!")
return
}
// PayQrcode 生成支付 URL 二维码
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
var data struct {
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -128,62 +169,105 @@ func (h *PaymentHandler) AlipayQrcode(c *gin.Context) {
}
var product model.Product
res := h.db.First(&product, data.ProductId)
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next()
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.db.First(&user, data.UserId)
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
var payWay string
var notifyURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
default:
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Calls: product.Calls,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{
UserId: user.Id,
Mobile: user.Mobile,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: product.Price - product.Discount,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: PayWayAlipay,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.db.Create(&order)
if res.Error != nil {
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
// 生成二维码图片
file, err := h.fs.Open("res/img/alipay.jpg")
// PayJs 单独处理,只能用官方生成二维码
if data.PayWay == "payjs" {
params := payment.JPayReq{
TotalFee: int(math.Ceil(order.Amount * 100)),
OutTradeNo: order.OrderNo,
Subject: product.Name,
}
r := h.js.Pay(params)
if r.IsOK() {
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
return
} else {
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
return
}
}
var logo string
if data.PayWay == "alipay" {
logo = "res/img/alipay.jpg"
} else if data.PayWay == "hupi" {
if h.App.Config.HuPiPayConfig.Name == "wechat" {
logo = "res/img/wechat-pay.jpg"
} else {
logo = "res/img/alipay.jpg"
}
}
file, err := h.fs.Open(logo)
if err != nil {
resp.ERROR(c, err.Error())
resp.ERROR(c, "error with open qrcode log file: "+err.Error())
return
}
parse, err := url.Parse(h.App.Config.AlipayConfig.NotifyURL)
parse, err := url.Parse(notifyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
imageURL := fmt.Sprintf("%s://%s/api/payment/alipay?order_no=%s", parse.Scheme, parse.Host, orderNo)
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
imgData, err := utils.GenQrcode(imageURL, 400, file)
if err != nil {
resp.ERROR(c, err.Error())
@@ -193,6 +277,252 @@ func (h *PaymentHandler) AlipayQrcode(c *gin.Context) {
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
}
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
CallbackURL: returnURL,
WapName: "极客学长",
}
r, err := h.huPiPayService.Pay(params)
if err != nil {
logger.Error("error with generating Pay URL: ", err.Error())
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
payURL = r.URL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
returnURL = h.App.Config.JPayConfig.ReturnURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
params := url.Values{}
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
params.Add("out_trade_no", orderNo)
params.Add("body", product.Name)
params.Add("notify_url", notifyURL)
params.Add("auto", "0")
payURL = h.js.PayH5(params)
case "alipay":
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
returnURL = h.App.Config.AlipayConfig.ReturnURL
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
if err != nil {
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
return
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
resp.SUCCESS(c, payURL)
}
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error)
logger.Error(err)
return err
}
h.lock.Lock()
defer h.lock.Unlock()
// 已支付订单,直接返回
if order.Status == types.OrderPaidSuccess {
return nil
}
var user model.User
res = h.DB.First(&user, order.UserId)
if res.Error != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error)
logger.Error(err)
return err
}
var remark types.OrderRemark
err := utils.JsonDecode(order.Remark, &remark)
if err != nil {
err := fmt.Errorf("error with decode order remark: %v", err)
logger.Error(err)
return err
}
var opt string
var power int
if remark.Days > 0 { // VIP 充值
if user.ExpiredTime >= time.Now().Unix() {
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
opt = "VIP充值VIP 没到期,只延期不增加算力"
} else {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
power = h.App.SysConfig.VipMonthPower
opt = "VIP充值"
}
user.Vip = true
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
return err
}
// 更新订单状态
order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo
res = h.DB.Updates(&order)
if res.Error != nil {
err := fmt.Errorf("error with update order info: %v", res.Error)
logger.Error(err)
return err
}
// 更新产品销量
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
// 记录算力充值日志
if opt != "" {
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
}
return nil
}
// GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
data := gin.H{}
if h.App.Config.AlipayConfig.Enabled {
data["alipay"] = gin.H{"name": "alipay"}
}
if h.App.Config.HuPiPayConfig.Enabled {
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
}
if h.App.Config.JPayConfig.Enabled {
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
}
resp.SUCCESS(c, data)
}
// HuPiPayNotify 虎皮椒支付异步回调
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,订单 NO%s交易流水号%s", orderNo, tradeNo)
if err = h.huPiPayService.Check(tradeNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// AlipayNotify 支付宝支付回调
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
@@ -200,74 +530,55 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return
}
// TODO这里最好用支付宝的公钥签名签证一下交易真假
//res := h.alipayService.TradeVerify(c.Request.Form)
r := h.alipayService.TradeQuery(c.Request.Form.Get("out_trade_no"))
logger.Infof("验证支付结果:%+v", r)
if !r.Success() {
// TODO验证交易签名
res := h.alipayService.TradeVerify(c.Request.Form)
logger.Infof("验证支付结果:%+v", res)
if !res.Success() {
logger.Error("订单校验失败:", res.Message)
c.String(http.StatusOK, "fail")
return
}
h.lock.Lock()
defer h.lock.Unlock()
var order model.Order
res := h.db.Where("order_no = ?", r.OutTradeNo).First(&order)
if res.Error != nil {
logger.Error(res.Error)
c.String(http.StatusOK, "fail")
return
}
var user model.User
res = h.db.First(&user, order.UserId)
if res.Error != nil {
logger.Error(res.Error)
c.String(http.StatusOK, "fail")
return
}
var remark types.OrderRemark
err = utils.JsonDecode(order.Remark, &remark)
tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(res.OutTradeNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// PayJsNotify PayJs 支付异步回调
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("out_trade_no")
returnCode := c.Request.Form.Get("return_code")
logger.Infof("收到订单支付回调,订单 NO%s支付结果代码%v", orderNo, returnCode)
// 支付失败
if returnCode != "1" {
return
}
// 校验订单支付状态
tradeNo := c.Request.Form.Get("payjs_order_id")
err = h.js.Check(tradeNo)
if err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
logger.Error(res.Error)
c.String(http.StatusOK, "fail")
return
}
// 1. 点卡days == 0, calls > 0
// 2. vip 套餐days > 0, calls == 0
if remark.Days > 0 {
if user.ExpiredTime > time.Now().Unix() {
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
} else {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
}
user.Vip = true
} else if !user.Vip { // 充值点卡的非 VIP 用户
user.ExpiredTime = time.Now().AddDate(0, 0, 30).Unix()
}
if remark.Calls > 0 { // 充值点卡
user.Calls += remark.Calls
} else {
user.Calls += h.App.SysConfig.VipMonthCalls
}
// 更新用户信息
res = h.db.Updates(&user)
if res.Error != nil {
logger.Error(res.Error)
c.String(http.StatusOK, "fail")
return
}
// 更新订单状态
order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess
h.db.Updates(&order)
// 更新产品销量
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
c.String(http.StatusOK, "success")
}

View File

@@ -0,0 +1,67 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
userId := h.GetLoginUserId(c)
session = session.Where("user_id", userId)
if data.Model != "" {
session = session.Where("model", data.Model)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}

View File

@@ -12,20 +12,17 @@ import (
type ProductHandler struct {
BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product
var list = make([]vo.Product, 0)
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product

View File

@@ -4,22 +4,25 @@ import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
BaseHandler
db *gorm.DB
lock sync.Mutex
}
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db}
h.App = server
return &h
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
@@ -32,11 +35,20 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
}
// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
data.TxId = strings.ReplaceAll(data.TxId, " ", "")
h.lock.Lock()
defer h.lock.Unlock()
var item model.Reward
res := h.db.Where("tx_id = ?", data.TxId).First(&item)
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!")
return
@@ -47,16 +59,13 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.HACKER(c)
return
}
tx := h.db.Begin()
calls := (item.Amount + 0.1) * 10
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -64,13 +73,26 @@ func (h *RewardHandler) Verify(c *gin.Context) {
// 更新核销状态
item.Status = true
item.UserId = user.Id
res = h.db.Updates(&item)
item.Exchange = utils.JsonEncode(exchange)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)

View File

@@ -3,35 +3,45 @@ package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/oss"
"chatplus/service/sd"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
service *sd.Service
redis *redis.Client
pool *sd.ServicePool
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
}
func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler {
h := SdJobHandler{
redis: redisCli,
db: db,
service: service,
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{
pool: pool,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
@@ -39,25 +49,36 @@ func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 删除旧的连接
h.service.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务")
return false
}
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
@@ -67,11 +88,6 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
if !h.App.Config.SdConfig.Enabled {
resp.ERROR(c, "Stable Diffusion service is disabled")
return
}
if !h.checkLimits(c) {
return
}
@@ -105,23 +121,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
params := types.SdTaskParams{
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
Prompt: data.Prompt,
NegativePrompt: data.NegativePrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
taskId, err := h.snowflake.Next(true)
if err != nil {
resp.ERROR(c, "error with generate task id: "+err.Error())
return
}
params := types.SdTaskParams{
TaskId: taskId,
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
}
job := model.SdJob{
UserId: userId,
Type: types.TaskImage.String(),
@@ -129,45 +151,84 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params),
Prompt: data.Prompt,
Progress: 0,
Started: false,
Power: h.App.SysConfig.SdPower,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
}
h.service.PushTask(types.SdTask{
h.pool.PushTask(types.SdTask{
Id: int(job.Id),
SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage,
Prompt: data.Prompt,
Params: params,
UserId: userId,
})
var jobVo vo.SdJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.service.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
// JobList 获取 stable diffusion 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
userId := h.GetInt(c, "user_id", 0)
// ImgWall 照片墙
func (h *SdJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
session := h.db.Session(&gorm.Session{})
if status == 1 {
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
@@ -175,6 +236,9 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
@@ -183,8 +247,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
var items []model.SdJob
res := session.Find(&items)
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
return res.Error, nil
}
var jobs = make([]vo.SdJob, 0)
@@ -194,14 +257,69 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
if err != nil {
continue
}
if item.Progress < 100 {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
h.db.Delete(&item)
continue
// 从 leveldb 中获取图片预览数据
imageData, err := h.leveldb.Get(item.TaskId)
if err == nil {
job.ImgURL = "data:image/png;base64," + string(imageData)
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
return nil, jobs
}
// Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.SdJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *SdJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -4,8 +4,11 @@ import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/sms"
"chatplus/utils"
"chatplus/utils/resp"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
@@ -15,22 +18,31 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct {
BaseHandler
redis *redis.Client
sms *service.AliYunSmsService
sms *sms.ServiceManager
smtp *service.SmtpService
captcha *service.CaptchaService
}
func NewSmsHandler(app *core.AppServer, client *redis.Client, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha}
handler.App = app
return handler
func NewSmsHandler(
app *core.AppServer,
client *redis.Client,
sms *sms.ServiceManager,
smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler {
return &SmsHandler{
redis: client,
sms: sms,
captcha: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
}
// SendCode 发送验证码短信
// SendCode 发送验证码
func (h *SmsHandler) SendCode(c *gin.Context) {
var data struct {
Mobile string `json:"mobile"`
Key string `json:"key"`
Dots string `json:"dots"`
Receiver string `json:"receiver"` // 接收者
Key string `json:"key"`
Dots string `json:"dots"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -43,14 +55,28 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
}
code := utils.RandomNumber(6)
err := h.sms.SendVerifyCode(data.Mobile, code)
var err error
if strings.Contains(data.Receiver, "@") { // email
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!")
return
}
err = h.smtp.SendVerifyCode(data.Receiver, code)
} else {
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!")
return
}
err = h.sms.GetService().SendVerifyCode(data.Receiver, code)
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 存储验证码,等待后面注册验证
_, err = h.redis.Set(c, CodeStorePrefix+data.Mobile, code, 0).Result()
_, err = h.redis.Set(c, CodeStorePrefix+data.Receiver, code, 0).Result()
if err != nil {
resp.ERROR(c, "验证码保存失败")
return
@@ -58,13 +84,3 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.SUCCESS(c)
}
type statusVo struct {
EnabledMsgService bool `json:"enabled_msg_service"`
EnabledRegister bool `json:"enabled_register"`
}
// Status check if the message service is enabled
func (h *SmsHandler) Status(c *gin.Context) {
resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsg, EnabledRegister: h.App.SysConfig.EnabledRegister})
}

View File

@@ -0,0 +1,17 @@
package handler
import (
"chatplus/service"
"chatplus/service/payment"
"gorm.io/gorm"
)
type TestHandler struct {
db *gorm.DB
snowflake *service.Snowflake
js *payment.PayJS
}
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js}
}

View File

@@ -3,29 +3,92 @@ package handler
import (
"chatplus/core"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
BaseHandler
db *gorm.DB
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
handler := &UploadHandler{db: db, uploaderManager: manager}
handler.App = app
return handler
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
fileURL, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, fileURL)
userId := h.GetLoginUserId(c)
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}
func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 {
for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
}
resp.SUCCESS(c, files)
}
// Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}

View File

@@ -8,11 +8,12 @@ import (
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/gin-gonic/gin"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"gorm.io/gorm"
@@ -20,7 +21,6 @@ import (
type UserHandler struct {
BaseHandler
db *gorm.DB
searcher *xdb.Searcher
redis *redis.Client
}
@@ -30,82 +30,112 @@ func NewUserHandler(
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, redis: client}
handler.App = app
return handler
return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
}
// Register user register
func (h *UserHandler) Register(c *gin.Context) {
// parameters process
var data struct {
Mobile string `json:"mobile"`
Password string `json:"password"`
Code string `json:"code"`
RegWay string `json:"reg_way"`
Username string `json:"username"`
Password string `json:"password"`
Code string `json:"code"`
InviteCode string `json:"invite_code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
data.Password = strings.TrimSpace(data.Password)
if len(data.Mobile) < 10 {
resp.ERROR(c, "请输入合法的手机号")
return
}
if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符")
return
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
if h.App.SysConfig.EnabledMsg {
var key string
if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
key = CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
resp.ERROR(c, "验证码错误")
return
}
}
// 验证邀请码
inviteCode := model.InviteCode{}
if data.InviteCode != "" {
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "无效的邀请码")
return
}
}
// check if the username is exists
var item model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&item)
if res.RowsAffected > 0 {
resp.ERROR(c, "该手机号码已经被注册,请更换其他手机号")
res := h.DB.Where("username = ?", data.Username).First(&item)
if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册")
return
}
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
Mobile: data.Mobile,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.UserInitCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
Power: h.App.SysConfig.InitPower,
}
res = h.db.Create(&user)
res = h.DB.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
return
}
if h.App.SysConfig.EnabledMsg {
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 记录邀请关系
if data.InviteCode != "" {
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
// 记录邀请算力充值日志
var inviter model.User
h.DB.Where("id", inviteCode.UserId).First(&inviter)
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
})
}
// 添加邀请记录
h.DB.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
})
}
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
@@ -128,7 +158,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
var data struct {
Mobile string `json:"username"`
Username string `json:"username"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -136,7 +166,7 @@ func (h *UserHandler) Login(c *gin.Context) {
return
}
var user model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&user)
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户名不存在")
return
@@ -156,11 +186,11 @@ func (h *UserHandler) Login(c *gin.Context) {
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.db.Model(&user).Updates(user)
h.DB.Model(&user).Updates(user)
h.db.Create(&model.UserLoginLog{
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Mobile,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
@@ -203,7 +233,7 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
@@ -219,26 +249,23 @@ func (h *UserHandler) Session(c *gin.Context) {
}
type userProfile struct {
Id uint `json:"id"`
Mobile string `json:"mobile"`
Avatar string `json:"avatar"`
ChatConfig types.UserChatConfig `json:"chat_config"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
TotalTokens int64 `json:"total_tokens"`
Tokens int64 `json:"tokens"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
Id uint `json:"id"`
Nickname string `json:"nickname"`
Username string `json:"username"`
Avatar string `json:"avatar"`
Power int `json:"power"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
}
func (h *UserHandler) Profile(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
var profile userProfile
err = utils.CopyObject(user, &profile)
if err != nil {
@@ -258,15 +285,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
user.Avatar = data.Avatar
user.ChatConfig = utils.JsonEncode(data.ChatConfig)
res := h.db.Updates(&user)
user.Nickname = data.Nickname
res := h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新用户信息失败")
return
@@ -291,21 +318,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
password := utils.GenPassword(data.OldPass, user.Salt)
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
if password != user.Password {
resp.ERROR(c, "原密码错误")
return
}
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.db.Model(&user).UpdateColumn("password", newPass)
res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败")
@@ -318,9 +345,9 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
// ResetPass 重置密码
func (h *UserHandler) ResetPass(c *gin.Context) {
var data struct {
Mobile string
Code string // 验证码
Password string // 新密码
Username string `json:"username"`
Code string `json:"code"` // 验证码
Password string `json:"password"` // 新密码
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -328,25 +355,23 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.Where("mobile", data.Mobile).First(&user)
res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户不存在!")
return
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
if h.App.SysConfig.EnabledMsg {
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
return
}
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -355,11 +380,11 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
}
// BindMobile 绑定手机
func (h *UserHandler) BindMobile(c *gin.Context) {
// BindUsername 重置账
func (h *UserHandler) BindUsername(c *gin.Context) {
var data struct {
Mobile string `json:"mobile"`
Code string `json:"code"`
Username string `json:"username"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -367,28 +392,28 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
resp.ERROR(c, "验证码错误")
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&item)
res := h.DB.Where("username = ?", data.Username).First(&item)
if res.Error == nil {
resp.ERROR(c, "该手机号已经被其他账号绑定")
resp.ERROR(c, "该号已经被其他账号绑定")
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res = h.db.Model(&user).UpdateColumn("mobile", data.Mobile)
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return

View File

@@ -8,16 +8,15 @@ import (
"chatplus/handler/chatimpl"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/fun"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/service/payment"
"chatplus/service/sd"
"chatplus/service/sms"
"chatplus/service/wx"
"chatplus/store"
"context"
"embed"
"github.com/go-redis/redis/v8"
"io"
"log"
"os"
@@ -26,6 +25,8 @@ import (
"syscall"
"time"
"github.com/go-redis/redis/v8"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"go.uber.org/fx"
"gorm.io/gorm"
@@ -52,24 +53,24 @@ func (l *AppLifecycle) OnStop(context.Context) error {
return nil
}
func NewAppLifeCycle() *AppLifecycle {
return &AppLifecycle{}
}
func main() {
configFile := os.Getenv("CONFIG_FILE")
if configFile == "" {
configFile = "config.toml"
}
var debug bool
debugEnv := os.Getenv("DEBUG")
if debugEnv == "" {
debug = true
} else {
debug, _ = strconv.ParseBool(os.Getenv("DEBUG"))
}
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
logger.Info("Loading config file: ", configFile)
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
}
}()
if !debug {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
}
}()
}
app := fx.New(
// 初始化配置应用配置
@@ -95,6 +96,7 @@ func main() {
fx.Provide(store.NewGormConfig),
fx.Provide(store.NewMysql),
fx.Provide(store.NewRedisClient),
fx.Provide(store.NewLevelDB),
fx.Provide(func() embed.FS {
return xdbFS
@@ -114,9 +116,6 @@ func main() {
return xdb.NewWithBuffer(cBuff)
}),
// 创建函数
fx.Provide(fun.NewFunctions),
// 创建控制器
fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler),
@@ -131,6 +130,8 @@ func main() {
fx.Provide(handler.NewPaymentHandler),
fx.Provide(handler.NewOrderHandler),
fx.Provide(handler.NewProductHandler),
fx.Provide(handler.NewConfigHandler),
fx.Provide(handler.NewPowerLogHandler),
fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
@@ -142,15 +143,20 @@ func main() {
fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler),
fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler),
// 创建服务
fx.Provide(service.NewAliYunSmsService),
fx.Provide(sms.NewSendServiceManager),
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
return service.NewCaptchaService(config.ApiConfig)
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
// 邮件服务
fx.Provide(service.NewSmtpService),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
@@ -162,36 +168,28 @@ func main() {
}
}),
// MidJourney 机器人
fx.Provide(mj.NewBot),
fx.Provide(mj.NewClient),
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
if config.MjConfig.Enabled {
err := bot.Run()
if err != nil {
log.Fatal("MidJourney 服务启动失败:", err)
}
}
}),
fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
if config.MjConfig.Enabled {
go func() {
mjService.Run()
}()
// MidJourney service pool
fx.Provide(mj.NewServicePool),
fx.Invoke(func(pool *mj.ServicePool) {
if pool.HasAvailableService() {
pool.DownloadImages()
pool.CheckTaskNotify()
pool.SyncTaskProgress()
}
}),
// Stable Diffusion 机器人
fx.Provide(sd.NewService),
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
if config.SdConfig.Enabled {
go func() {
service.Run()
}()
fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool) {
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
}
}),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewPayJS),
fx.Provide(service.NewSnowflake),
fx.Provide(service.NewXXLJobExecutor),
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
@@ -217,7 +215,7 @@ func main() {
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/username", h.BindUsername)
group.POST("resetPass", h.ResetPass)
}),
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
@@ -234,16 +232,19 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
s.Engine.POST("/api/upload", h.Upload)
s.Engine.GET("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/")
group.GET("status", h.Status)
group.POST("code", h.SendCode)
}),
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
group := s.Engine.Group("/api/captcha/")
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
}),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
group := s.Engine.Group("/api/reward/")
@@ -251,17 +252,27 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.Any("client", h.Client)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.Any("client", h.Client)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
}),
// 管理后台控制器
@@ -275,11 +286,17 @@ func main() {
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
@@ -295,11 +312,13 @@ func main() {
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
group.POST("set", h.Set)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/")
group.GET("list", h.List)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/")
@@ -319,10 +338,14 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/")
group.GET("alipay", h.Alipay)
group.GET("doPay", h.DoPay)
group.GET("payWays", h.GetPayWays)
group.POST("query", h.OrderQuery)
group.POST("alipay/qrcode", h.AlipayQrcode)
group.POST("qrcode", h.PayQrcode)
group.POST("mobile", h.Mobile)
group.POST("alipay/notify", h.AlipayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/")
@@ -346,13 +369,82 @@ func main() {
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
fx.Provide(handler.NewInviteHandler),
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/")
group.GET("code", h.Code)
group.POST("list", h.List)
group.GET("hits", h.Hits)
}),
fx.Provide(admin.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
group := s.Engine.Group("/api/admin/function/")
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("token", h.GenToken)
}),
// 验证码
fx.Provide(admin.NewCaptchaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
group := s.Engine.Group("/api/admin/login/")
group.GET("captcha", h.GetCaptcha)
}),
fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload)
}),
fx.Provide(handler.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
group := s.Engine.Group("/api/function/")
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
}),
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
group := s.Engine.Group("/api/powerLog/")
group.POST("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
group := s.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}),
fx.Provide(admin.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
group := s.Engine.Group("/api/admin/menu/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}),
fx.Provide(handler.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
group := s.Engine.Group("/api/menu/")
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
}()
}),
fx.Provide(NewAppLifeCycle),
// 注册生命周期回调函数
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
lifecycle.Append(fx.Hook{

View File

@@ -0,0 +1,38 @@
-----BEGIN CERTIFICATE-----
MIIDszCCApugAwIBAgIQICMRB0rBU2/rZJbfJGMYIzANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDYzNTQxWhcNMjQxMTA2MDYzNTQxWjCB
hDELMAkGA1UEBhMCQ04xHzAdBgNVBAoMFm1ib25meTkwMTVAc2FuZGJveC5jb20xDzANBgNVBAsM
BkFsaXBheTFDMEEGA1UEAww65pSv5LuY5a6dKOS4reWbvSnnvZHnu5zmioDmnK/mnInpmZDlhazl
j7gtMjA4ODcyMTAyMDc1MDU4MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKsoKcw5
sxaiyV7mpWzDtnQ1K518eQLP0+dJlZAf06aBep/Aj9DIqrba/k7DHt8dKQvILMLAMpN1+2IRxbaO
yxMa/laj3lZ1eHrB6F077O3D62oHcE3noZtXL0N1zZAxpmkNmYIHeLZS2oLMS4ANu47O/wpDC7BV
HjdpZugtdPJ4mxdCpM9GDdLs7W4s5QI4PUPK4skFNMFoKI+0cYP/9ju87UP//IHC/K510GWNl+Gn
Cvgag3AmiIB0utJNsGhxm6zT1T9tUWjW9iz/BxBKiPatsCX9VpPQzGnW7ZonRQtiZSokIlP2IPvl
H5DcwpWUz3/LUY0SmKxnKOEYeOOqCW8CAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
DQEBCwUAA4IBAQAtgxF2EzjOndEFxBUD9tFwcSt6XKGggOp52oft1pvynPg4ALTLafOtfEPDrFBH
PwpYrSu9s9C8NJtaA2HrlCfBjIuwEFTXiN+HPvS0SwSPKt9AXEiTcOF8vDcGamEen8QI4fo5Jia7
2VRKkerkww5/+FzSaVO7ZUKuL80M1QJStmAZc8kPPwdYOTTW2bGf8BcmSDL6SPElBkt7tCCRd4sn
+jq4cZ0yb2i77rBZCwHcTvfTqIBblPwLv4uGvg3+83BxIB5w6Kqp06bKEAPmobFY5IVHa+ON0/qi
BXxXr+WQ3piKRVQEN64+PTAjSc67Ix1umvpLl3Ko6Ry7NJmpDcUn
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDszCCApugAwIBAgIQIBkIGbgVxq210KxLJ+YA/TANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UE
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxJTAjBgNVBAsMHENlcnRpZmljYXRpb24gQXV0
aG9yaXR5IHRlc3QxNjA0BgNVBAMMLUFudCBGaW5hbmNpYWwgQ2VydGlmaWNhdGlvbiBBdXRob3Jp
dHkgUjEgdGVzdDAeFw0xOTA4MTkxMTE2MDBaFw0yNDA4MDExMTE2MDBaMIGRMQswCQYDVQQGEwJD
TjEbMBkGA1UECgwSQW50IEZpbmFuY2lhbCB0ZXN0MSUwIwYDVQQLDBxDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSB0ZXN0MT4wPAYDVQQDDDVBbnQgRmluYW5jaWFsIENlcnRpZmljYXRpb24gQXV0aG9y
aXR5IENsYXNzIDIgUjEgdGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMh4FKYO
ZyRQHD6eFbPKZeSAnrfjfU7xmS9Yoozuu+iuqZlb6Z0SPLUqqTZAFZejOcmr07ln/pwZxluqplxC
5+B48End4nclDMlT5HPrDr3W0frs6Xsa2ZNcyil/iKNB5MbGll8LRAxntsKvZZj6vUTMb705gYgm
VUMILwi/ZxKTQqBtkT/kQQ5y6nOZsj7XI5rYdz6qqOROrpvS/d7iypdHOMIM9Iz9DlL1mrCykbBi
t25y+gTeXmuisHUwqaRpwtCGK4BayCqxRGbNipe6W73EK9lBrrzNtTr9NaysesT/v+l25JHCL9tG
wpNr1oWFzk4IHVOg0ORiQ6SUgxZUTYcCAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
DQEBCwUAA4IBAQBWThEoIaQoBX2YeRY/I8gu6TYnFXtyuCljANnXnM38ft+ikhE5mMNgKmJYLHvT
yWWWgwHoSAWEuml7EGbE/2AK2h3k0MdfiWLzdmpPCRG/RJHk6UB1pMHPilI+c0MVu16OPpKbg5Vf
LTv7dsAB40AzKsvyYw88/Ezi1osTXo6QQwda7uefvudirtb8FcQM9R66cJxl3kt1FXbpYwheIm/p
j1mq64swCoIYu4NrsUYtn6CV542DTQMI5QdXkn+PzUUly8F6kDp+KpMNd0avfWNL5+O++z+F5Szy
1CPta1D7EQ/eYmMP+mOQ35oifWIoFCpN6qQVBS/Hob1J/UUyg7BW
-----END CERTIFICATE-----

View File

@@ -0,0 +1,88 @@
-----BEGIN CERTIFICATE-----
MIIBszCCAVegAwIBAgIIaeL+wBcKxnswDAYIKoEcz1UBg3UFADAuMQswCQYDVQQG
EwJDTjEOMAwGA1UECgwFTlJDQUMxDzANBgNVBAMMBlJPT1RDQTAeFw0xMjA3MTQw
MzExNTlaFw00MjA3MDcwMzExNTlaMC4xCzAJBgNVBAYTAkNOMQ4wDAYDVQQKDAVO
UkNBQzEPMA0GA1UEAwwGUk9PVENBMFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAE
MPCca6pmgcchsTf2UnBeL9rtp4nw+itk1Kzrmbnqo05lUwkwlWK+4OIrtFdAqnRT
V7Q9v1htkv42TsIutzd126NdMFswHwYDVR0jBBgwFoAUTDKxl9kzG8SmBcHG5Yti
W/CXdlgwDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFEwysZfZ
MxvEpgXBxuWLYlvwl3ZYMAwGCCqBHM9VAYN1BQADSAAwRQIgG1bSLeOXp3oB8H7b
53W+CKOPl2PknmWEq/lMhtn25HkCIQDaHDgWxWFtnCrBjH16/W3Ezn7/U/Vjo5xI
pDoiVhsLwg==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIF0zCCA7ugAwIBAgIIH8+hjWpIDREwDQYJKoZIhvcNAQELBQAwejELMAkGA1UE
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNVBAsMF0NlcnRpZmlj
YXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5jaWFsIENlcnRpZmlj
YXRpb24gQXV0aG9yaXR5IFIxMB4XDTE4MDMyMTEzNDg0MFoXDTM4MDIyODEzNDg0
MFowejELMAkGA1UEBhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNV
BAsMF0NlcnRpZmljYXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5j
aWFsIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFIxMIICIjANBgkqhkiG9w0BAQEF
AAOCAg8AMIICCgKCAgEAtytTRcBNuur5h8xuxnlKJetT65cHGemGi8oD+beHFPTk
rUTlFt9Xn7fAVGo6QSsPb9uGLpUFGEdGmbsQ2q9cV4P89qkH04VzIPwT7AywJdt2
xAvMs+MgHFJzOYfL1QkdOOVO7NwKxH8IvlQgFabWomWk2Ei9WfUyxFjVO1LVh0Bp
dRBeWLMkdudx0tl3+21t1apnReFNQ5nfX29xeSxIhesaMHDZFViO/DXDNW2BcTs6
vSWKyJ4YIIIzStumD8K1xMsoaZBMDxg4itjWFaKRgNuPiIn4kjDY3kC66Sl/6yTl
YUz8AybbEsICZzssdZh7jcNb1VRfk79lgAprm/Ktl+mgrU1gaMGP1OE25JCbqli1
Pbw/BpPynyP9+XulE+2mxFwTYhKAwpDIDKuYsFUXuo8t261pCovI1CXFzAQM2w7H
DtA2nOXSW6q0jGDJ5+WauH+K8ZSvA6x4sFo4u0KNCx0ROTBpLif6GTngqo3sj+98
SZiMNLFMQoQkjkdN5Q5g9N6CFZPVZ6QpO0JcIc7S1le/g9z5iBKnifrKxy0TQjtG
PsDwc8ubPnRm/F82RReCoyNyx63indpgFfhN7+KxUIQ9cOwwTvemmor0A+ZQamRe
9LMuiEfEaWUDK+6O0Gl8lO571uI5onYdN1VIgOmwFbe+D8TcuzVjIZ/zvHrAGUcC
AwEAAaNdMFswCwYDVR0PBAQDAgEGMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFF90
tATATwda6uWx2yKjh0GynOEBMB8GA1UdIwQYMBaAFF90tATATwda6uWx2yKjh0Gy
nOEBMA0GCSqGSIb3DQEBCwUAA4ICAQCVYaOtqOLIpsrEikE5lb+UARNSFJg6tpkf
tJ2U8QF/DejemEHx5IClQu6ajxjtu0Aie4/3UnIXop8nH/Q57l+Wyt9T7N2WPiNq
JSlYKYbJpPF8LXbuKYG3BTFTdOVFIeRe2NUyYh/xs6bXGr4WKTXb3qBmzR02FSy3
IODQw5Q6zpXj8prYqFHYsOvGCEc1CwJaSaYwRhTkFedJUxiyhyB5GQwoFfExCVHW
05ZFCAVYFldCJvUzfzrWubN6wX0DD2dwultgmldOn/W/n8at52mpPNvIdbZb2F41
T0YZeoWnCJrYXjq/32oc1cmifIHqySnyMnavi75DxPCdZsCOpSAT4j4lAQRGsfgI
kkLPGQieMfNNkMCKh7qjwdXAVtdqhf0RVtFILH3OyEodlk1HYXqX5iE5wlaKzDop
PKwf2Q3BErq1xChYGGVS+dEvyXc/2nIBlt7uLWKp4XFjqekKbaGaLJdjYP5b2s7N
1dM0MXQ/f8XoXKBkJNzEiM3hfsU6DOREgMc1DIsFKxfuMwX3EkVQM1If8ghb6x5Y
jXayv+NLbidOSzk4vl5QwngO/JYFMkoc6i9LNwEaEtR9PhnrdubxmrtM+RjfBm02
77q3dSWFESFQ4QxYWew4pHE0DpWbWy/iMIKQ6UZ5RLvB8GEcgt8ON7BBJeMc+Dyi
kT9qhqn+lw==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIICiDCCAgygAwIBAgIIQX76UsB/30owDAYIKoZIzj0EAwMFADB6MQswCQYDVQQG
EwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UECwwXQ2VydGlmaWNh
dGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNpYWwgQ2VydGlmaWNh
dGlvbiBBdXRob3JpdHkgRTEwHhcNMTkwNDI4MTYyMDQ0WhcNNDkwNDIwMTYyMDQ0
WjB6MQswCQYDVQQGEwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UE
CwwXQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNp
YWwgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgRTEwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAASCCRa94QI0vR5Up9Yr9HEupz6hSoyjySYqo7v837KnmjveUIUNiuC9pWAU
WP3jwLX3HkzeiNdeg22a0IZPoSUCpasufiLAnfXh6NInLiWBrjLJXDSGaY7vaokt
rpZvAdmjXTBbMAsGA1UdDwQEAwIBBjAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRZ
4ZTgDpksHL2qcpkFkxD2zVd16TAfBgNVHSMEGDAWgBRZ4ZTgDpksHL2qcpkFkxD2
zVd16TAMBggqhkjOPQQDAwUAA2gAMGUCMQD4IoqT2hTUn0jt7oXLdMJ8q4vLp6sg
wHfPiOr9gxreb+e6Oidwd2LDnC4OUqCWiF8CMAzwKs4SnDJYcMLf2vpkbuVE4dTH
Rglz+HGcTLWsFs4KxLsq7MuU+vJTBUeDJeDjdA==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDxTCCAq2gAwIBAgIUEMdk6dVgOEIS2cCP0Q43P90Ps5YwDQYJKoZIhvcNAQEF
BQAwajELMAkGA1UEBhMCQ04xEzARBgNVBAoMCmlUcnVzQ2hpbmExHDAaBgNVBAsM
E0NoaW5hIFRydXN0IE5ldHdvcmsxKDAmBgNVBAMMH2lUcnVzQ2hpbmEgQ2xhc3Mg
MiBSb290IENBIC0gRzMwHhcNMTMwNDE4MDkzNjU2WhcNMzMwNDE4MDkzNjU2WjBq
MQswCQYDVQQGEwJDTjETMBEGA1UECgwKaVRydXNDaGluYTEcMBoGA1UECwwTQ2hp
bmEgVHJ1c3QgTmV0d29yazEoMCYGA1UEAwwfaVRydXNDaGluYSBDbGFzcyAyIFJv
b3QgQ0EgLSBHMzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOPPShpV
nJbMqqCw6Bz1kehnoPst9pkr0V9idOwU2oyS47/HjJXk9Rd5a9xfwkPO88trUpz5
4GmmwspDXjVFu9L0eFaRuH3KMha1Ak01citbF7cQLJlS7XI+tpkTGHEY5pt3EsQg
wykfZl/A1jrnSkspMS997r2Gim54cwz+mTMgDRhZsKK/lbOeBPpWtcFizjXYCqhw
WktvQfZBYi6o4sHCshnOswi4yV1p+LuFcQ2ciYdWvULh1eZhLxHbGXyznYHi0dGN
z+I9H8aXxqAQfHVhbdHNzi77hCxFjOy+hHrGsyzjrd2swVQ2iUWP8BfEQqGLqM1g
KgWKYfcTGdbPB1MCAwEAAaNjMGEwHQYDVR0OBBYEFG/oAMxTVe7y0+408CTAK8hA
uTyRMB8GA1UdIwQYMBaAFG/oAMxTVe7y0+408CTAK8hAuTyRMA8GA1UdEwEB/wQF
MAMBAf8wDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3DQEBBQUAA4IBAQBLnUTfW7hp
emMbuUGCk7RBswzOT83bDM6824EkUnf+X0iKS95SUNGeeSWK2o/3ALJo5hi7GZr3
U8eLaWAcYizfO99UXMRBPw5PRR+gXGEronGUugLpxsjuynoLQu8GQAeysSXKbN1I
UugDo9u8igJORYA+5ms0s5sCUySqbQ2R5z/GoceyI9LdxIVa1RjVX8pYOj8JFwtn
DJN3ftSFvNMYwRuILKuqUYSHc2GPYiHVflDh5nDymCMOQFcFG3WsEuB+EYQPFgIU
1DHmdZcz7Llx8UOZXX2JupWCYzK1XhJb+r4hK5ncf/w8qGtYlmyJpxk3hr1TfUJX
Yf4Zr0fJsGuv
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDmTCCAoGgAwIBAgIQICMRB2LW76yahgdg3IFNPDANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDU0NjE5WhcNMjQxMTExMDU0NjE5WjBr
MQswCQYDVQQGEwJDTjEfMB0GA1UECgwWbWJvbmZ5OTAxNUBzYW5kYm94LmNvbTEPMA0GA1UECwwG
QWxpcGF5MSowKAYDVQQDDCEyMDg4NzIxMDIwNzUwNTgxLTkwMjEwMDAxMzE2NTgwMjMwggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCxihQPf1Q+g9ArgM46shVqL5sbRha/df95D1PsWyEq
ANmWmG4zZ+ksYDVQrc4KzhSRoi56sm/7TDFYTmM6bW99e/nKW58WxyZB4ie5qA3F4n17psPyDqb8
IokcQmCphSFDaXQD6AoXoLNtTM0vAI2cWxAgebZ/vsrdj5Ntjt+Rp3NYMCk1i5xovHcfILzLEGbX
QXoT9fo5AhHotTWa6xHVLPUGY9qwLzQxHzBmvy5ZMfnOfJkm/mDisTSqAUB59F3dzU/1ARVkEZ1w
Mgb4XohWBw6iurQfbMnH2mIomAAwwZVFv+sXDbL9yMbSMo/SjVsTQprn0Q0EnwLo7nmmOM6HAgMB
AAGjEjAQMA4GA1UdDwEB/wQEAwIE8DANBgkqhkiG9w0BAQsFAAOCAQEAn3Y4/C1h9R6ONsBqX3/q
XfHX7yX1FM0Y1x48X3/Yxk6HivAkTukhhhVYVKJsbrbzRqHDp9vhAP/FR6o6pAevaYMmLov0VMXU
7oAuetgkaYEYkDuNen5/Hpdhqi2vTtdT+q9w8zHJd6MDQ0aoHgIxpLKw5vof2R1N4fwSgNXMiXE5
kmllKQMem/+on2p+Sj80/2asxryHIGlH87qPzkffv+kIOkZthbTApTFLLjdVri2QHGe8/cc4xy01
/9iR3IUzNahotT41lJ4bMevBY7XMAS3n5ekyABN/9ZRJqhWdXgmFCRN/u56qd6lDgu7R2M2QUoyc
LuW5DfgRItKlmUB7sw==
-----END CERTIFICATE-----

View File

@@ -0,0 +1 @@
MIIEpQIBAAKCAQEAsYoUD39UPoPQK4DOOrIVai+bG0YWv3X/eQ9T7FshKgDZlphuM2fpLGA1UK3OCs4UkaIuerJv+0wxWE5jOm1vfXv5ylufFscmQeInuagNxeJ9e6bD8g6m/CKJHEJgqYUhQ2l0A+gKF6CzbUzNLwCNnFsQIHm2f77K3Y+TbY7fkadzWDApNYucaLx3HyC8yxBm10F6E/X6OQIR6LU1musR1Sz1BmPasC80MR8wZr8uWTH5znyZJv5g4rE0qgFAefRd3c1P9QEVZBGdcDIG+F6IVgcOorq0H2zJx9piKJgAMMGVRb/rFw2y/cjG0jKP0o1bE0Ka59ENBJ8C6O55pjjOhwIDAQABAoIBAFetNfz1R7hbxjlFshMAkVzQR8wvT9qbvl+dtzdZRcaFhu89NecDIP7+QDYor0FcxoGpU0TazDyRQyk2BQD8vHt+9zv9BVLtZLJSqoWgPbUFBi1DjS8EF2ka8RVYnn35NhUhhd7L//ftL88Bh673mfembQ9srDjoEy1Z01feoABAnCMkNFl986DmEwnarvEufXSDIgeN4ioMxha4NvfIPuI0zpVdV1O9sv+SGC+VEWZBtN3GNsaf4zS/f8FVGvTiU/Abz0gSw/iwSPHclDWQDTN3yFHf/tfqlzh0mH0WfhnuOBFWXzK+R7fbnM+asI9ttvzRcfpzgRGXdPcNcOv/6cECgYEA3DVqpi1k8MYfJixju6SG5gfyhM4VFksFmCMaNPgtatDMBKLMTgV/Ej6LXREojcy29uZl83F09pVlpd41eG39ULIPktixA/BqErQ2UaWh6kOxifycpu22Jh0r09hax6UgVrcBrrnCJEjcFsuJlrZvXQSzc3PBxjWy5gjabS5h9iECgYEAzmVAIh2frF01Y95zsLueAhhZwCtPanm6kf7ivR4r1plIX3b2sNRhWGmEHFgaCE6Braa0ogQ73Hd26kw4ZW+D6QMGC/zjCBEzDLLf++SjdVUHiY5AR4WHqXzq1jdAlsVyo9R661oAOp3lhiJVGLNXkHyEfEVPHsaxJh4osYSbX6cCgYEAx32Qx0i6eDFTyLZQB46uMrgiaVN04QRH5iJuvGvUYT8UhGKjaU8rZfDJOh+wOH2rhxMEaz1uc3C2bERY9mfWI4Ob/jFWc7YZsiYWS3Mcsuhubw4tMECLUg39RWZsHw8ls8kIuixIh6yFzhTH6YQOcRswIrhMZG8DScfdcSmiz2ECgYEAkWP1t5KSpkLKl11etcKUXfl1T8+yk9jIOowIgRw92WAFAWq2AH67TCKYM7dEL1HOO9tRJ0hAOt/U3ttuZtYVYBEHM26jJ02mXm2rJrA7DS4mrxmL4lYH6LbcXqZxU0Qnq4zEQgIWYzRTORf6Rfof1uJAGaJhR9bDd4yLMfGt2cUCgYEAo216Y61xOHUTA4AF1eekk+r+uOcQgQDvLXfs9FkDdJLk0mPG48/+eIYpPFnANJ/riF/DWOp8WGEe2IzA9yUFexzDbNQK8ha9kGcxaSAyiCwzjZ/t9/+hScDSV8kNqWSRSisu/YOFleEHbokT6mbLZ+gdqES8mUUanaEBzRQYGxo=

BIN
api/res/img/wechat-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 KiB

View File

@@ -1,67 +0,0 @@
{
"data": [
"task(m1wpaa4v60zedj8)",
"a cute cat",
"",
[],
20,
"DPM++ 2M Karras",
1,
1,
7,
512,
384,
true,
0.7,
2,
"ESRGAN_4x",
10,
0,
0,
"Use same checkpoint",
"Use same sampler",
"",
"",
[],
"None",
false,
"",
0.8,
-1,
false,
-1,
0,
0,
0,
false,
false,
"positive",
"comma",
0,
false,
false,
"",
"Seed",
"",
[],
"Nothing",
"",
[],
"Nothing",
"",
[],
true,
false,
false,
false,
0,
false,
[],
"",
"",
""
],
"event_data": null,
"fn_index": 96,
"session_hash": "kmb0ojjfhdj"
}

View File

@@ -60,3 +60,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true
}
func (s *CaptchaService) SlideGet() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
}
if res.Code != types.Success {
return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
}
return res.Data, nil
}
func (s *CaptchaService) SlideCheck(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return false
}
if res.Code != types.Success {
return false
}
return true
}

View File

@@ -1,42 +0,0 @@
package fun
import (
"chatplus/core/types"
"chatplus/service/mj"
"chatplus/utils"
)
// AI 绘画函数
type FuncMidJourney struct {
name string
service *mj.Service
}
func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney {
return FuncMidJourney{
name: "MidJourney AI 绘画",
service: mjService}
}
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"])
f.service.PushTask(types.MjTask{
SessionId: utils.InterfaceToString(params["session_id"]),
Src: types.TaskSrcChat,
Type: types.TaskImage,
Prompt: prompt,
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
Icon: utils.InterfaceToString(params["icon"]),
ChatId: utils.InterfaceToString(params["chat_id"]),
})
return prompt, nil
}
func (f FuncMidJourney) Name() string {
return f.name
}
var _ Function = &FuncMidJourney{}

View File

@@ -1,39 +0,0 @@
package fun
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/service/mj"
)
type Function interface {
Invoke(map[string]interface{}) (string, error)
Name() string
}
var logger = logger2.GetLogger()
type resVo struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data struct {
Title string `json:"title"`
UpdatedAt string `json:"updated_at"`
Items []dataItem `json:"items"`
} `json:"data"`
}
type dataItem struct {
Title string `json:"title"`
Url string `json:"url"`
Remark string `json:"remark"`
}
func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function {
return map[string]Function{
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
types.FuncHeadLine: NewHeadLines(config.ApiConfig),
types.FuncMidJourney: NewMidJourneyFunc(mjService),
}
}

View File

@@ -1,58 +0,0 @@
package fun
import (
"chatplus/core/types"
"errors"
"fmt"
"github.com/imroc/req/v3"
"strings"
"time"
)
// 今日头条函数实现
type FuncHeadlines struct {
name string
config types.ChatPlusApiConfig
client *req.Client
}
func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
return FuncHeadlines{
name: "今日头条",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/headline/fetch", f.config.ApiURL)
var res resVo
r, err := f.client.R().
SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", err, r.Err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
for i, v := range res.Data.Items {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Url, v.Remark))
}
return strings.Join(builder, "\n\n"), nil
}
func (f FuncHeadlines) Name() string {
return f.name
}
var _ Function = &FuncHeadlines{}

View File

@@ -1,58 +0,0 @@
package fun
import (
"chatplus/core/types"
"errors"
"fmt"
"github.com/imroc/req/v3"
"strings"
"time"
)
// 微博热搜函数实现
type FuncWeiboHot struct {
name string
config types.ChatPlusApiConfig
client *req.Client
}
func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
return FuncWeiboHot{
name: "微博热搜",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/weibo/fetch", f.config.ApiURL)
var res resVo
r, err := f.client.R().
SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", err, r.Err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
for i, v := range res.Data.Items {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
}
return strings.Join(builder, "\n\n"), nil
}
func (f FuncWeiboHot) Name() string {
return f.name
}
var _ Function = &FuncWeiboHot{}

View File

@@ -1,59 +0,0 @@
package fun
import (
"chatplus/core/types"
"errors"
"fmt"
"github.com/imroc/req/v3"
"strings"
"time"
)
// 每日早报函数实现
type FuncZaoBao struct {
name string
config types.ChatPlusApiConfig
client *req.Client
}
func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
return FuncZaoBao{
name: "每日早报",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/zaobao/fetch", f.config.ApiURL)
var res resVo
r, err := f.client.R().
SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", err, r.Err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
for _, v := range res.Data.Items {
builder = append(builder, v.Title)
}
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
return strings.Join(builder, "\n\n"), nil
}
func (f FuncZaoBao) Name() string {
return f.name
}
var _ Function = &FuncZaoBao{}

View File

@@ -1,213 +0,0 @@
package mj
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"net/http"
"net/url"
"regexp"
"strings"
)
// MidJourney 机器人
var logger = logger2.GetLogger()
type Bot struct {
config *types.MidJourneyConfig
bot *discordgo.Session
service *Service
}
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
if err != nil {
return nil, err
}
if config.ProxyURL != "" {
proxy, _ := url.Parse(config.ProxyURL)
discord.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
discord.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
}
}
return &Bot{
config: &config.MjConfig,
bot: discord,
service: service,
}, nil
}
func (b *Bot) Run() error {
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate)
logger.Info("Starting MidJourney Bot...")
err := b.bot.Open()
if err != nil {
logger.Error("Error opening Discord connection:", err)
return err
}
logger.Info("Starting MidJourney Bot successfully!")
return nil
}
type TaskStatus string
const (
Start = TaskStatus("Started")
Running = TaskStatus("Running")
Stopped = TaskStatus("Stopped")
Finished = TaskStatus("Finished")
)
type Image struct {
URL string `json:"url"`
ProxyURL string `json:"proxy_url"`
Filename string `json:"filename"`
Width int `json:"width"`
Height int `json:"height"`
Size int `json:"size"`
Hash string `json:"hash"`
}
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
// parse content
req := CBReq{
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: 0,
Status: Start}
b.service.Notify(req)
return
}
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Stopped)") {
req := CBReq{
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: extractProgress(m.Content),
Status: Stopped}
b.service.Notify(req)
return
}
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
progress := extractProgress(content)
var status TaskStatus
if progress == 100 {
status = Finished
} else {
status = Running
}
for _, attachment := range attachments {
if attachment.Width == 0 || attachment.Height == 0 {
continue
}
image := Image{
URL: attachment.URL,
Height: attachment.Height,
ProxyURL: attachment.ProxyURL,
Width: attachment.Width,
Size: attachment.Size,
Filename: attachment.Filename,
Hash: extractHashFromFilename(attachment.Filename),
}
req := CBReq{
MessageId: messageId,
ReferenceId: referenceId,
Image: image,
Prompt: extractPrompt(content),
Content: content,
Progress: progress,
Status: status,
}
b.service.Notify(req)
break // only get one image
}
}
// extract prompt from string
func extractPrompt(input string) string {
pattern := `\*\*(.*?)\*\*`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
func extractProgress(input string) int {
pattern := `\((\d+)\%\)`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return utils.IntValue(matches[1], 0)
}
return 100
}
func extractHashFromFilename(filename string) string {
if !strings.HasSuffix(filename, ".png") {
return ""
}
index := strings.LastIndex(filename, "_")
if index != -1 {
return filename[index+1 : len(filename)-4]
}
return ""
}

View File

@@ -1,144 +1,61 @@
package mj
import (
"chatplus/core/types"
"fmt"
"github.com/imroc/req/v3"
"time"
)
import "chatplus/core/types"
// MidJourney client
type Client struct {
client *req.Client
config *types.MidJourneyConfig
type Client interface {
Imagine(task types.MjTask) (ImageRes, error)
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error)
Upscale(task types.MjTask) (ImageRes, error)
Variation(task types.MjTask) (ImageRes, error)
QueryTask(taskId string) (QueryRes, error)
}
func NewClient(config *types.AppConfig) *Client {
client := req.C().SetTimeout(10 * time.Second)
// set proxy URL
if config.ProxyURL != "" {
client.SetProxyURL(config.ProxyURL)
}
return &Client{client: client, config: &config.MjConfig}
type ImageReq struct {
BotType string `json:"botType,omitempty"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter interface{} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook,omitempty"`
State string `json:"state,omitempty"`
}
func (c *Client) Imagine(prompt string) error {
interactionsReq := &InteractionsRequest{
Type: 2,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
SessionID: SessionID,
Data: map[string]any{
"version": "1166847114203123795",
"id": "938956540159881230",
"name": "imagine",
"type": "1",
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"value": prompt,
},
},
"application_command": map[string]any{
"id": "938956540159881230",
"application_id": ApplicationID,
"version": "1118961510123847772",
"default_permission": true,
"default_member_permissions": nil,
"type": 1,
"nsfw": false,
"name": "imagine",
"description": "Create images with Midjourney",
"dm_permission": true,
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"description": "The prompt to imagine",
"required": true,
},
},
"attachments": []any{},
},
},
}
url := "https://discord.com/api/v9/interactions"
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %w%v", err, r.Err)
}
return nil
type ImageRes struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
type QueryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}

View File

@@ -0,0 +1,230 @@
package mj
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"errors"
"fmt"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
}
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
}
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.Config.Mode, c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.Config.Mode, c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &PlusClient{}

204
api/service/mj/pool.go Normal file
View File

@@ -0,0 +1,204 @@
package mj
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
for k, config := range appConfig.MjPlusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
service := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
service.Run()
}()
services = append(services, service)
}
for k, config := range appConfig.MjProxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
service := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
cli := p.Clients.Get(userId)
if cli == nil {
continue
}
err = cli.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
}
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
if res.Error != nil {
continue
}
for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
// 删除任务
p.db.Delete(&job)
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second)
}
}()
}
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if s.Name == name {
return s
}
}
return nil
}

View File

@@ -0,0 +1,178 @@
package mj
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"errors"
"fmt"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -2,248 +2,179 @@ package mj
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/service"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
"strings"
"time"
"gorm.io/gorm"
)
// MJ 绘画服务
const RunningJobKey = "MidJourney_Running_Job"
// Service MJ 绘画服务
type Service struct {
client *Client // MJ 客户端
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
proxyURL string
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
}
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
return &Service{
redis: redisCli,
db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
client: client,
uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](),
ChatClients: types.NewLMap[string, *types.WsClient](),
proxyURL: config.ProxyURL,
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
}
}
func (s *Service) Run() {
logger.Info("Starting MidJourney job consumer.")
ctx := context.Background()
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err = s.taskQueue.LPop(&task)
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
break
case types.TaskUpscale:
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
break
case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
}
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
// 更新任务的执行状态
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
}
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt))
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
func (s *Service) PushTask(task types.MjTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Notify(data CBReq) {
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err)
return
}
var task types.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
return
}
var job model.MidJourneyJob
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil && data.Status == Finished {
logger.Warn("重复消息:", data.MessageId)
return
}
if task.Src == types.TaskSrcImg { // 绘画任务
var job model.MidJourneyJob
res := s.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
// 任务完成,将最终的图片下载下来
if data.Progress == 100 {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
} else {
// 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = data.Image.URL
}
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
if data.Progress < 100 {
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
// 推送任务到前端
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
} else if task.Src == types.TaskSrcChat { // 聊天任务
wsClient := s.ChatClients.Get(task.SessionId)
if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
utils.ReplyMessage(wsClient, content)
}
// download image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download image: ", err)
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content)
}
return
}
tx := s.db.Begin()
data.Image.URL = imgURL
message := model.HistoryMessage{
UserId: uint(task.UserId),
ChatId: task.ChatId,
RoleId: uint(task.RoleId),
Type: types.MjMsg,
Icon: task.Icon,
Content: utils.JsonEncode(data),
Tokens: 0,
UseContext: false,
}
res = tx.Create(&message)
if res.Error != nil {
logger.Error("error with update database: ", err)
return
}
// save the job
job.UserId = task.UserId
job.Type = task.Type.String()
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Prompt = data.Prompt
job.ImgURL = imgURL
job.Progress = data.Progress
job.Hash = data.Image.Hash
job.CreatedAt = time.Now()
res = tx.Create(&job)
if res.Error != nil {
logger.Error("error with update database: ", err)
tx.Rollback()
return
}
tx.Commit()
}
if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data)
return
}
if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// 本次绘画完毕,移除客户端
s.ChatClients.Delete(task.SessionId)
} else {
// 使用代理临时转发图片
if data.Image.URL != "" {
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
if err == nil {
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(task.UserId)
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
// 更新用户剩余绘图次数
// TODO: 放大图片是否需要消耗绘图次数?
if data.Status == Finished {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 解除任务锁定
s.redis.Del(context.Background(), RunningJobKey)
}
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(job.UserId)
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -1,34 +0,0 @@
package mj
const (
ApplicationID string = "936929561302675456"
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
)
type InteractionsRequest struct {
Type int `json:"type"`
ApplicationID string `json:"application_id"`
MessageFlags *int `json:"message_flags,omitempty"`
MessageID *string `json:"message_id,omitempty"`
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
SessionID string `json:"session_id"`
Data map[string]any `json:"data"`
Nonce string `json:"nonce,omitempty"`
}
type InteractionsResult struct {
Code int `json:"code"`
Message string
Error map[string]any
}
type CBReq struct {
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`
Content string `json:"content"`
Prompt string `json:"prompt"`
Status TaskStatus `json:"status"`
Progress int `json:"progress"`
}

View File

@@ -4,12 +4,15 @@ import (
"bytes"
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/gin-gonic/gin"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/gin-gonic/gin"
)
type AliYunOss struct {
@@ -32,6 +35,10 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
return nil, err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return &AliYunOss{
config: config,
bucket: bucket,
@@ -40,28 +47,34 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
}
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
if err != nil {
return "", err
return File{}, err
}
// 打开上传文件
src, err := file.Open()
if err != nil {
return "", err
return File{}, err
}
defer src.Close()
fileExt := filepath.Ext(file.Filename)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件
err = s.bucket.PutObject(objectKey, src)
if err != nil {
return "", err
return File{}, err
}
return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil
return File{
Name: file.Filename,
ObjKey: objectKey,
URL: fmt.Sprintf("%s/%s", s.config.Domain, objectKey),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -79,19 +92,39 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
return "", err
}
return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
return s.bucket.DeleteObject(objectName)
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.bucket.DeleteObject(objectKey)
}
var _ Uploader = AliYunOss{}

View File

@@ -3,6 +3,7 @@ package oss
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"net/url"
@@ -23,23 +24,30 @@ func NewLocalStorage(config *types.AppConfig) LocalStorage {
}
}
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) {
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {
return "", fmt.Errorf("error with get form: %v", err)
return File{}, fmt.Errorf("error with get form: %v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, file.Filename)
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
if err != nil {
return "", fmt.Errorf("error with generate filename: %s", err.Error())
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
}
// 将文件保存到指定路径
err = ctx.SaveUploadedFile(file, filePath)
err = ctx.SaveUploadedFile(file, path)
if err != nil {
return "", fmt.Errorf("error with save upload file: %s", err.Error())
return File{}, fmt.Errorf("error with save upload file: %s", err.Error())
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
ext := filepath.Ext(file.Filename)
return File{
Name: file.Filename,
ObjKey: path,
URL: utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path),
Ext: ext,
Size: file.Size,
}, nil
}
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -48,7 +56,7 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
filename := filepath.Base(parse.Path)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
if err != nil {
return "", fmt.Errorf("error with generate image dir: %v", err)
}
@@ -65,7 +73,24 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
err = os.WriteFile(filePath, imageData, 0644)
if err != nil {
return "", fmt.Errorf("error writing to file:%v", err)
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) Delete(fileURL string) error {
if _, err := os.Stat(fileURL); err == nil {
return os.Remove(fileURL)
}
filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1)
return os.Remove(filePath)
}

View File

@@ -4,14 +4,16 @@ import (
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type MiniOss struct {
@@ -29,6 +31,9 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
if err != nil {
return MiniOss{}, err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
}
@@ -48,7 +53,7 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
@@ -62,33 +67,64 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {
return "", fmt.Errorf("error with get form: %v", err)
return File{}, fmt.Errorf("error with get form: %v", err)
}
// Open the uploaded file
fileReader, err := file.Open()
if err != nil {
return "", fmt.Errorf("error opening file: %v", err)
return File{}, fmt.Errorf("error opening file: %v", err)
}
defer fileReader.Close()
fileExt := filepath.Ext(file.Filename)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
fileExt := utils.GetImgExt(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Content-Type"),
})
if err != nil {
return "", fmt.Errorf("error uploading to MinIO: %v", err)
return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
}
return File{
Name: file.Filename,
ObjKey: info.Key,
URL: fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s MiniOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
objectKey,
strings.NewReader(string(imageData)),
int64(len(imageData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
return s.client.RemoveObject(context.Background(), s.config.Bucket, objectName, minio.RemoveObjectOptions{})
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{})
}
var _ Uploader = MiniOss{}

View File

@@ -5,13 +5,16 @@ import (
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"net/url"
"path/filepath"
"time"
)
type QinNiuOss struct {
@@ -21,7 +24,6 @@ type QinNiuOss struct {
uploader *storage.FormUploader
manager *storage.BucketManager
proxyURL string
dir string
}
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
@@ -38,6 +40,9 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
putPolicy := storage.PutPolicy{
Scope: config.Bucket,
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return QinNiuOss{
config: config,
mac: mac,
@@ -45,34 +50,40 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
uploader: formUploader,
manager: storage.NewBucketManager(mac, &storeConfig),
proxyURL: appConfig.ProxyURL,
dir: "chatgpt-plus",
}
}
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
if err != nil {
return "", err
return File{}, err
}
// 打开上传文件
src, err := file.Open()
if err != nil {
return "", err
return File{}, err
}
defer src.Close()
fileExt := filepath.Ext(file.Filename)
key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件
ret := storage.PutRet{}
extra := storage.PutExtra{}
err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra)
if err != nil {
return "", err
return File{}, err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
return File{
Name: file.Filename,
ObjKey: key,
URL: fmt.Sprintf("%s/%s", s.config.Domain, ret.Key),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -90,8 +101,8 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt)
fileExt := utils.GetImgExt(parse.Path)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
@@ -102,10 +113,32 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.dir, objectName)
return s.manager.Delete(s.config.Bucket, key)
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.manager.Delete(s.config.Bucket, objectKey)
}
var _ Uploader = QinNiuOss{}

View File

@@ -2,8 +2,21 @@ package oss
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
type File struct {
Name string `json:"name"`
ObjKey string `json:"obj_key"`
Size int64 `json:"size"`
URL string `json:"url"`
Ext string `json:"ext"`
}
type Uploader interface {
PutFile(ctx *gin.Context, name string) (string, error)
PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@@ -9,11 +9,6 @@ type UploaderManager struct {
handler Uploader
}
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
active := Local
if config.OSS.Active != "" {

View File

@@ -0,0 +1,162 @@
package payment
import (
"chatplus/core/types"
"chatplus/utils"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
type HuPiPayService struct {
appId string
appSecret string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
type HuPiPayReq struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
TotalFee string `json:"total_fee"`
Title string `json:"title"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
WapName string `json:"wap_name"`
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
}
type HuPiResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
params.Time = simple
params.NonceStr = simple
encode := utils.JsonEncode(params)
m := make(map[string]string)
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
}
var res HuPiResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(tradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("open_order_id", tradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
}

View File

@@ -0,0 +1,148 @@
package payment
import (
"chatplus/core/types"
"chatplus/utils"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
type PayJS struct {
config *types.JPayConfig
}
func NewPayJS(appConfig *types.AppConfig) *PayJS {
return &PayJS{
config: &appConfig.JPayConfig,
}
}
type JPayReq struct {
TotalFee int `json:"total_fee"`
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"body"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"callback_url"`
}
type JPayReps struct {
OutTradeNo string `json:"out_trade_no"`
OrderId string `json:"payjs_order_id"`
ReturnCode int `json:"return_code"`
ReturnMsg string `json:"return_msg"`
Sign string `json:"Sign"`
TotalFee string `json:"total_fee"`
CodeUrl string `json:"code_url,omitempty"`
Qrcode string `json:"qrcode,omitempty"`
}
func (r JPayReps) IsOK() bool {
return r.ReturnMsg == "SUCCESS"
}
func (js *PayJS) Pay(param JPayReq) JPayReps {
param.NotifyURL = js.config.NotifyURL
var p = url.Values{}
encode := utils.JsonEncode(param)
m := make(map[string]interface{})
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
p.Add(k, fmt.Sprintf("%v", v))
}
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
cli := http.Client{}
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
r, err := cli.PostForm(apiURL, p)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
defer r.Body.Close()
bs, err := io.ReadAll(r.Body)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
var data JPayReps
err = utils.JsonDecode(string(bs), &data)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
return data
}
func (js *PayJS) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *PayJS) sign(params url.Values) string {
params.Del(`sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += "&key=" + js.config.PrivateKey
md5bs := md5.Sum([]byte(src))
md5res := hex.EncodeToString(md5bs[:])
return strings.ToUpper(md5res)
}
// Check 查询订单支付状态
// @param tradeNo 支付平台交易 ID
func (js *PayJS) Check(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{}
params.Add("payjs_order_id", tradeNo)
params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
defer resp.Body.Close()
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ReturnCode int `json:"return_code"`
Status int `json:"status"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ReturnCode == 1 && r.Status == 1 {
return nil
} else {
logger.Errorf("PayJs 支付验证响应:%s", string(body))
return errors.New("order not paid")
}
}

124
api/service/sd/pool.go Normal file
View File

@@ -0,0 +1,124 @@
package sd
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for _, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB)
// run sd service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
p.db.Delete(&job)
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
}
}()
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -2,308 +2,218 @@ package sd
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"os"
"strconv"
"strings"
"time"
)
// SD 绘画服务
const RunningJobKey = "StableDiffusion_Running_Job"
type Service struct {
httpClient *req.Client
config *types.StableDiffusionConfig
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
redis *redis.Client
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池
name string // service name
leveldb *store.LevelDB
}
func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{
config: &config.SdConfig,
name: name,
config: config,
httpClient: req.C(),
redis: redisCli,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
leveldb: levelDB,
uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](),
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
}
}
func (s *Service) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
continue
}
// 更新任务的执行状态
s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(task.UserId)
continue
}
}
}
// PushTask 推送任务到队列
func (s *Service) PushTask(task types.SdTask) {
logger.Infof("add a new Stable Diffusion Task: %+v", task)
s.taskQueue.RPush(task)
// Txt2ImgReq 文生图请求实体
type Txt2ImgReq struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Seed int64 `json:"seed,omitempty"`
Steps int `json:"steps"`
CfgScale float32 `json:"cfg_scale"`
Width int `json:"width"`
Height int `json:"height"`
SamplerName string `json:"sampler_name"`
EnableHr bool `json:"enable_hr,omitempty"`
HrScale int `json:"hr_scale,omitempty"`
HrUpscaler string `json:"hr_upscaler,omitempty"`
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
ForceTaskId string `json:"force_task_id,omitempty"`
}
// Txt2ImgResp 文生图响应实体
type Txt2ImgResp struct {
Images []string `json:"images"`
Parameters struct {
} `json:"parameters"`
Info string `json:"info"`
}
// TaskProgressResp 任务进度响应实体
type TaskProgressResp struct {
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
}
// Txt2Img 文生图 API
func (s *Service) Txt2Img(task types.SdTask) error {
var taskInfo TaskInfo
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
if err != nil {
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
body := Txt2ImgReq{
Prompt: task.Params.Prompt,
NegativePrompt: task.Params.NegPrompt,
Steps: task.Params.Steps,
CfgScale: task.Params.CfgScale,
Width: task.Params.Width,
Height: task.Params.Height,
SamplerName: task.Params.Sampler,
ForceTaskId: task.Params.TaskId,
}
err = json.Unmarshal(bytes, &taskInfo)
if err != nil {
return fmt.Errorf("error with decode json params: %s", err.Error())
if task.Params.Seed > 0 {
body.Seed = task.Params.Seed
}
data := taskInfo.Data
params := task.Params
data[ParamKeys["task_id"]] = params.TaskId
data[ParamKeys["prompt"]] = params.Prompt
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
data[ParamKeys["steps"]] = params.Steps
data[ParamKeys["sampler"]] = params.Sampler
// @fix bug: 有些 stable diffusion 没有面部修复功能
//data[ParamKeys["face_fix"]] = params.FaceFix
data[ParamKeys["cfg_scale"]] = params.CfgScale
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSteps
taskInfo.SessionId = task.SessionId
taskInfo.TaskId = params.TaskId
taskInfo.Data = data
taskInfo.JobId = task.Id
if task.Params.HdFix {
body.EnableHr = true
body.HrScale = task.Params.HdScale
body.HrUpscaler = task.Params.HdScaleAlg
body.HrSecondPassSteps = task.Params.HdSteps
body.DenoisingStrength = task.Params.HdRedrawRate
}
var res Txt2ImgResp
var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
logger.Debugf("send image request to %s", apiURL)
go func() {
s.runTask(taskInfo, s.httpClient)
}()
return nil
}
// 执行任务
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
logger.Debug(utils.JsonEncode(body))
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
if err != nil {
cbReq.Message = "error with send request: " + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- err
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error http code status: %v", response.Status)
return
}
var images []struct {
Name string `json:"name"`
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
// 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil {
cbReq.Message = "error with decode image:" + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error with upload image: %v", err)
return
}
var info map[string]any
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
// 获取绘画真实的 seed
var info map[string]interface{}
err = utils.JsonDecode(res.Info, &info)
if err != nil {
logger.Error(res.Data)
cbReq.Message = "error with decode image url:" + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error with decode task response: %v", err)
return
}
// 获取真实的 seed 值
cbReq.ImageName = images[0].Name
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
cbReq.Seed = seed
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
errChan <- nil
}()
for {
select {
case value := <-result:
s.callback(value)
return
case err := <-errChan: // 任务完成
if err != nil {
return err
}
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(task.UserId)
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
default:
var progressReq = map[string]any{
"id_task": taskInfo.TaskId,
"id_live_preview": 1,
err, resp := s.checkTaskProgress()
// 更新任务进度
if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(task.UserId)
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
}
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
logger.Debug(cbReq)
s.callback(cbReq)
time.Sleep(time.Second)
}
}
}
func (s *Service) callback(data CBReq) {
// 释放任务锁
s.redis.Del(context.Background(), RunningJobKey)
client := s.Clients.Get(data.SessionId)
if data.Success { // 任务成功
var job model.SdJob
res := s.db.Where("id = ?", data.JobId).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
// 更新任务进度
job.Progress = data.Progress
// 更新任务 seed
var params types.SdTaskParams
err := utils.JsonDecode(job.Params, &params)
if err != nil {
logger.Error("任务解析失败:", err)
return
}
params.Seed = data.Seed
if data.ImageName != "" { // 下载图片
imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imageURL
}
job.Params = utils.JsonEncode(params)
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
var jobVo vo.SdJob
err = utils.CopyObject(job, &jobVo)
if err != nil {
logger.Error("error with copy object: ", err)
return
}
if data.Progress < 100 && data.ImageData != "" {
jobVo.ImgURL = data.ImageData
}
// 扣减绘图次数
s.db.Model(&model.User{}).Where("id = ?", jobVo.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 推送任务到前端
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// 删除任务
s.db.Delete(&model.SdJob{Id: uint(data.JobId)})
// 推送消息到前端
if client != nil {
utils.ReplyChunkMessage(client, vo.SdJob{
Id: uint(data.JobId),
Progress: -1,
TaskId: data.TaskId,
})
}
// 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
var res TaskProgressResp
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return err, nil
}
if response.IsErrorState() {
return fmt.Errorf("error http code status: %v", response.Status), nil
}
return nil, &res
}

View File

@@ -5,6 +5,7 @@ import logger2 "chatplus/logger"
var logger = logger2.GetLogger()
type TaskInfo struct {
UserId uint `json:"user_id"`
SessionId string `json:"session_id"`
JobId int `json:"job_id"`
TaskId string `json:"task_id"`
@@ -15,6 +16,7 @@ type TaskInfo struct {
}
type CBReq struct {
UserId uint
SessionId string
JobId int
TaskId string
@@ -32,11 +34,11 @@ var ParamKeys = map[string]int{
"negative_prompt": 2,
"steps": 4,
"sampler": 5,
"face_fix": 6, // 面部修复
"face_fix": 7, // 面部修复
"cfg_scale": 8,
"seed": 27,
"height": 9,
"width": 10,
"height": 10,
"width": 9,
"hd_fix": 11,
"hd_redraw_rate": 12, //高清修复重绘幅度
"hd_scale": 13, // 高清修复放大倍数

View File

@@ -1,4 +1,4 @@
package service
package sms
import (
"chatplus/core/types"
@@ -7,22 +7,23 @@ import (
)
type AliYunSmsService struct {
config *types.AliYunSmsConfig
config *types.SmsConfigAli
client *dysmsapi.Client
}
func NewAliYunSmsService(config *types.AppConfig) (*AliYunSmsService, error) {
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
config := &appConfig.SMS.Ali
// 创建阿里云短信客户端
client, err := dysmsapi.NewClientWithAccessKey(
"cn-hangzhou",
config.SmsConfig.AccessKey,
config.SmsConfig.AccessSecret)
config.AccessKey,
config.AccessSecret)
if err != nil {
return nil, fmt.Errorf("failed to create client: %v", err)
}
return &AliYunSmsService{
config: &config.SmsConfig,
config: config,
client: client,
}, nil
}
@@ -46,6 +47,7 @@ func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
if response.Code != "OK" {
return fmt.Errorf("failed to send SMS:%v", response.Message)
}
return nil
}
var _ Service = &AliYunSmsService{}

72
api/service/sms/bao.go Normal file
View File

@@ -0,0 +1,72 @@
package sms
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
)
type BaoSmsService struct {
config *types.SmsConfigBao
}
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
config := appConfig.SMS.Bao
if config.Domain == "" { // use default domain
config.Domain = "api.smsbao.com"
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
}
return &BaoSmsService{
config: &config,
}
}
var errMsg = map[string]string{
"0": "短信发送成功",
"-1": "参数不全",
"-2": "服务器空间不支持请确认支持curl或者fsocket联系您的空间商解决或者更换空间",
"30": "密码错误",
"40": "账号不存在",
"41": "余额不足",
"42": "账户已过期",
"43": "IP地址限制",
"50": "内容含有敏感词",
}
func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
content := fmt.Sprintf("%s%s", s.config.Sign, s.config.CodeTemplate)
content = strings.ReplaceAll(content, "{code}", strconv.Itoa(code))
password := utils.Md5(s.config.Password)
params := url.Values{}
params.Set("u", s.config.Username)
params.Set("p", password)
params.Set("m", mobile)
params.Set("c", content)
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
response, err := http.Get(apiURL)
if err != nil {
return err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return err
}
result := string(body)
logger.Debugf("send SmsBao result: %v", errMsg[result])
if result != "0" {
return fmt.Errorf("failed to send SMS:%v", errMsg[result])
}
return nil
}
var _ Service = &BaoSmsService{}

View File

@@ -0,0 +1,8 @@
package sms
const Ali = "ALI"
const Bao = "BAO"
type Service interface {
SendVerifyCode(mobile string, code int) error
}

View File

@@ -0,0 +1,39 @@
package sms
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"strings"
)
type ServiceManager struct {
handler Service
}
var logger = logger2.GetLogger()
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
active := Ali
if config.SMS.Active != "" {
active = strings.ToUpper(config.SMS.Active)
}
var handler Service
switch active {
case Ali:
client, err := NewAliYunSmsService(config)
if err != nil {
return nil, err
}
handler = client
break
case Bao:
handler = NewSmsBaoSmsService(config)
break
}
return &ServiceManager{handler: handler}, nil
}
func (m *ServiceManager) GetService() Service {
return m.handler
}

View File

@@ -1,5 +0,0 @@
package service
type SmsService interface {
SendVerifyCode(mobile string, code int) error
}

View File

@@ -0,0 +1,44 @@
package service
import (
"bytes"
"chatplus/core/types"
"fmt"
"mime"
"net/smtp"
)
type SmtpService struct {
config *types.SmtpConfig
}
func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
return &SmtpService{
config: &appConfig.SmtpConfig,
}
}
func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := "ChatPlus注册验证码"
body := fmt.Sprintf("您正在注册 ChatPlus AI 助手账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", code)
// 设置SMTP客户端配置
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
// 对主题进行MIME编码
encodedSubject := mime.QEncoding.Encode("UTF-8", subject)
// 组装邮件
message := bytes.NewBuffer(nil)
message.WriteString(fmt.Sprintf("From: \"%s\" <%s>\r\n", s.config.AppName, s.config.From))
message.WriteString(fmt.Sprintf("To: %s\r\n", to))
message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject))
message.WriteString("\r\n" + body)
// 发送邮件
// 发送邮件
err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes())
if err != nil {
return fmt.Errorf("error sending email: %v", err)
}
return nil
}

View File

@@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake {
}
// Next 生成一个新的唯一ID
func (s *Snowflake) Next() (string, error) {
func (s *Snowflake) Next(raw bool) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) {
s.lastTimestamp = timestamp
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
if raw {
return fmt.Sprintf("%d", id), nil
}
now := time.Now()
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
}

Some files were not shown because too many files have changed in this diff Show More