Compare commits

...

394 Commits
rag ... v4.0.5

Author SHA1 Message Date
RockYang
27c816cf3b merge conflicts for v4.0.5 2024-05-21 11:30:40 +08:00
RockYang
0d81776212 update docker image name 2024-05-21 11:21:27 +08:00
RockYang
cccab31c0f rename project name to geekai 2024-05-20 15:14:02 +08:00
RockYang
4ddf3bf2bf fixed bugs for send message captcha component 2024-05-17 08:41:09 +08:00
RockYang
3d37a3d367 update readme 2024-05-14 18:23:12 +08:00
RockYang
73d8236697 update change log 2024-05-14 18:20:52 +08:00
RockYang
3c34e8e0e7 fixed bug for dalle3 task not decrease power 2024-05-10 11:17:52 +08:00
RockYang
922202734a fixed bug for markmap 2024-05-09 21:55:40 +08:00
RockYang
b270960a04 remove license code 2024-05-07 16:41:35 +08:00
RockYang
5c4899df6e upgrade to v4.0.4 2024-05-07 16:32:05 +08:00
RockYang
b0c9ffc5a6 handle the exception for web front page 2024-05-06 17:39:58 +08:00
RockYang
f527cc5b98 fixed conflicts 2024-05-06 14:44:09 +08:00
RockYang
debe8dc209 chore: change module name to geekai, add copyright in source code 2024-05-06 14:41:27 +08:00
RockYang
2f0215ac87 Update README.md 2024-05-06 10:45:50 +08:00
RockYang
dd5cc206e5 update LICENSE.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-05-05 03:05:23 +00:00
RockYang
142cd553a3 fix bug for waterflow component 2024-05-05 10:52:29 +08:00
RockYang
657ecccee3 update changelog 2024-05-04 21:32:43 +08:00
RockYang
1232c3cd9c fix bug for license synchronize 2024-05-04 21:30:29 +08:00
RockYang
3ac04a3938 opt: optimize mobile images page styles 2024-05-03 08:14:33 +08:00
RockYang
b7abc42209 fix: fix bug for dalle power refund 2024-05-01 08:48:32 +08:00
RockYang
a48179ce0e chore: update docker image name 2024-05-01 07:59:20 +08:00
RockYang
e589f25a05 opt: styles and view micro optimization 2024-05-01 07:40:56 +08:00
RockYang
cc1a3ce343 opt: close unused websocket connections 2024-04-30 22:54:39 +08:00
RockYang
7bb76d581c chore: replace docker image url with AliYun 2024-04-30 19:08:33 +08:00
RockYang
0d733c0be0 feat: change theme for mobile site is ready 2024-04-30 18:57:15 +08:00
RockYang
8b40ac5b5c feat: mobile page refactor is finished 2024-04-29 19:22:00 +08:00
RockYang
24479814e9 opt: add chat config for mobile chat session 2024-04-29 09:39:23 +08:00
RockYang
99df028237 feat: add index page for mobile 2024-04-28 19:09:26 +08:00
RockYang
b354b88876 theme change is ready 2024-04-27 13:13:28 +08:00
RockYang
5e0be4d10e feat: admin console dark theme 2024-04-26 18:10:17 +08:00
RockYang
468b48151f opt: optimize index and login page UI 2024-04-26 16:07:02 +08:00
RockYang
fa5c036041 synchronize license every 10 secs 2024-04-26 14:35:01 +08:00
RockYang
0fdc588167 change index page background 2024-04-25 18:54:33 +08:00
RockYang
2e023cb8dc update readme 2024-04-25 06:27:30 +08:00
RockYang
e933f32d9c remove unused files 2024-04-24 21:22:56 +08:00
RockYang
bd4b0c4d65 show license info in admin active page, optimize markdown generate prompt 2024-04-24 19:00:28 +08:00
RockYang
0b2501c1d8 output the error to chat page directly, replace the common error message 'AI开小差' 2024-04-24 10:10:03 +08:00
RockYang
9d28e62142 fixed bug for markmap 2024-04-23 20:47:06 +08:00
RockYang
c1d892069e release v4.0.4 2024-04-23 18:47:23 +08:00
RockYang
61b2dbc9f1 optimize styles and release v4.0.4 2024-04-23 18:46:32 +08:00
RockYang
be3245666e feat: hide more navigator items 2024-04-22 16:27:53 +08:00
RockYang
dacdd6fe74 feat: support markmap svg export and download as png image 2024-04-22 14:20:51 +08:00
RockYang
6807f7e88a allow users to select a chatApp to chat in chat app list page 2024-04-22 11:18:55 +08:00
RockYang
087f5ab2d1 optimize index page UI 2024-04-22 10:43:33 +08:00
RockYang
47c5a0387b optimize code for remove timeout and failed image drawing job 2024-04-21 21:44:28 +08:00
RockYang
f9da18ad52 image wall page add dalle 2024-04-21 20:42:42 +08:00
RockYang
5c9025ca22 dalle image page is ready 2024-04-21 20:23:47 +08:00
RockYang
d02cb573fd DO NOT refresh finished jobs when job is running 2024-04-20 21:30:55 +08:00
RockYang
caa538a1d0 fixed markdown generating styles 2024-04-19 18:22:45 +08:00
RockYang
b584b4bfb6 support send email use TLS 2024-04-19 12:04:59 +08:00
RockYang
bda335212d Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2024-04-19 10:56:05 +08:00
RockYang
06f4cdc649 fixed bug for QWen response blank quotes 2024-04-19 10:55:29 +08:00
RockYang
336a7d5b56 fixed bug for QWen response blank quotes 2024-04-19 10:30:02 +08:00
RockYang
a0f464830f fixed bug for chat handler doRequest method 2024-04-16 23:49:56 +08:00
RockYang
9bf7fa4081 compatible freeGPT35 API 2024-04-15 21:03:19 +08:00
RockYang
96ead65774 fixed bug for websocket close tip message 2024-04-15 18:15:15 +08:00
RockYang
7ad41927aa feat: markmap function is ready 2024-04-15 17:23:59 +08:00
RockYang
4ca9dfd9c0 fixed upscale and variation action url 2024-04-15 15:14:49 +08:00
RockYang
8a9f386d8f opt: close the old connection for mj and sd clients 2024-04-15 09:34:20 +08:00
RockYang
adfee8bf58 update version 2024-04-15 09:05:54 +08:00
RockYang
fbfa2a71a9 release v4.0.3 2024-04-15 08:26:07 +08:00
RockYang
9a1368ef17 markmap enable to select ai model 2024-04-15 06:16:53 +08:00
RockYang
31b02b97d3 fixed chat page styles 2024-04-12 21:57:41 +08:00
RockYang
42da38c5c3 feat: markmap page view is ready 2024-04-12 18:49:24 +08:00
RockYang
0a01b55713 feat: allow chat model bind a fixed api key 2024-04-12 17:09:22 +08:00
RockYang
3b292c2a12 chore: update build.sh 2024-04-11 21:26:33 +08:00
RockYang
db0ba0d9a0 feat: use custom mode for mj upscale and variation operarions 2024-04-11 17:32:34 +08:00
RockYang
3a23ff6b42 feat: add index page 2024-04-10 18:23:55 +08:00
RockYang
1e9c5adb0a feat: support for freeGPT35 API 2024-04-10 14:49:07 +08:00
RockYang
abab76ccc6 feat: support gpt-4-turbo-2014-04-09 vision function 2024-04-10 11:47:10 +08:00
RockYang
6efd92806f fixed bug for gpt-4-turbo-2024-0409 model function calls 2024-04-10 10:23:45 +08:00
RockYang
cfe333e89f fix bug: remove timeout task ONLY for unfinished(progress < 100) 2024-04-10 06:25:54 +08:00
RockYang
a7237fe62f Merge branch 'main' of gitee.com:blackfox/chatgpt-plus 2024-04-07 18:33:55 +08:00
RockYang
c3c454b7d7 Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2024-04-07 18:32:20 +08:00
RockYang
d4d708d44b update database files 2024-04-07 18:26:45 +08:00
RockYang
7f0b6a3a46 set the enable status for adding new api key with default value true 2024-04-07 10:17:10 +08:00
RockYang
c2a7c089d2 update change log 2024-04-07 08:40:43 +08:00
RockYang
df5bd4df60 Merge branch 'dev' 2024-04-07 08:03:14 +08:00
RockYang
79b6010104 optimize ngin configuration for chat-plus.conf 2024-04-07 08:02:47 +08:00
RockYang
97b0a98793 fix 404 error with remove api keys 2024-04-07 08:01:25 +08:00
RockYang
5230f90540 fix bug for remove api 404 errorc 2024-04-06 20:36:52 +08:00
RockYang
803db4e895 feat: show bind model on chat role list page 2024-04-05 21:21:28 +08:00
RockYang
7cee9f2ebb feat: support uploading role icon 2024-04-05 17:41:23 +08:00
RockYang
8be9a21efd feat: allow bind a chat model for chat role 2024-04-05 12:51:18 +08:00
RockYang
6a3e26b566 update change log 2024-04-05 06:58:13 +08:00
RockYang
0355c37bef feat: stable diffusion image drawing on mobile is ready 2024-04-03 18:13:48 +08:00
RockYang
9b7ee538c4 feat: midjourney role and style consistency is ready 2024-04-02 19:01:28 +08:00
RockYang
d900a3d08e update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-04-02 10:00:45 +00:00
RockYang
cdf5b66729 Update README.md 2024-04-02 17:59:53 +08:00
RockYang
1cff4b63cd feat: image preview for stable-diffusion task is ready 2024-04-02 17:24:38 +08:00
RockYang
da14309ef9 feat: support midjourney --cref and --sref for role consistency 2024-04-02 14:59:53 +08:00
RockYang
fbb216fe3b feat: update menu icons, add version in site titles 2024-04-01 18:20:00 +08:00
RockYang
95efbd5659 show power for chat and imaging page 2024-03-31 20:49:12 +08:00
RockYang
4596c1049c opt: change the relative path with absolute path for midjourney image uploading 2024-03-31 17:45:22 +08:00
RockYang
b35d95f0c7 Merge branch 'main' into dev 2024-03-30 11:57:31 +08:00
RockYang
01419df998 remove dead code 2024-03-30 11:57:23 +08:00
RockYang
a6c00c42fa fixed conflicts 2024-03-29 18:10:59 +08:00
RockYang
4cc9db7115 update readme file 2024-03-29 18:06:55 +08:00
RockYang
4f1ed54059 update databases 2024-03-29 17:43:38 +08:00
RockYang
8227a73e35 feat: support custom menu 2024-03-29 15:41:58 +08:00
RockYang
adfd8c1939 add tip for mj image buttons 2024-03-28 21:41:49 +08:00
RockYang
8eed7ff534 fix: fix overflow hidden for admin page 2024-03-28 18:51:09 +08:00
RockYang
c79c4e74d0 fix: fix overflow hidden for mobile page 2024-03-28 18:13:33 +08:00
RockYang
f1855fd0a1 fix: use slide captcha for iphone 2024-03-28 15:00:53 +08:00
RockYang
1f964c74e9 feat: add mj_action_power system config item 2024-03-28 09:53:41 +08:00
RockYang
4fb2c5803c feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy 2024-03-27 18:57:15 +08:00
RockYang
b5947545cb feat: auto translate and rewrite prompt for midjourney and stable-diffusion 2024-03-27 13:45:52 +08:00
RockYang
342b76f666 feat: stable-diffusion refactored, replace websocket api with sdapi 2024-03-26 18:23:08 +08:00
RockYang
49b5906bc7 fix: can not change user's power in admin console 2024-03-25 11:40:03 +08:00
RockYang
3075bfb7fc fix: fix bug for update user's power in admin page did not work 2024-03-23 16:21:37 +08:00
RockYang
82e06fad33 No need to login with Stable-Diffusion page and Invite page 2024-03-23 15:45:37 +08:00
RockYang
4a9028747b fix: fix code highlight error when add formule detecting 2024-03-22 19:09:04 +08:00
RockYang
4a8ff0ccf0 chore: correct prompt messages 2024-03-22 18:27:57 +08:00
RockYang
99341f0484 feat: add chart for admin dashbord 2024-03-22 16:57:30 +08:00
RockYang
f58ac29ad0 feat: integrate xxl-job-admin to implements automatic task scheduling 2024-03-22 13:47:16 +08:00
RockYang
7060edb3e5 feat: save prompt in power log for dalle-3 2024-03-21 15:55:39 +08:00
RockYang
41ae411f9b feat: add manager list page in console page 2024-03-21 15:24:28 +08:00
RockYang
79b7fee47c feat: add powerlog page for admin console 2024-03-21 13:46:39 +08:00
RockYang
0044bf10af opt: optimize the formula show styles 2024-03-21 11:04:12 +08:00
RockYang
e9348d3611 always parse authorization token for all request 2024-03-20 21:11:52 +08:00
RockYang
b9236e09a7 fixed conflicts 2024-03-20 20:40:22 +08:00
RockYang
09b38d5f42 update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-03-20 20:39:44 +08:00
RockYang
7bb539a06e feat: Async loading midjourney job for mobile MidJourney page 2024-03-20 18:39:14 +08:00
RockYang
5cdada8265 feat: h5 payment for payjs is ready 2024-03-20 17:46:39 +08:00
RockYang
4147c217b1 feat: payment for mobile pages is ready 2024-03-20 16:14:02 +08:00
RockYang
8dda639b23 feat: the power log page is ready 2024-03-20 14:14:30 +08:00
RockYang
8487d2c9eb feat: no need login refactor with member and chatApps page 2024-03-19 18:59:02 +08:00
RockYang
c5e583b215 feat: load preview page do not require user to login 2024-03-19 18:25:01 +08:00
RockYang
549f618cff feat: optimize login dialog 2024-03-19 10:47:13 +08:00
RockYang
e9a3510346 feat: refactoring adjustments for reward page is ready 2024-03-18 18:28:34 +08:00
RockYang
30e6e963b3 feat: refactoring adjustments for member pages 2024-03-18 16:59:07 +08:00
RockYang
c72d963f45 feat: The 'chat_models' field of user table, holds the model IDS in place of the model values 2024-03-18 15:37:46 +08:00
RockYang
172d498618 remove new-ui files 2024-03-18 12:01:34 +08:00
RockYang
313993532e chore: adjust page styles 2024-03-18 06:46:08 +08:00
RockYang
e53db3582c 重构主体工作完成 2024-03-15 18:35:10 +08:00
RockYang
72c6bd3f77 restore new ui files 2024-03-15 11:13:02 +08:00
廖彦棋
ca8b349df3 fix(ui): 交互修复调整 2024-03-15 11:07:41 +08:00
RockYang
1b206c3640 feat: refactor user list page for new UI 2024-03-15 09:29:19 +08:00
廖彦棋
c60276fc9f fix(ui): 用户管理有效期传参调整,权限标识补充 2024-03-15 09:25:06 +08:00
廖彦棋
d00a3167c0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:49:18 +08:00
廖彦棋
6b1cd8c30c refactor(ui): 无权限页面调整 2024-03-14 17:49:15 +08:00
吴汉强
46f12dc9ad feat(ui): 后台首页去掉权限判断 2024-03-14 17:24:40 +08:00
廖彦棋
a3e1d8ae21 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:19:41 +08:00
廖彦棋
72a066b93e feat(ui): 无权限判断 2024-03-14 17:19:39 +08:00
吴汉强
0327a829ac Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 17:12:53 +08:00
吴汉强
882e9b8819 feat(ui): 新增 sql 2024-03-14 17:12:48 +08:00
廖彦棋
ef58cfadaa Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:06:15 +08:00
吴汉强
bf958d6113 feat(ui): 403,没权限 2024-03-14 16:41:38 +08:00
吴汉强
71611273d7 feat(ui): 网站配置需授权,去掉 2024-03-14 16:28:49 +08:00
廖彦棋
b27c654311 fix(ui): 调整 2024-03-14 16:14:49 +08:00
吴汉强
90930ea9f9 feat(ui): 后端加权限验证 2024-03-14 15:39:12 +08:00
廖彦棋
1ab2185ff1 feat(ui): 角色管理 2024-03-14 15:25:17 +08:00
廖彦棋
0f2f978d4c feat(ui): 新增系统分类菜单 2024-03-14 15:17:53 +08:00
廖彦棋
f61963b0b0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 10:56:56 +08:00
廖彦棋
2aa413960d fix(ui): 修复 2024-03-14 10:56:54 +08:00
吴汉强
aa4bbba5ec Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 10:28:39 +08:00
吴汉强
eba61fea2d feat(ui): 登录接口返回权限 2024-03-14 10:28:32 +08:00
廖彦棋
34e3455128 feat(ui): 管理后台新增权限及部分组合式函数优化 2024-03-14 10:27:09 +08:00
廖彦棋
07dca3e739 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 17:30:26 +08:00
廖彦棋
4cb4b145f9 feat(ui): web移动端初始化 2024-03-13 17:30:24 +08:00
吴汉强
1ed417cb69 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 17:24:38 +08:00
吴汉强
6cf91a84ca feat(ui): 增加角色管理,管理员方法新增角色关联 2024-03-13 17:24:30 +08:00
chenzifan
0b566980fc Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 14:40:45 +08:00
chenzifan
f86176b342 refactor: remove api change to post request 2024-03-13 14:40:38 +08:00
吴汉强
c700b32670 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 11:41:07 +08:00
吴汉强
22641b452a feat(ui): 增加系统权限管理 2024-03-13 11:41:01 +08:00
廖彦棋
d3fbb8c19e fix(ui): ui调整 2024-03-13 09:54:20 +08:00
RockYang
e3bb69ff10 docs: update sql file 2024-03-13 08:49:40 +08:00
RockYang
770360c614 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:48:10 +08:00
chenzifan
f302a0478f fix: 删除系统管理员失效的问题 2024-03-13 08:47:17 +08:00
chenzifan
a88697b43a Merge remote-tracking branch 'origin/ui' into ui
# Conflicts:
#	api/handler/admin/admin_user_handler.go
2024-03-13 08:46:16 +08:00
chenzifan
cc6f140812 fix: 删除系统管理员失效的问题 2024-03-13 08:45:09 +08:00
廖彦棋
424f2b3bdc Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:44:50 +08:00
廖彦棋
ec0c13a600 feat(ui): 调整 2024-03-13 08:44:48 +08:00
chenzf@pvc123.com
a1f03bec4c feat: 超级管理员不支持修改和删除 2024-03-12 21:16:05 +08:00
RockYang
b5bd4a5e0e Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 18:07:24 +08:00
RockYang
7c2e49bfdb fix conflicts 2024-03-12 18:07:19 +08:00
chenzifan
f80fe6d041 feat: 增加系统管理员 2024-03-12 18:06:49 +08:00
RockYang
72f80a96bc fix conflicts 2024-03-12 18:03:24 +08:00
RockYang
2de655a1cf refactor: use power replace calls for front pages 2024-03-12 17:47:06 +08:00
RockYang
da2bd4a501 refactor: 重构项目,为所有的 AI 工具都引入算力,采用算力统一结算各个工具的调用次数和权限 2024-03-12 15:40:44 +08:00
廖彦棋
e0aa62c40d Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 08:37:29 +08:00
廖彦棋
9d26a892d1 refactor(ui): 调整 2024-03-12 08:37:27 +08:00
huangqj
4ece7f2847 fix(ui):环境变量 2024-03-11 16:10:49 +08:00
廖彦棋
32368caf1b feat(ui): 新增系统管理员 2024-03-11 15:59:15 +08:00
廖彦棋
e91f54e79e fix(ui): 删除冗余 2024-03-11 14:23:11 +08:00
廖彦棋
bb8f4c57c4 fix(ui): 删除冗余 2024-03-11 14:22:39 +08:00
RockYang
43bfac99b6 feat: replace Tools param with Function param for OpenAI chat API 2024-03-11 14:09:19 +08:00
廖彦棋
be379b6d63 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-11 13:52:24 +08:00
廖彦棋
17f3c9b840 fix(ui): type 2024-03-11 13:52:22 +08:00
chenzifan
24de97fac2 feat: 优化后台UI 2024-03-11 13:51:26 +08:00
chenzifan
bf27b44fee Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 13:46:46 +08:00
廖彦棋
1802b4fe4d refactor(ui): 调整优化 2024-03-11 13:46:08 +08:00
廖彦棋
241a5c7bc9 feat(ui): 细节优化 2024-03-11 12:02:20 +08:00
廖彦棋
557d547bf1 feat(ui): 上传功能补充 2024-03-11 11:41:50 +08:00
廖彦棋
2e7b75affb feat(ui): 登录新增验证码及记住密码功能 2024-03-11 10:49:13 +08:00
廖彦棋
bc21a1d443 feat(ul): 顶部信息 2024-03-11 09:00:00 +08:00
huangqj
3fc9e10a24 feat(ui):看板图表 样式调整 2024-03-11 08:35:21 +08:00
chenzifan
5fa1aa2060 Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 08:01:54 +08:00
RockYang
be8a0ec184 opt: remove global keyup event bind 2024-03-10 09:45:59 +08:00
RockYang
b02e3aad95 docs: update config.toml 2024-03-10 09:45:59 +08:00
RockYang
08eca511ad docs: add database sql file for v3.2.7 2024-03-10 09:45:59 +08:00
RockYang
c34e911596 feat: allow to view chat message in manager console 2024-03-10 09:45:59 +08:00
RockYang
8a452c3072 feat: download image which ai generated in dialog and replace the image url 2024-03-10 09:45:59 +08:00
RockYang
13bfb14107 feat: image-wall page for mobile is ready 2024-03-10 09:45:59 +08:00
RockYang
4188b0969e feat: allow user config third-party platform openai and mj api key 2024-03-10 09:45:59 +08:00
RockYang
0c27795a10 feat: added delete file function 2024-03-10 09:45:59 +08:00
RockYang
d05693c5c1 fix: verifycation component touch event coordinates misplace in iphone browser 2024-03-10 09:45:59 +08:00
RockYang
c0b2063b38 fix: fix bug for regenerate button did not work 2024-03-10 09:45:59 +08:00
RockYang
4d183747b1 fix: Upscale and Variation task overrite each other 2024-03-10 09:45:59 +08:00
RockYang
08fe1b2f75 feat: midjourney mobile page all function is ready 2024-03-10 09:45:59 +08:00
RockYang
db3e8a267e feat: mobile mj list page is ready 2024-03-10 09:45:59 +08:00
RockYang
8fc62682c4 feat: add mj image list component for mobile page. fixed bug for html tag escape 2024-03-10 09:45:59 +08:00
RockYang
75031914a3 feat: add functions mj page for mobile 2024-03-10 09:45:59 +08:00
RockYang
a4c9fdd95a opt: add default extension for mj image 2024-03-10 09:45:59 +08:00
RockYang
6a9bfeb5aa feat: mj for mobile page payout is ready 2024-03-10 09:45:59 +08:00
RockYang
e654766f60 add docs and github link 2024-03-10 09:45:59 +08:00
RockYang
0ef6955f96 opt: enable use cdn url for mj-plus 2024-03-10 09:45:59 +08:00
RockYang
b4501557c9 feat: LaTeX parse is ready 2024-03-10 09:45:59 +08:00
RockYang
a2ed99e6cb feat: add model field for chat_item and and chat_history data table 2024-03-10 09:45:59 +08:00
RockYang
6bd6bb3885 feat: add err_msg field for mj and sd jobs 2024-03-10 09:45:59 +08:00
RockYang
399cf65fc9 feat: blend and swap face function for midjourney-plus is ready 2024-03-10 09:45:59 +08:00
RockYang
24906a6df1 feat: add blend and swapface task implements for midjourney 2024-03-10 09:45:59 +08:00
RockYang
d772bbebe6 opt: refactor chat session page for mobile device 2024-03-10 09:45:59 +08:00
RockYang
14988853a3 opt: optimize chat list page for mobile 2024-03-10 09:45:59 +08:00
RockYang
7b3f16ac9f feat: use vant replace element-plus as mobile UI framework 2024-03-10 09:45:59 +08:00
RockYang
82b2755c18 feat: add websocket heartbeat message for mj page 2024-03-10 09:45:59 +08:00
廖彦棋
ff4b267858 fix(ui): 细节调整 2024-03-08 17:46:48 +08:00
huangqj
a590d0497f feat(ui):细节调整 2024-03-08 10:24:38 +08:00
廖彦棋
ac30d906f0 feat(ui): prettier 2024-03-08 09:59:40 +08:00
廖彦棋
5bc071e038 refactor(ui): 登录页重构 2024-03-08 09:45:09 +08:00
廖彦棋
88b956cf98 refactor(ui): 优化 2024-03-08 09:12:39 +08:00
huangqj
f725cf4661 feat(ui):路由 2024-03-08 08:35:41 +08:00
廖彦棋
057cc1e8a6 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 18:03:08 +08:00
廖彦棋
de122735b8 feat(ui): 过期跳转登录 2024-03-07 18:03:06 +08:00
huangqj
e87ede981c feat(ui):apiKey 语言模型 角色管理 产品 2024-03-07 17:58:25 +08:00
廖彦棋
606fb498e1 feat(ui): 新增系统设置 2024-03-07 17:24:50 +08:00
廖彦棋
a0c06e40a4 feat(ui): 对话管理 2024-03-07 15:32:32 +08:00
huangqj
aba8f57279 feat(ui):用户 2024-03-07 15:05:01 +08:00
huangqj
960286a350 feat(ui):simpleTable 2024-03-07 14:59:45 +08:00
huangqj
8c93fa51f6 feat(ui):searchtable 2024-03-07 14:59:16 +08:00
huangqj
cb0e7d64ff feat(ui):用户 2024-03-07 14:03:55 +08:00
廖彦棋
8e7413da97 feat(ui): 函数管理 2024-03-07 11:40:57 +08:00
廖彦棋
a36f14eb94 feat(ui): 新增弹窗及时间格式化 2024-03-07 09:23:45 +08:00
chenzifan
f2f9f6e488 Merge branch 'main' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 08:37:54 +08:00
chenzifan
85068b8ca2 feat: 增加系统用户管理 2024-03-07 08:37:48 +08:00
廖彦棋
f2cfcfeefc fix(ui): ts类型 2024-03-06 18:20:07 +08:00
RockYang
755273a898 feat: update changelogs 2024-03-06 17:58:17 +08:00
廖彦棋
d4a24a0f1d feat(ui): 新增登录 2024-03-06 17:54:38 +08:00
RockYang
92281fcbb7 feat: Mj and sd jobs data loading in pages 2024-03-06 17:31:54 +08:00
RockYang
636db4afcc add prompt translating function for mobile midjourney page 2024-03-06 16:22:03 +08:00
huangqj
ba25b8755e feat(ui):simpleTable 2024-03-06 15:33:37 +08:00
廖彦棋
6399d13a49 feat(ui): 新增请求方法及表格 2024-03-06 13:55:38 +08:00
廖彦棋
06fa54fd25 feat(ui): 管理后台基础配置 2024-03-06 10:23:55 +08:00
廖彦棋
a335b965d0 chore(ui): 目录结构调整 2024-03-06 09:32:47 +08:00
廖彦棋
725adaa7d0 feat(ui): 初始化 2024-03-06 09:27:11 +08:00
chenzifan
7e7e81e974 refactor: 初始化UI重构 2024-03-06 08:57:46 +08:00
RockYang
8cfe6bfc17 Merge branch 'dev' of gitee.com:blackfox/chatgpt-plus-pro into dev 2024-03-04 08:34:12 +08:00
RockYang
33de83f2ac feat: add removing order button in admin order list page 2024-03-03 19:27:22 +08:00
RockYang
3f856afec8 fix: fix major bugs for unauthorized access to data 2024-03-03 10:40:32 +08:00
RockYang
4e4dc4cb73 Update README.md 2024-03-01 23:08:05 +08:00
RockYang
02a9c422fe fix: fixed bug image preview im mobile chat session page 2024-02-29 15:41:45 +08:00
RockYang
ca69341024 feat: add draw same image for midjourney page 2024-02-29 11:44:09 +08:00
RockYang
169bf069ce opt: add logs for mj-plus api error 2024-02-28 15:50:42 +08:00
RockYang
1bee0ab04d opt: replace proxy url for discord image url 2024-02-27 17:45:57 +08:00
RockYang
440d91dd0e feat: add change password in with mobile page 2024-02-27 15:36:20 +08:00
RockYang
8168e246a8 feat: add image preview for mobile chat page 2024-02-26 18:11:37 +08:00
RockYang
2ef07574ae feat: replace http polling with webscoket notify in sd image page 2024-02-26 15:45:54 +08:00
RockYang
37392f2bb2 chore: replace 'token' with power 2024-02-23 18:11:57 +08:00
RockYang
a80cd3848e docs: update mj-plus api domain 2024-02-23 15:41:02 +08:00
RockYang
db6ed84451 opt: remove global keyup event bind 2024-02-23 11:43:27 +08:00
RockYang
4463cc5963 docs: update config.toml 2024-02-22 17:53:39 +08:00
RockYang
d316158fe2 docs: add database sql file for v3.2.7 2024-02-22 17:28:22 +08:00
RockYang
e02a8d7586 feat: allow to view chat message in manager console 2024-02-22 17:16:44 +08:00
RockYang
9988dff885 feat: download image which ai generated in dialog and replace the image url 2024-02-20 18:38:03 +08:00
RockYang
35ef5674ff feat: image-wall page for mobile is ready 2024-02-20 17:38:18 +08:00
RockYang
976da45bce feat: allow user config third-party platform openai and mj api key 2024-02-20 11:23:55 +08:00
RockYang
c83ac48bd2 feat: added delete file function 2024-02-19 16:43:03 +08:00
RockYang
3d159a833e fix: verifycation component touch event coordinates misplace in iphone browser 2024-02-19 14:04:50 +08:00
RockYang
4b09878bdd fix: fix bug for regenerate button did not work 2024-02-19 11:22:42 +08:00
RockYang
b0162e6a92 fix: Upscale and Variation task overrite each other 2024-02-16 18:08:29 +08:00
RockYang
8ab15e5dc4 feat: midjourney mobile page all function is ready 2024-02-16 15:55:04 +08:00
RockYang
d2ac807252 feat: mobile mj list page is ready 2024-02-15 18:11:22 +08:00
RockYang
0af01f6f1f feat: add mj image list component for mobile page. fixed bug for html tag escape 2024-02-15 11:39:04 +08:00
RockYang
013b319fab feat: add functions mj page for mobile 2024-01-31 07:24:35 +08:00
RockYang
2899ba5949 opt: add default extension for mj image 2024-01-30 21:46:17 +08:00
RockYang
a558b7e104 feat: mj for mobile page payout is ready 2024-01-30 18:34:01 +08:00
RockYang
7a833e2233 add docs and github link 2024-01-30 16:18:27 +08:00
RockYang
bf65746d00 opt: enable use cdn url for mj-plus 2024-01-28 21:56:25 +08:00
RockYang
f08a7862de feat: LaTeX parse is ready 2024-01-26 18:04:53 +08:00
RockYang
023a2c2f09 feat: add model field for chat_item and and chat_history data table 2024-01-26 16:54:00 +08:00
RockYang
1bcd0f4c1a feat: add err_msg field for mj and sd jobs 2024-01-26 14:50:36 +08:00
RockYang
a0f3bc8ccb feat: blend and swap face function for midjourney-plus is ready 2024-01-26 11:57:08 +08:00
RockYang
dea72738c1 feat: add blend and swapface task implements for midjourney 2024-01-25 18:50:24 +08:00
RockYang
a1d1fe7763 opt: refactor chat session page for mobile device 2024-01-25 14:07:10 +08:00
RockYang
a39ed9764c opt: optimize chat list page for mobile 2024-01-24 18:23:24 +08:00
RockYang
aaa5ba99aa feat: use vant replace element-plus as mobile UI framework 2024-01-24 17:34:30 +08:00
RockYang
2113508b6d feat: add websocket heartbeat message for mj page 2024-01-24 09:33:04 +08:00
RockYang
7fe4212684 add v3.2.6 database sql file 2024-01-23 18:00:49 +08:00
RockYang
8bdda64794 doc: update config sample file 2024-01-23 17:56:22 +08:00
RockYang
ec08c24dca fix: auto fill apiURL when platform changed for ApiKey add page 2024-01-23 17:30:54 +08:00
RockYang
a992a5b3b3 feat: merge sms branch,add DuanXinBao sms service implemetation 2024-01-23 16:16:47 +08:00
RockYang
0f05970141 opt: add heartbeat message for websocket connects 2024-01-22 18:42:51 +08:00
whale_fall
e5e762efcd feat: 添加支持多个短信服务商支持 添加短信宝服务商支持,同时添加配置示例 2024-01-22 16:38:44 +08:00
RockYang
b3d0c1ef9c fix: fix bug for wechat transfer message parse failed 2024-01-22 16:10:08 +08:00
RockYang
397078f7ff feat: HuPiPay order check function is ready 2024-01-22 15:17:26 +08:00
RockYang
3ad8065e20 opt: verify the order in notify callback 2024-01-22 13:58:25 +08:00
RockYang
66c7717f04 chore: print error detail when call http api failed with mj 2024-01-21 22:30:24 +08:00
RockYang
412f8ecc6c opt: add image upload support for md-editor-3 2024-01-19 18:43:13 +08:00
RockYang
51dcf642b3 fixed conflicts 2024-01-19 18:21:49 +08:00
RockYang
bfeea555b2 feat: system notice function is ready 2024-01-19 18:19:51 +08:00
RockYang
479f94c372 feat: system notice function is ready 2024-01-19 18:18:10 +08:00
RockYang
0140713e86 fix: fixed bug for img_call increased when upscale task run failed 2024-01-19 17:10:52 +08:00
RockYang
15b2ec9721 feat: add system config item for wechat qrcode 2024-01-19 16:58:13 +08:00
RockYang
c9cd082855 chore: optimize variable name 2024-01-19 11:26:22 +08:00
RockYang
d7c002890c Merge branch 'main' into dev 2024-01-19 10:09:18 +08:00
RockYang
348dd22279 Merge branch 'main' of gitee.com:blackfox/chatgpt-plus 2024-01-19 10:09:01 +08:00
RockYang
3e99b4cbf6 !4 添加支持阿里旗下的大模型 通义千问对话
Merge pull request !4 from 鲸落/qwen
2024-01-19 02:06:17 +00:00
whale_fall
6968da3ac7 feat: 添加支持阿里的通义千问对话 2024-01-19 09:52:16 +08:00
RockYang
bf1c1b84c3 feat: add image publish function, ONLY published image show in image wall page 2024-01-19 06:52:23 +08:00
RockYang
c70314d930 update change log 2024-01-18 18:11:25 +08:00
RockYang
9104ca8e49 feat: add system config disable user registeration 2024-01-18 17:24:02 +08:00
RockYang
2af33b3630 opt: compatible wechat old message format for parsing wechat transfer message 2024-01-18 16:58:20 +08:00
RockYang
654e795545 opt: optimize order query alg, reduce polling times 2024-01-18 09:39:36 +08:00
RockYang
c62ba2451e update config 2024-01-16 15:24:06 +08:00
RockYang
d72d1b8a99 docs: update change log 2024-01-16 15:05:54 +08:00
RockYang
b939d6016b docs: update config files 2024-01-16 14:38:18 +08:00
RockYang
36a2626ccc opt: optimize markdown image parser, identify image and blockquote tags 2024-01-16 10:13:00 +08:00
RockYang
bd057a4cc9 feat: attachments manage function is ready 2024-01-15 18:48:01 +08:00
RockYang
dc24a8c781 feat: gpt-4-gizmo-g-* model is supported 2024-01-15 15:03:05 +08:00
RockYang
59fa21779b feat: gpt-4-all model is ready 2024-01-15 14:07:24 +08:00
RockYang
a140671aad opt: optimize vip recharge logic 2024-01-15 11:01:57 +08:00
RockYang
5fe8990fb4 Merge branch 'dev' 2024-01-15 10:29:01 +08:00
RockYang
12799b7159 opt: optimize vip recharge logic 2024-01-15 10:28:46 +08:00
RockYang
9929746b1d feat: add asynchronously pull midjourney task progress in case the synchronization callback is fails 2024-01-12 18:24:28 +08:00
RockYang
d70035ff0c feat: midjourney plus service is ready 2024-01-11 18:16:48 +08:00
RockYang
eec90274d8 feat: update video tutorial 2024-01-10 08:50:13 +08:00
RockYang
e8fff55c42 feat: update video tutorial 2024-01-10 08:48:05 +08:00
RockYang
3cf3cdd705 remove api key for hupipay 2024-01-09 17:16:27 +08:00
RockYang
9801fce659 fix: fixed bug for gorm insert record failed and Error is not nil 2024-01-08 18:10:32 +08:00
RockYang
4c1f51110b feat: change mobile field to username 2024-01-08 17:34:09 +08:00
RockYang
913d538587 add changelog 2024-01-08 12:01:58 +08:00
RockYang
9e704365fc chore: do not close pop window when click model 2024-01-08 11:01:19 +08:00
RockYang
485bdbc56a fix: function call 兼容中转 API 2024-01-07 22:32:59 +08:00
RockYang
7000168fd4 opt: add support to disable code verify 2024-01-07 17:31:26 +08:00
RockYang
5694f97a6b feat: payjs payment channel is ready 2024-01-07 14:36:02 +08:00
RockYang
b677d3fac7 fix: add user failed in admin user list page 2024-01-07 10:49:36 +08:00
RockYang
dc6719cf54 release v3.2.4 2024-01-06 21:09:19 +08:00
RockYang
7de5b55091 chore: remove useless system config items 2024-01-06 17:38:55 +08:00
RockYang
76c5101092 chore: error recover is enable ONLY in debug mode 2024-01-06 17:16:02 +08:00
RockYang
2f8d2f4854 feat: payjs service is ready 2024-01-06 15:53:30 +08:00
RockYang
b1ee34ba0c chore: rename bind username api 2024-01-05 18:21:47 +08:00
RockYang
069ad6a09a feat: email registration function is ready 2024-01-05 18:17:11 +08:00
RockYang
bf1403c818 feat: update api key last_use_time after dalle3 call 2024-01-04 18:15:00 +08:00
RockYang
bcc622a24d feat: support dall-e3 api mirrors, add name field for ApiKey 2024-01-04 16:29:57 +08:00
RockYang
a06a81a415 feat: refactor LLM api request code, get API URL from ApiKey object 2024-01-04 14:51:33 +08:00
RockYang
d1950acd01 feat: api key manage page funciton is ready 2024-01-04 10:48:04 +08:00
RockYang
039b70eed2 fix: fixed bug for concurrency risk for getting token for chat histroy with issue #92 2024-01-04 09:03:19 +08:00
RockYang
d8e4308b1b fix: add unique key for MidJourney task_id 2024-01-03 18:06:10 +08:00
RockYang
434fbb3463 feat: show notice in chat page 2024-01-03 15:19:24 +08:00
RockYang
de3eb8969c feat: fixed bug for wechat bot to parse transactions. enable user to exchange reward with img_calls 2024-01-03 11:15:54 +08:00
RockYang
fbd6eac877 fix: fixed chat export page styles 2024-01-02 11:32:36 +08:00
RockYang
1fecab177b fix: fixed for img_call repeated reductions 2024-01-01 18:54:48 +08:00
RockYang
b1b385c455 feat: add switch for enable|disable chat role 2023-12-29 17:51:56 +08:00
RockYang
3c6e86d04b feat: add nickname field for user 2023-12-29 17:39:37 +08:00
RockYang
3d2035d08a feat: add authorization for local function call 2023-12-29 17:21:29 +08:00
RockYang
da86f916d8 update changelog 2023-12-29 11:53:37 +08:00
RockYang
e7a07f7e92 feat: add router for function manager 2023-12-29 11:22:26 +08:00
RockYang
b01e6387fc fix: restore user img_calls quota when image task run failed 2023-12-29 10:41:29 +08:00
RockYang
d86aca0f5d merge pull request #72 2023-12-29 10:09:37 +08:00
RockYang
09414fe36a Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2023-12-29 09:39:52 +08:00
RockYang
df0e7508db Merge pull request #72 from Unclesimonlau/main
重新设计了移动端web页面,新增了移动端CSS,增加移动端SD绘图页面
2023-12-29 09:39:33 +08:00
RockYang
92b1f01118 merge pull request #71 2023-12-29 09:31:25 +08:00
RockYang
8fb8bd932b Merge pull request #71 from JingHong0202/main
fix: Azure Api request failure after changing the API-version parameter
2023-12-29 09:27:23 +08:00
RockYang
3f74b94784 chore: remove dead code 2023-12-29 09:02:55 +08:00
RockYang
e9467341fa feat: function manager refactor is ready 2023-12-28 18:14:38 +08:00
JingHong
131e051ddc Merge branch 'yangjian102621:main' into main 2023-12-27 23:33:46 +08:00
UncleSimonlau
f626fe3166 Merge branch 'main' of https://github.com/Unclesimonlau/chatgpt-plus 2023-12-26 14:19:25 +08:00
RockYang
6bc57b6132 fix: fixed bug #70, XunFei 1.5 url version map error 2023-12-26 14:19:00 +08:00
RockYang
d972e97c88 fix: fixed bug #70, XunFei 1.5 url version map error 2023-12-25 08:54:17 +08:00
RockYang
3991f4daec feat: function CRUD operation is ready 2023-12-24 22:12:12 +08:00
jinghong0202
f6b567d6fc fix: Azure Api 更换api-version参数后请求失败的问题 2023-12-24 08:36:34 +01:00
RockYang
8addba8203 feat: function add for admin page is ready 2023-12-23 22:30:27 +08:00
RockYang
3ab930a107 feat: support CDN reverse proxy for MidJourney and OpenAI API 2023-12-22 17:25:31 +08:00
RockYang
de512a5ea2 feat: add function list page in admin console 2023-12-21 18:06:09 +08:00
RockYang
113cfae2dc opt: optimize image compress alg, add cache control for image 2023-12-21 15:00:46 +08:00
RockYang
33aebf9cb5 feat: add funcitons manger page 2023-12-21 08:58:24 +08:00
RockYang
6e58ddf681 feat: auto translate image creating prompt 2023-12-19 18:54:19 +08:00
RockYang
cae5c049e4 fix: fixed bug for HuPiPay qrcode generation. set field 'openid' of result struct to Any data type 2023-12-19 11:31:57 +08:00
RockYang
ff76e4bd89 chore: update copyright information 2023-12-18 18:19:41 +08:00
RockYang
a0a506a3c4 feat: add remove action to remove task and images for MJ and SD task list page 2023-12-18 17:44:52 +08:00
RockYang
aa5a4a9977 opt: merge RAG branch 2023-12-18 16:41:40 +08:00
RockYang
abf4f061c1 opt: make sure the Upscale and Variation task is assign to the same mj service with Image task 2023-12-18 16:34:33 +08:00
RockYang
245cd3ee1a fix: fixed bug for mj service pool config pointer 2023-12-15 22:52:57 +08:00
343 changed files with 31620 additions and 8298 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
deploy
docs
api/static
web/node_modules
desktop

View File

@@ -1,5 +1,5 @@
name: Bug 报告 🐛
description: chatgpt-plus 提交错误报告
description: geekai 提交错误报告
labels: ['Bug']
body:
- type: checkboxes

View File

@@ -1,5 +1,5 @@
name: 功能优化 🚀
description: chatgpt-plus 提交优化建议
description: geekai 提交优化建议
labels: ['feature']
body:
- type: checkboxes

View File

@@ -1,12 +1,165 @@
# 更新日志
## v4.0.5
* 功能优化:已授权系统在后台显示授权信息
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
* 功能优化:优化首页登录注册页面的 UI
* BUG修复修复License验证的逻辑漏洞
* Bug修复后台添加用户的时候密码规则限制跟前台注册保持一致
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
* 功能新增:移动端新增 DALL-E 绘画功能
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
* 功能新增:移动支持免登录预览功能
* Bug修复解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
* Bug修复修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
## v4.0.4
* Bug修复修复统一千问第二句不回复的问题
* 功能优化MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
* 功能新增Dalle AI 绘画功能实现
* Bug修复修复思维导图格式乱码问题
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
* 功能新增:支持从应用列表直接和某个应用对话
* 功能优化优化算力日志的页面和首页的UI
* 功能新增:支持思维导图导出 PNG 图片下载
## v4.0.3
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
* Bug修复兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
* Bug修复修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
* 功能新增:支持上传图片和视觉模型
* 功能优化:优化聊天页面的复制代码按钮样式乱码
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
* 功能新增支持为角色绑定对话模型比如绑定某个角色只能用GPT3.5或者 GPT4
* 功能新增:支持为模型绑定 API KEY比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
* 功能新增:支持管理后台 Logo 修改
## 4.0.2
* 功能新增:支持前端菜单可以配置
* 功能优化:在登录和注册界面标题显示软件版本号
* 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
* Bug修复解决因为图片上传使用相对路径而导致融图失败的问题。
* 功能新增:手机端支持 Stable-Diffusion 绘画
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
## v4.0.1
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口SDAPI 兼容各种 stable-diffusion
发行版,稳定性更强一些
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API兼容
MJ-Plus 中转
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
* Bug修复修复 iphone 手机无法通过图形验证码的Bug使用滑动验证码替换
* Bug修复修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
## v4.0.0
非兼容版本重大重构引入算力概念将系统中所有的能力AI对话MJ绘画SD绘画DALL绘画全部使用算力来兑换。
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力一次 GPT4 对话消耗10个算力。一次 MJ
对话消耗15个算力...
* 功能重构:重构整体系统,全部采用算力来进行结算
* 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
* 功能优化:移动端聊天页面图片支持预览和放大功能
* 功能优化MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
* 功能优化:**PC端不登录也可以预览功能只有在发起操作的时候才需要登录**
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
* 功能新增支持H5支付
* 功能优化:支持数学公式的识别和美化输出
* 功能新增:新增算力消费日志功能
* 功能优化:整合 XXL-JOB 实现订单清理每日算力派发VIP 算力重置等任务
* 功能新增管理后台新增7日内新增用户和新增订单统计
## v3.2.7
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
* Bug修复修复 issue [
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
* 功能优化:在对话和聊天记录表中新增冗余字段 model存储对话模型
* Bug修复IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
* Bug修复重新生成按钮功能失效问题
* Bug修复对话输入HTML标签不显示的问题
* 功能优化gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
* 功能新增:新增删除文件功能
* Bug修复解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
* 功能新增:后台管理新怎对话查看和检索功能
## v3.2.6
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
* 功能优化:兼用旧版本微信收款消息解析
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
* 功能新增:新增网站公告,可以在管理后台自定义配置
* 功能新增:新增阿里通义千问大模型支持
* Bug修复修复 MJ 放大任务失败时候 img_call 会增加的 Bug
* 功能优化新增虎皮椒和PayJS订单状态校验功能增加安全性
* Bug修复修复微信转账交易 ID 提取失败 Bug
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
* 功能新增:新增短信宝短信平台发送平台集成
## v3.2.5
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
Plus 账号了!!!
* 功能优化:增强 markdown 图片和引用块解析。
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
* 功能优化function call 兼用中转 API。
* Bug修复修复部分已知的 Bug。
## v3.2.4.1
* 功能新增:新增 PayJs 支付通道
* Bug修复紧急修复后台添加用户失败问题
* Bug修复紧急修复使用中转 API-KEY 无法绘图的问题
* Bug修复允许用户关闭手机和邮箱注册通道移除验证码依赖
## v3.2.4
* 功能新增:重磅更新,支持邮箱注册
* 功能优化:优化函数调用授权
* 功能优化:给用户表新增 nickname 字段
* 功能优化:管理后台给聊天角色增加启用/禁用开关
* Bug修复SD绘画出现重复扣减绘图次数
* 功能优化:优化聊天对话导出样式,适应移动端
* 功能新增:众筹核销可以选择兑换对话还是绘图的额度
* Bug修复修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
* Bug修复修复 MidJourney 绘图任务调度Bug为 task_id 建议唯一索引
* 功能重构:重构了 API KEY模块支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
## v3.2.3
* 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
* 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
* Bug修复修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
* Bug修复确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
* 功能新增:新增删除绘图任何和图片功能
* Bug修复修复虎皮椒支付二维码重复扫码时报错问题
* 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
* 功能优化优化AI绘画的大图压缩算法新增图片缓存
* 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
* Bug修复修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
* Bug修复修复科大讯飞 V1.5 API 请求失败的问题
* Bug修复绘图失败后自动恢复用户的剩余绘图次数
* 功能新增:为移动端新增 SD 绘图功能,分享功能
## v3.2.2
* 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
* 功能新增AI绘画页面增加翻译和重写提示词功能
* 功能优化OSS上传组件支持在 Bucket 下设置二级目录
* Bug修复修复阿里云 OSS 访问路径错误
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
## v3.2.1
* 功能优化:切换角色和模型的时候自动创建新的对话
* Bug修复修复文件上传失败No such file bug
* 功能新增MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
@@ -16,6 +169,7 @@
* 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
## v3.2.0
* 功能新增:新增邀请注册功能
* 功能优化增加中间件自动对HTTP请求的参数去掉首尾空格
* 功能优化:增加中间件自动为大图片生成缩略图
@@ -25,6 +179,7 @@
* Bug修复修复MidJourney绘图失败后重复添加到队列的问题
## v3.1.9
* 功能新增:增加讯飞星火大模型 v3.0 支持
* 功能新增:新增找回密码功能
* 功能新增:支持 Markdown 代码复制功能

214
LICENSE
View File

@@ -1,21 +1,201 @@
MIT License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright (c) 2023 RockYang
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
1. Definitions.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,22 +1,29 @@
# ChatGPT-Plus
# GeekAI
### 本项目已经正式更名为 GeekAI请大家及时更新代码克隆地址。
**ChatGPT-PLUS** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。主要有如下特性:
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
* 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
* 基于 Websocket 实现,完美的打字机体验。
* 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求
* 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型
* 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用
* 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道
* 已集成支付宝支付功能,支持多种会员套餐和点卡购买功能。
* 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
主要特性:
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用
- 基于 Websocket 实现,完美的打字机体验
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型
- 支持 Suno 文生音乐
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
绘画函数插件。
## 关于部署镜像申明
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
由于目前部署人数越来越多,本人的阿里云镜像仓库流量不够支撑大家使用了。所以从 v3.2.0 版本开始,一键部署脚本和部署镜像将只提供给 **[付费技术交流群]** 内用户使用。
代码依旧是全部开源的,大家可自行编译打包镜像。
- [x] 更友好的 UI 界面
- [x] 支持 Dall-E 文生图功能
- [x] 支持文生思维导图
- [x] 支持为模型绑定指定的 API KEY支持为角色绑定指定的模型等功能
- [x] 支持网站 Logo 版权等信息的修改
## 功能截图
@@ -68,34 +75,40 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
![Mobile chat setting](/docs/imgs/mobile_user_profile.png)
![Mobile chat setting](/docs/imgs/mobile_pay.png)
### 7. 体验地址
### 体验地址
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
## 快速部署
请参考文档 [**GeekAI 快速部署**](https://ai.r9it.com/docs/install/)。
## 使用须知
1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
## 项目地址
* Github 地址https://github.com/yangjian102621/chatgpt-plus
* 码云地址https://gitee.com/blackfox/chatgpt-plus
* Github 地址https://github.com/yangjian102621/geekai
* 码云地址https://gitee.com/blackfox/geekai
## 客户端下载
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
## TODOLIST
* [ ] 支持基于知识库的 AI 问答
* [ ] 会员邀请注册推广功能
* [ ] 微信支付功能
## 项目文档
详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**GeekAI 文档**](https://ai.r9it.com/docs/)。
加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。**
@@ -123,7 +136,4 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
![打赏](docs/imgs/donate.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/chatgpt-plus&type=Date)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/geekai&type=Date)

1
api/.gitignore vendored
View File

@@ -18,4 +18,3 @@ data
config.toml
static/upload
storage.json
certs/alipay/*

View File

@@ -1,5 +1,5 @@
SHELL=/usr/bin/env bash
NAME := chatgpt-plus
NAME := geekai
all: amd64 arm64
amd64:

View File

@@ -1,6 +1,6 @@
Listen = "0.0.0.0:5678"
ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = ""
@@ -10,10 +10,6 @@ WeChatBot = false
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
MaxAge = 86400
[Manager]
Username = "admin"
Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
[Redis] # redis 配置信息
Host = "localhost"
Port = 6379
@@ -25,19 +21,28 @@ WeChatBot = false
AppId = ""
Token = ""
[SmsConfig] # 阿里云短信服务配置
AccessKey = ""
AccessSecret = ""
Product = "Dysmsapi"
Domain = "dysmsapi.aliyuncs.com"
Sign = ""
CodeTempId = ""
[SMS] # Sms 配置,用于发送短信
Active = "Ali" # 当前启用的短信服务,默认使用阿里云
[SMS.Bao]
Username = ""
Password = ""
Domain = "api.smsbao.com"
Sign = "【极客学长】"
CodeTemplate = "您的验证码是{code}。5分钟有效若非本人操作请忽略本短信。"
[SMS.Ali]
AccessKey = ""
AccessSecret = ""
Product = "Dysmsapi"
Domain = "dysmsapi.aliyuncs.com"
Sign = ""
CodeTempId = ""
[OSS] # OSS 配置,用于存储 MJ 绘画图片
Active = "local" # 默认使用本地文件存储引擎
[OSS.Local]
BasePath = "./static/upload" # 本地文件上传根路径
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件 URL 如果是线上,则直接设置为 /static/upload 即可
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
[OSS.Minio]
Endpoint = "" # 如 172.22.11.200:9000
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
@@ -51,20 +56,24 @@ WeChatBot = false
AccessSecret = ""
Bucket = ""
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
[OSS.AliYun]
Endpoint = "oss-cn-hangzhou.aliyuncs.com"
AccessKey = ""
AccessSecret = ""
Bucket = "chatgpt-plus"
SubDir = ""
Domain = ""
[[MjConfigs]]
Enabled = false
UserToken = ""
BotToken = ""
GuildId = ""
ChanelId = ""
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[[MjConfigs]]
[[MjPlusConfigs]]
Enabled = false
UserToken = ""
BotToken = ""
GuildId = ""
ChanelId = ""
ApiURL = "https://api.chat-plus.net"
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
[[SdConfigs]]
Enabled = false
@@ -72,12 +81,6 @@ WeChatBot = false
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/text2img.json"
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
@@ -95,4 +98,28 @@ WeChatBot = false
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
NotifyURL = "http://r9it.com:6004/api/payment/alipay/notify" # 支付异步回调地址
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
[HuPiPayConfig]
Enabled = false
Name = "wechat"
AppId = ""
AppSecret = ""
ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
[JPayConfig] # PayJs 支付配置
Enabled = false
Name = "wechat" # 请不要改动
AppId = "" # 商户 ID
PrivateKey = "" # 秘钥
ApiURL = "https://payjs.cn"
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的

View File

@@ -1,23 +1,29 @@
package core
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"chatplus/service/fun"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core/types"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/nfnt/resize"
"golang.org/x/image/webp"
"gorm.io/gorm"
"image"
"image/jpeg"
"io"
"log"
"net/http"
"os"
"runtime/debug"
@@ -29,31 +35,28 @@ type AppServer struct {
Debug bool
Config *types.AppConfig
Engine *gin.Engine
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache
SysConfig *types.SystemConfig // system config cache
SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
Functions map[string]fun.Function
}
func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer {
func NewServer(appConfig *types.AppConfig) *AppServer {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
return &AppServer{
Debug: false,
Config: appConfig,
Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []interface{}](),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
Functions: functions,
}
}
@@ -72,23 +75,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
}
func (s *AppServer) Run(db *gorm.DB) error {
// load chat config from database
var chatConfig model.Config
res := db.Where("marker", "chat").First(&chatConfig)
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
if err != nil {
return err
}
// load system configs
var sysConfig model.Config
res = db.Where("marker", "system").First(&sysConfig)
res := db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return res.Error
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return err
}
@@ -146,70 +139,64 @@ func corsMiddleware() gin.HandlerFunc {
// 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/mj/jobs" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/jobs" ||
strings.HasPrefix(c.Request.URL.Path, "/test/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
c.Request.URL.Path == "/api/admin/config/get" {
c.Next()
return
}
var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" {
tokenString = c.Query("token")
} else {
tokenString = c.GetHeader(types.UserAuthHeader)
}
if tokenString == "" {
resp.ERROR(c, "You should put Authorization in request headers")
c.Abort()
return
if needLogin(c) {
resp.ERROR(c, "You should put Authorization in request headers")
c.Abort()
return
} else { // 直接放行
c.Next()
return
}
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
if isAdminApi {
return []byte(s.Config.AdminSession.SecretKey), nil
} else {
return []byte(s.Config.Session.SecretKey), nil
}
return []byte(s.Config.Session.SecretKey), nil
})
if err != nil {
if err != nil && needLogin(c) {
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
if !ok || !token.Valid && needLogin(c) {
resp.NotAuth(c, "Token is invalid")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
resp.NotAuth(c, "Token is expired")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := client.Get(context.Background(), key).Result(); err != nil {
if isAdminApi {
key = fmt.Sprintf("admin/%v", claims["user_id"])
}
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
resp.NotAuth(c, "Token is not found in redis")
c.Abort()
return
@@ -218,6 +205,43 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
}
}
func needLogin(c *gin.Context) bool {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/logout" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/admin/logout" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/user/session" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" ||
c.Request.URL.Path == "/api/markMap/client" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false
}
return true
}
// 统一参数处理
func parameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
@@ -316,6 +340,10 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
@@ -332,12 +360,16 @@ func staticResourceMiddleware() gin.HandlerFunc {
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
log.Fatal(err)
logger.Error(err)
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)
c.Header("Cache-Control", "max-age=31536000, public")
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
return
c.Abort() // 中断请求
}
c.Next()
}

View File

@@ -1,10 +1,17 @@
package core
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/utils"
"os"
"github.com/BurntSushi/toml"
@@ -14,18 +21,16 @@ var logger = logger2.GetLogger()
func NewDefaultConfig() *types.AppConfig {
return &types.AppConfig{
Listen: "0.0.0.0:5678",
ProxyURL: "",
Manager: types.Manager{Username: "admin", Password: "admin123"},
StaticDir: "./static",
StaticUrl: "http://localhost/5678/static",
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
AesEncryptKey: utils.RandString(24),
Listen: "0.0.0.0:5678",
ProxyURL: "",
StaticDir: "./static",
StaticUrl: "http://localhost/5678/static",
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
Session: types.Session{
SecretKey: utils.RandString(64),
MaxAge: 86400,
},
ApiConfig: types.ChatPlusApiConfig{},
ApiConfig: types.ApiConfig{},
OSS: types.OSSConfig{
Active: "local",
Local: types.LocalStorageConfig{

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// ApiRequest API 请求实体
type ApiRequest struct {
Model string `json:"model,omitempty"` // 兼容百度文心一言
@@ -8,7 +15,13 @@ type ApiRequest struct {
Stream bool `json:"stream"`
Messages []interface{} `json:"messages,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
Functions []Function `json:"functions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"`
Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问
Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
}
type Message struct {
@@ -27,10 +40,14 @@ type ChoiceItem struct {
}
type Delta struct {
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
FunctionCall FunctionCall `json:"function_call,omitempty"`
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
} `json:"function_call,omitempty"`
}
// ChatSession 聊天会话对象
@@ -44,10 +61,15 @@ type ChatSession struct {
}
type ChatModel struct {
Id uint `json:"id"`
Platform Platform `json:"platform"`
Value string `json:"value"`
Weight int `json:"weight"`
Id uint `json:"id"`
Platform Platform `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id"` // 绑定 API KEY
}
type ApiError struct {
@@ -61,17 +83,37 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message
const MjMsg = "mj"
var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"chatglm_pro": 32768, // 清华智普
"chatglm_std": 16384,
"chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞
"general2": 8192,
// PowerType 算力日志类型
type PowerType int
const (
PowerRecharge = PowerType(1) // 充值
PowerConsume = PowerType(2) // 消费
PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
PowerInvite = PowerType(4) // 邀请奖励
PowerReward = PowerType(5) // 众筹
PowerGift = PowerType(6) // 系统赠送
)
func (t PowerType) String() string {
switch t {
case PowerRecharge:
return "充值"
case PowerConsume:
return "消费"
case PowerRefund:
return "退款"
case PowerReward:
return "众筹"
}
return "其他"
}
type PowerMark int
const (
PowerSub = PowerMark(0)
PowerAdd = PowerMark(1)
)

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"github.com/gorilla/websocket"

View File

@@ -1,60 +1,75 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
)
type AppConfig struct {
Path string `toml:"-"`
Listen string
Session Session
ProxyURL string
MysqlDns string // mysql 连接地址
Manager Manager // 后台管理员账户信息
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config
MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig
AlipayConfig AlipayConfig
HuPiPayConfig HuPiPayConfig
SmtpConfig SmtpConfig // 邮件发送配置
JPayConfig JPayConfig // payjs 支付配置
}
type ChatPlusApiConfig struct {
type SmtpConfig struct {
UseTls bool // 是否使用 TLS 发送
Host string
Port int
AppName string // 应用名称
From string // 发件人邮箱地址
Password string // 发件人邮箱密码
}
type ApiConfig struct {
ApiURL string
AppId string
Token string
}
type MidJourneyConfig struct {
Enabled bool
UserToken string
BotToken string
GuildId string // Server ID
ChanelId string // Chanel ID
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
ApiURL string
ApiKey string
Txt2ImgJsonPath string
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type AliYunSmsConfig struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct {
@@ -67,6 +82,7 @@ type AlipayConfig struct {
AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
@@ -74,8 +90,20 @@ type HuPiPayConfig struct { //虎皮椒第四方支付配置
Name string // 支付名称wechat/alipay
AppId string // App ID
AppSecret string // app 密钥
ApiURL string // 支付网关
NotifyURL string // 异步通知回调
PayURL string // 支付网关
ReturnURL string // 支付成功返回地址
}
// JPayConfig PayJs 支付配置
type JPayConfig struct {
Enabled bool
Name string // 支付名称,默认 wechat
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步回调地址
ReturnURL string // 支付成功返回地址
}
type XXLConfig struct { // XXL 任务调度配置
@@ -94,31 +122,21 @@ type RedisConfig struct {
DB int
}
// LicenseKey 存储许可证书的 KEY
const LicenseKey = "Geek-AI-License"
type License struct {
Key string `json:"key"` // 许可证书密钥
MachineId string `json:"machine_id"` // 机器码
UserNum int `json:"user_num"` // 用户数量
ExpiredAt int64 `json:"expired_at"` // 过期时间
IsActive bool `json:"is_active"` // 是否激活
}
func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
}
// ChatConfig 系统默认的聊天配置
type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
XunFei ModelAPIConfig `json:"xun_fei"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
ContextDeep int `json:"context_deep"` // 上下文深度
DallApiURL string `json:"dall_api_url"` // dall-e3 绘图 API 地址
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
}
type Platform string
const OpenAI = Platform("OpenAI")
@@ -126,42 +144,35 @@ const Azure = Platform("Azure")
const ChatGLM = Platform("ChatGLM")
const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
// UserChatConfig 用户的聊天配置
type UserChatConfig struct {
ApiKeys map[Platform]string `json:"api_keys"`
}
type InviteReward struct {
ChatCalls int `json:"chat_calls"`
ImgCalls int `json:"img_calls"`
}
type ModelAPIConfig struct {
ApiURL string `json:"api_url,omitempty"`
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
ApiKey string `json:"api_key"`
}
const QWen = Platform("QWen")
type SystemConfig struct {
Title string `json:"title"`
AdminTitle string `json:"admin_title"`
Models []string `json:"models"`
InitChatCalls int `json:"init_chat_calls"` // 新用户注册赠送对话次数
InitImgCalls int `json:"init_img_calls"` // 新用户注册赠送绘图次数
VipMonthCalls int `json:"vip_month_calls"` // VIP 会员每月赠送的对话次数
VipMonthImgCalls int `json:"vip_month_img_calls"` // VIP 会员每月赠送绘图次数
EnabledRegister bool `json:"enabled_register"` // 是否启用注册功能,关闭注册功能之后将无法注册
EnabledMsg bool `json:"enabled_msg"` // 是否启用短信验证码服务
RewardImg string `json:"reward_img"` // 众筹收款二维码地址
EnabledFunction bool `json:"enabled_function"` // 启用 API 函数功能
EnabledReward bool `json:"enabled_reward"` // 启用众筹功能
EnabledAlipay bool `json:"enabled_alipay"` // 是否启用支付宝支付通道
OrderPayTimeout int `json:"order_pay_timeout"` //订单支付超时时间
DefaultModels []string `json:"default_models"` // 默认开通的 AI 模型
OrderPayInfoText string `json:"order_pay_info_text"` // 订单支付页面说明文字
InviteChatCalls int `json:"invite_chat_calls"` // 邀请用户注册奖励对话次数
InviteImgCalls int `json:"invite_img_calls"` // 邀请用户注册奖励绘图次数
ForceInvite bool `json:"force_invite"` // 是否强制必须使用邀请码才能注册
Title string `json:"title,omitempty"`
AdminTitle string `json:"admin_title,omitempty"`
Logo string `json:"logo,omitempty"`
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"`
}

View File

@@ -1,92 +1,28 @@
package types
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type ToolCall struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function Function `json:"function"`
}
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters Parameters `json:"parameters"`
}
type Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]Property `json:"properties"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
const (
FuncZaoBao = "zao_bao" // 每日早报
FuncHeadLine = "headline" // 今日头条
FuncWeibo = "weibo_hot" // 微博热搜
FuncImage = "draw_image" // AI 绘画
)
var InnerFunctions = []Function{
{
Name: FuncZaoBao,
Description: "每日早报,获取当天全球的热门新闻事件列表",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: FuncWeibo,
Description: "新浪微博热搜榜,微博当日热搜榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: FuncHeadLine,
Description: "今日头条,给用户推荐当天的头条新闻,周榜热文",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: FuncImage,
Description: "AI 绘画工具,根据输入的绘图描述用 AI 工具进行绘画",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"prompt": {
Type: "string",
Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。",
},
},
Required: []string{},
},
},
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
Required interface{} `json:"required,omitempty"`
}

View File

@@ -1,15 +1,22 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"sync"
)
type MKey interface {
string | int
string | int | uint
}
type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{}
*WsClient | *ChatSession | context.CancelFunc | []Message
}
type LMap[K MKey, T MValue] struct {
lock sync.RWMutex

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OrderStatus int
const (
@@ -9,10 +16,9 @@ const (
)
type OrderRemark struct {
Days int `json:"days"` // 有效期
Calls int `json:"calls"` // 增加对话次
ImgCalls int `json:"img_calls"` // 增加绘图次数
Name string `json:"name"` // 产品名称
Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
Discount float64 `json:"discount"`
}

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OSSConfig struct {
Active string
Local LocalStorageConfig

View File

@@ -1,11 +1,17 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const LoginUserID = "LOGIN_USER_ID"
const LoginUserCache = "LOGIN_USER_CACHE"
const UserAuthHeader = "Authorization"
const AdminAuthHeader = "Admin-Authorization"
const ChatTokenHeader = "Chat-Token"
// Session configs struct
type Session struct {

33
api/core/types/sms.go Normal file
View File

@@ -0,0 +1,33 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SMSConfig struct {
Active string
Ali SmsConfigAli
Bao SmsConfigBao
}
// SmsConfigAli 阿里云短信平台配置
type SmsConfigAli struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
}
// SmsConfigBao 短信宝平台配置
type SmsConfigBao struct {
Username string //短信宝平台注册的用户名
Password string //短信宝平台注册的密码
Domain string //域名
Sign string // 短信签名
CodeTemplate string // 验证码短信模板 匹配
}

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// TaskType 任务类别
type TaskType string
@@ -9,17 +16,24 @@ func (t TaskType) String() string {
const (
TaskImage = TaskType("image")
TaskBlend = TaskType("blend")
TaskSwapFace = TaskType("swapFace")
TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation")
)
// MjTask MidJourney 任务
type MjTask struct {
Id int `json:"id"`
Id uint `json:"id"`
TaskId string `json:"task_id"`
ImgArr []string `json:"img_arr"`
ChannelId string `json:"channel_id"`
SessionId string `json:"session_id"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
NegPrompt string `json:"neg_prompt,omitempty"`
Params string `json:"full_prompt"`
Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
@@ -31,25 +45,37 @@ type SdTask struct {
SessionId string `json:"session_id"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"`
}
type SdTaskParams struct {
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
}
// DallTask DALL-E task
type DallTask struct {
JobId uint `json:"job_id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
}

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// BizVo 业务返回 VO
type BizVo struct {
Code BizCode `json:"code"`
@@ -21,7 +28,7 @@ const (
WsStart = WsMsgType("start")
WsMiddle = WsMsgType("middle")
WsEnd = WsMsgType("end")
WsMjImg = WsMsgType("mj")
WsErr = WsMsgType("error")
)
type BizCode int
@@ -30,6 +37,7 @@ const (
Success = BizCode(0)
Failed = BizCode(1)
NotAuthorized = BizCode(400) // 未授权
NotPermission = BizCode(403) // 没有权限
OkMsg = "Success"
ErrorMsg = "系统开小差了"

View File

@@ -1,4 +1,4 @@
module chatplus
module geekai
go 1.19
@@ -6,7 +6,6 @@ require (
github.com/BurntSushi/toml v1.1.0
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
github.com/bwmarrin/discordgo v0.27.1
github.com/eatmoreapple/openwechat v1.2.1
github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5
@@ -26,6 +25,23 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0
require (
github.com/mojocn/base64Captcha v1.3.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
)
require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
@@ -96,6 +112,6 @@ require (
go.uber.org/fx v1.19.3
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/crypto v0.12.0
golang.org/x/sys v0.11.0 // indirect
golang.org/x/sys v0.15.0 // indirect
gorm.io/gorm v1.25.1
)

View File

@@ -7,8 +7,6 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -29,6 +27,7 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
@@ -41,6 +40,8 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
@@ -65,10 +66,15 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -76,7 +82,6 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -84,6 +89,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -129,12 +135,17 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
@@ -166,6 +177,10 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
@@ -190,6 +205,12 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
@@ -202,6 +223,8 @@ github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onq
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
@@ -219,7 +242,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
@@ -227,10 +249,14 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -239,13 +265,16 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -257,8 +286,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -290,12 +319,15 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -1,14 +1,26 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/mojocn/base64Captcha"
"time"
"github.com/gin-gonic/gin"
@@ -17,47 +29,88 @@ import (
var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
h := ManagerHandler{db: db, redis: client}
h.App = app
return &h
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager
var data Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
manager := h.App.Config.Manager
if data.Username == manager.Username && data.Password == manager.Password {
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Username,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := "users/" + manager.Username
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, tokenString)
} else {
resp.ERROR(c, "用户名或者密码错误")
// add captcha
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
resp.ERROR(c, "验证码错误!")
return
}
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
return
}
password := utils.GenPassword(data.Password, manager.Salt)
if password != manager.Password {
resp.ERROR(c, "用户名或密码错误")
return
}
// 超级管理员默认是ID:1
if manager.Id != SuperManagerID && manager.Status == false {
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
return
}
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("admin/%d", manager.Id)
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
// 更新最后登录时间和IP
manager.LastLoginIp = c.ClientIP()
manager.LastLoginAt = time.Now().Unix()
h.DB.Updates(&manager)
var result = struct {
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
}{
IsSuperAdmin: manager.Id == 1,
Token: tokenString,
}
resp.SUCCESS(c, result)
}
// Logout 注销
@@ -72,10 +125,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
token := c.GetHeader(types.AdminAuthHeader)
if token == "" {
id := h.GetLoginUserId(c)
key := fmt.Sprintf("admin/%d", id)
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c)
} else {
resp.SUCCESS(c)
return
}
var manager model.AdminUser
res := h.DB.Where("id", id).First(&manager)
if res.Error != nil {
resp.NotAuth(c)
return
}
resp.SUCCESS(c, manager)
}
// List 数据列表
func (h *ManagerHandler) List(c *gin.Context) {
var items []model.AdminUser
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
users := make([]vo.AdminUser, 0)
for _, item := range items {
var u vo.AdminUser
err := utils.CopyObject(item, &u)
if err != nil {
continue
}
u.Id = item.Id
u.CreatedAt = item.CreatedAt.Unix()
users = append(users, u)
}
resp.SUCCESS(c, users)
}
func (h *ManagerHandler) Save(c *gin.Context) {
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Status bool `json:"status"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("username", data.Username).First(&user)
if res.Error == nil {
resp.ERROR(c, "用户名已存在")
return
}
// 生成密码
salt := utils.RandString(8)
password := utils.GenPassword(data.Password, salt)
res = h.DB.Save(&model.AdminUser{
Username: data.Username,
Password: password,
Salt: salt,
Status: data.Status,
})
if res.Error != nil {
resp.ERROR(c, "failed with update database")
return
}
resp.SUCCESS(c)
}
// Remove 删除管理员
func (h *ManagerHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id == SuperManagerID {
resp.ERROR(c, "超级管理员不能删除")
return
}
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用/禁用
func (h *ManagerHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// ResetPass 重置密码
func (h *ManagerHandler) ResetPass(c *gin.Context) {
id := h.GetLoginUserId(c)
if id != SuperManagerID {
resp.ERROR(c, "只有超级管理员能够进行该操作")
return
}
var data struct {
Id int `json:"id"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -1,34 +1,43 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ApiKeyHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db}
h.App = app
return &h
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
}
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"`
ProxyURL string `json:"proxy_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -37,12 +46,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{}
if data.Id > 0 {
h.db.Find(&apiKey, data.Id)
h.DB.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
apiKey.Type = data.Type
res := h.db.Save(&apiKey)
apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled
apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name
res := h.DB.Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -60,9 +73,20 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
status := h.GetBool(c, "status")
t := h.GetTrim(c, "type")
session := h.DB.Session(&gorm.Session{})
if status {
session = session.Where("enabled", true)
}
if t != "" {
session = session.Where("type", t)
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := h.db.Find(&items)
res := session.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
@@ -80,15 +104,37 @@ func (h *ApiKeyHandler) List(c *gin.Context) {
resp.SUCCESS(c, keys)
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
func (h *ApiKeyHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,46 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/handler"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)
type CaptchaHandler struct {
handler.BaseHandler
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {
CaptchaId string `json:"captcha_id"`
PicPath string `json:"pic_path"`
}
// GetCaptcha 获取验证码
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
var captchaVo CaptchaVo
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
// b64s是图片的base64编码
id, b64s, err := cp.Generate()
if err != nil {
resp.ERROR(c, "生成验证码错误!")
return
}
captchaVo.CaptchaId = id
captchaVo.PicPath = b64s
resp.SUCCESS(c, captchaVo)
}

View File

@@ -0,0 +1,268 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatHandler struct {
handler.BaseHandler
}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type chatItemVo struct {
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Role vo.ChatRole `json:"role"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
}
func (h *ChatHandler) List(c *gin.Context) {
var data struct {
Title string `json:"title"`
UserId uint `json:"user_id"`
Model string `json:"model"`
CreateAt []string `json:"created_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Title != "" {
session = session.Where("title LIKE ?", "%"+data.Title+"%")
}
if data.UserId > 0 {
session = session.Where("user_id = ?", data.UserId)
}
if data.Model != "" {
session = session.Where("model = ?", data.Model)
}
if len(data.CreateAt) == 2 {
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.ChatItem{}).Count(&total)
var items []model.ChatItem
var list = make([]chatItemVo, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
userIds := make([]uint, 0)
chatIds := make([]string, 0)
roleIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
chatIds = append(chatIds, item.ChatId)
roleIds = append(roleIds, item.RoleId)
}
var messages []model.ChatMessage
var users []model.User
var roles []model.ChatRole
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles)
tokenMap := make(map[string]int)
userMap := make(map[uint]string)
msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole)
for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1
}
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, r := range roles {
var roleVo vo.ChatRole
err := utils.CopyObject(r, &roleVo)
if err != nil {
continue
}
roleMap[r.Id] = roleVo
}
for _, item := range items {
list = append(list, chatItemVo{
UserId: item.UserId,
Username: userMap[item.UserId],
ChatId: item.ChatId,
Title: item.Title,
Model: item.Model,
Token: tokenMap[item.ChatId],
MsgNum: msgMap[item.ChatId],
Role: roleMap[item.RoleId],
CreatedAt: item.CreatedAt.Unix(),
})
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
type chatMessageVo struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Username string `json:"username"`
Content string `json:"content"`
Type string `json:"type"`
Model string `json:"model"`
Token int `json:"token"`
Icon string `json:"icon"`
CreatedAt int64 `json:"created_at"`
}
// Messages 读取聊天记录列表
func (h *ChatHandler) Messages(c *gin.Context) {
var data struct {
UserId uint `json:"user_id"`
Content string `json:"content"`
Model string `json:"model"`
CreateAt []string `json:"created_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Content != "" {
session = session.Where("content LIKE ?", "%"+data.Content+"%")
}
if data.UserId > 0 {
session = session.Where("user_id = ?", data.UserId)
}
if data.Model != "" {
session = session.Where("model = ?", data.Model)
}
if len(data.CreateAt) == 2 {
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.ChatMessage{}).Count(&total)
var items []model.ChatMessage
var list = make([]chatMessageVo, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
userIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
userMap := make(map[uint]string)
for _, user := range users {
userMap[user.Id] = user.Username
}
for _, item := range items {
list = append(list, chatMessageVo{
Id: item.Id,
UserId: item.UserId,
Username: userMap[item.UserId],
Content: item.Content,
Model: item.Model,
Token: item.Tokens,
Icon: item.Icon,
Type: item.Type,
CreatedAt: item.CreatedAt.Unix(),
})
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
} else {
for _, item := range items {
var v vo.HistoryMessage
err := utils.CopyObject(item, &v)
v.CreatedAt = item.CreatedAt.Unix()
v.UpdatedAt = item.UpdatedAt.Unix()
if err == nil {
messages = append(messages, v)
}
}
}
resp.SUCCESS(c, messages)
}
// RemoveChat 删除对话
func (h *ChatHandler) RemoveChat(c *gin.Context) {
chatId := h.GetTrim(c, "chat_id")
if chatId == "" {
resp.ERROR(c, "请传入 ChatId")
return
}
tx := h.DB.Begin()
// 删除聊天记录
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
if res.Error != nil {
resp.ERROR(c, "failed to remove chat message")
return
}
// 删除对话
res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{})
if res.Error != nil {
tx.Rollback() // 回滚
resp.ERROR(c, "failed to remove chat")
return
}
tx.Commit()
resp.SUCCESS(c)
}
// RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -1,40 +1,47 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type ChatModelHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Weight int `json:"weight"`
CreatedAt int64 `json:"created_at"`
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"`
Open bool `json:"open"`
Platform string `json:"platform"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id,omitempty"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -42,18 +49,24 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
}
item := model.ChatModel{
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
Weight: data.Weight}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
KeyId: data.KeyId,
Power: data.Power}
var res *gorm.DB
if data.Id > 0 {
item.Id = data.Id
res = h.DB.Select("*").Omit("created_at").Updates(&item)
} else {
res = h.DB.Create(&item)
}
res := h.db.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -72,7 +85,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
@@ -80,18 +93,33 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var cms = make([]vo.ChatModel, 0)
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel
err := utils.CopyObject(item, &cm)
if err == nil {
cm.Id = item.Id
cm.CreatedAt = item.CreatedAt.Unix()
cm.UpdatedAt = item.UpdatedAt.Unix()
cms = append(cms, cm)
} else {
logger.Error(err)
}
if res.Error != nil {
resp.SUCCESS(c, cms)
return
}
// initialize key name
keyIds := make([]int, 0)
for _, v := range items {
keyIds = append(keyIds, v.KeyId)
}
var keys []model.ApiKey
keyMap := make(map[uint]string)
h.DB.Where("id IN ?", keyIds).Find(&keys)
for _, v := range keys {
keyMap[v.Id] = v.Name
}
for _, item := range items {
var cm vo.ChatModel
err := utils.CopyObject(item, &cm)
if err == nil {
cm.Id = item.Id
cm.CreatedAt = item.CreatedAt.Unix()
cm.UpdatedAt = item.UpdatedAt.Unix()
cm.KeyName = keyMap[uint(item.KeyId)]
cms = append(cms, cm)
} else {
logger.Error(err)
}
}
resp.SUCCESS(c, cms)
@@ -109,7 +137,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -129,7 +157,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -141,13 +169,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
func (h *ChatModelHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -1,27 +1,32 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type ChatRoleHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
h := ChatRoleHandler{db: db}
h.App = app
return &h
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或者更新某个角色
@@ -41,7 +46,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&role)
res := h.DB.Save(&role)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -55,12 +60,31 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
func (h *ChatRoleHandler) List(c *gin.Context) {
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.db.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
}
// initialize model mane for role
modelIds := make([]int, 0)
for _, v := range items {
if v.ModelId > 0 {
modelIds = append(modelIds, v.ModelId)
}
}
modelNameMap := make(map[int]string)
if len(modelIds) > 0 {
var models []model.ChatModel
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
if tx.Error == nil {
for _, m := range models {
modelNameMap[int(m.Id)] = m.Name
}
}
}
for _, v := range items {
var role vo.ChatRole
err := utils.CopyObject(v, &role)
@@ -68,6 +92,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
role.Id = v.Id
role.CreatedAt = v.CreatedAt.Unix()
role.UpdatedAt = v.UpdatedAt.Unix()
role.ModelName = modelNameMap[role.ModelId]
roles = append(roles, role)
}
}
@@ -88,7 +113,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -98,14 +123,39 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
func (h *ChatRoleHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
if res.Error != nil {
resp.ERROR(c, "删除失败!")
return

View File

@@ -1,49 +1,62 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
)
type ConfigHandler struct {
handler.BaseHandler
db *gorm.DB
levelDB *store.LevelDB
licenseService *service.LicenseService
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db}
h.App = app
return &h
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB, licenseService: licenseService}
}
func (h *ConfigHandler) Update(c *gin.Context) {
var data struct {
Key string `json:"key"`
Config map[string]interface{} `json:"config"`
Key string `json:"key"`
Config struct {
types.SystemConfig
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
str := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: str}
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value}
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
if config.Id > 0 {
config.Config = str
res := h.db.Updates(&config)
config.Config = value
res := h.DB.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -51,12 +64,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// update config cache for AppServer
var cfg model.Config
h.db.Where("marker", data.Key).First(&cfg)
h.DB.Where("marker", data.Key).First(&cfg)
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
} else if data.Key == "chat" {
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
}
if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error())
@@ -72,18 +83,48 @@ func (h *ConfigHandler) Update(c *gin.Context) {
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.db.Where("marker", key).First(&config)
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var m map[string]interface{}
err := utils.JsonDecode(config.Config, &m)
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, m)
resp.SUCCESS(c, value)
}
// Active 激活系统
func (h *ConfigHandler) Active(c *gin.Context) {
var data struct {
License string `json:"license"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
info, err := host.Info()
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.licenseService.ActiveLicense(data.License, info.HostID)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, info.HostID)
}
// GetLicense 获取 License 信息
func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license)
}

View File

@@ -1,32 +1,38 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"time"
)
type DashboardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
h := DashboardHandler{db: db}
h.App = app
return &h
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type statsVo struct {
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"`
}
func (h *DashboardHandler) Stats(c *gin.Context) {
@@ -35,37 +41,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
var userCount int64
now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil {
stats.Users = userCount
}
// new chats statistic
var chatCount int64
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
if res.Error == nil {
stats.Chats = chatCount
}
// tokens took stats
var historyMessages []model.HistoryMessage
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages)
var historyMessages []model.ChatMessage
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
for _, item := range historyMessages {
stats.Tokens += item.Tokens
}
// 众筹收入
var rewards []model.Reward
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards)
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入
var orders []model.Order
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders {
stats.Income += item.Amount
}
// 统计7天的订单的图表
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
var statsChart = make(map[string]map[string]float64)
//// 初始化
var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
for i := 0; i < 7; i++ {
var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
userStatistic[initTime] = float64(0)
historyMessagesStatistic[initTime] = float64(0)
incomeStatistic[initTime] = float64(0)
}
// 统计用户7天增加的曲线
var users []model.User
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
if res.Error == nil {
for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
}
}
// 统计7天Token 消耗
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
for _, item := range historyMessages {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
}
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
statsChart["users"] = userStatistic
statsChart["historyMessage"] = historyMessagesStatistic
statsChart["orders"] = incomeStatistic
stats.Chart = statsChart
resp.SUCCESS(c, stats)
}

View File

@@ -0,0 +1,128 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/golang-jwt/jwt/v5"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type FunctionHandler struct {
handler.BaseHandler
}
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *FunctionHandler) Save(c *gin.Context) {
var data vo.Function
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var f = model.Function{
Id: data.Id,
Name: data.Name,
Label: data.Label,
Description: data.Description,
Parameters: utils.JsonEncode(data.Parameters),
Action: data.Action,
Token: data.Token,
Enabled: data.Enabled,
}
res := h.DB.Save(&f)
if res.Error != nil {
resp.ERROR(c, "error with save data:"+res.Error.Error())
return
}
data.Id = f.Id
resp.SUCCESS(c, data)
}
func (h *FunctionHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
}
functions := make([]vo.Function, 0)
for _, v := range items {
var f vo.Function
err := utils.CopyObject(v, &f)
if err != nil {
continue
}
functions = append(functions, f)
}
resp.SUCCESS(c, functions)
}
func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}
// GenToken generate function api access token
func (h *FunctionHandler) GenToken(c *gin.Context) {
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": 0,
"expired": 0,
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
logger.Error("error with generate token", err)
resp.ERROR(c)
return
}
resp.SUCCESS(c, tokenString)
}

View File

@@ -0,0 +1,128 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MenuHandler struct {
handler.BaseHandler
}
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *MenuHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
URL string `json:"url"`
SortNum int `json:"sort_num"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Save(&model.Menu{
Id: data.Id,
Name: data.Name,
Icon: data.Icon,
URL: data.URL,
SortNum: data.SortNum,
Enabled: data.Enabled,
})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu
err := utils.CopyObject(item, &product)
if err == nil {
list = append(list, product)
}
}
}
resp.SUCCESS(c, list)
}
func (h *MenuHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *MenuHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
for index, id := range data.Ids {
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}
func (h *MenuHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Menu{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -1,31 +1,37 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OrderHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
var data struct {
OrderNo string `json:"order_no"`
Status int `json:"status"`
PayTime []string `json:"pay_time"`
Page int `json:"page"`
PageSize int `json:"page_size"`
@@ -35,7 +41,7 @@ func (h *OrderHandler) List(c *gin.Context) {
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if data.OrderNo != "" {
session = session.Where("order_no", data.OrderNo)
}
@@ -44,8 +50,9 @@ func (h *OrderHandler) List(c *gin.Context) {
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
}
session = session.Where("status = ?", types.OrderPaidSuccess)
if data.Status >= 0 {
session = session.Where("status", data.Status)
}
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order
@@ -74,7 +81,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
if id > 0 {
var item model.Order
res := h.db.First(&item, id)
res := h.DB.First(&item, id)
if res.Error != nil {
resp.ERROR(c, "记录不存在!")
return
@@ -85,7 +92,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return
}
res = h.db.Where("id = ?", id).Delete(&model.Order{})
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -0,0 +1,84 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
handler.BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`
Type int `json:"type"`
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Model != "" {
session = session.Where("model", data.Model)
}
if data.Type > 0 {
session = session.Where("type", data.Type)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
// 统计消费算力总和
var totalPower float64
if len(data.Date) == 2 {
session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
}
resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
}

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
@@ -15,13 +22,10 @@ import (
type ProductHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ProductHandler) Save(c *gin.Context) {
@@ -32,8 +36,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
Discount float64 `json:"discount"`
Enabled bool `json:"enabled"`
Days int `json:"days"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -46,14 +49,13 @@ func (h *ProductHandler) Save(c *gin.Context) {
Price: data.Price,
Discount: data.Discount,
Days: data.Days,
Calls: data.Calls,
ImgCalls: data.ImgCalls,
Power: data.Power,
Enabled: data.Enabled}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -70,16 +72,11 @@ func (h *ProductHandler) Save(c *gin.Context) {
resp.SUCCESS(c, itemVo)
}
// List 模型列表
// List 数据列表
func (h *ProductHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
}
var items []model.Product
var list = make([]vo.Product, 0)
res := session.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product
@@ -108,7 +105,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return
}
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -128,7 +125,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -142,7 +139,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.Product{})
res := h.DB.Where("id", id).Delete(&model.Product{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -1,30 +1,35 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RewardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db}
h.App = app
return &h
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
var items []model.Reward
res := h.db.Order("id DESC").Find(&items)
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
@@ -32,7 +37,7 @@ func (h *RewardHandler) List(c *gin.Context) {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.db.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
@@ -46,7 +51,7 @@ func (h *RewardHandler) List(c *gin.Context) {
}
r.Id = v.Id
r.Username = userMap[v.UserId].Mobile
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
r.UpdatedAt = v.UpdatedAt.Unix()
rewards = append(rewards, r)
@@ -55,3 +60,21 @@ func (h *RewardHandler) List(c *gin.Context) {
resp.SUCCESS(c, rewards)
}
func (h *RewardHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,52 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/handler"
"geekai/service/oss"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
handler.BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
userId := 0
res := h.DB.Create(&model.File{
UserId: userId,
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}

View File

@@ -1,42 +1,51 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type UserHandler struct {
handler.BaseHandler
db *gorm.DB
licenseService *service.LicenseService
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
h := UserHandler{db: db}
h.App = app
return &h
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
mobile := h.GetTrim(c, "mobile")
username := h.GetTrim(c, "username")
offset := (page - 1) * pageSize
var items []model.User
var users = make([]vo.User, 0)
var total int64
session := h.db.Session(&gorm.Session{})
if mobile != "" {
session = session.Where("mobile LIKE ?", "%"+mobile+"%")
session := h.DB.Session(&gorm.Session{})
if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%")
}
session.Model(&model.User{}).Count(&total)
@@ -63,57 +72,83 @@ func (h *UserHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Password string `json:"password"`
Mobile string `json:"mobile"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
Username string `json:"username"`
ChatRoles []string `json:"chat_roles"`
ChatModels []string `json:"chat_models"`
ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"`
Status bool `json:"status"`
Vip bool `json:"vip"`
Power int `json:"power"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检测最大注册人数
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if int(totalUser) >= h.licenseService.GetLicense().UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
var user = model.User{}
var res *gorm.DB
var userVo vo.User
if data.Id > 0 { // 更新
user.Id = data.Id
// 此处需要用 map 更新,用结构体无法更新 0 值
res = h.db.Model(&user).Updates(map[string]interface{}{
"mobile": data.Mobile,
"calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status,
"vip": data.Vip,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
"chat_models_json": utils.JsonEncode(data.ChatModels),
"expired_time": utils.Str2stamp(data.ExpiredTime),
})
res = h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, "user not found")
return
}
var oldPower = user.Power
user.Username = data.Username
user.Status = data.Status
user.Vip = data.Vip
user.Power = data.Power
user.ChatRoles = utils.JsonEncode(data.ChatRoles)
user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力日志
if oldPower != user.Power {
mark := types.PowerAdd
amount := user.Power - oldPower
if oldPower > user.Power {
mark = types.PowerSub
amount = oldPower - user.Power
}
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerGift,
Amount: amount,
Balance: user.Power,
Mark: mark,
Model: "管理员",
Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID%d", oldPower, user.Power, h.GetLoginUserId(c)),
CreatedAt: time.Now(),
})
}
} else {
salt := utils.RandString(8)
u := model.User{
Mobile: data.Mobile,
Username: data.Username,
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Power: data.Power,
Status: true,
ChatRoles: utils.JsonEncode(data.ChatRoles),
ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: data.Calls,
ImgCalls: data.ImgCalls,
}
res = h.db.Create(&u)
res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id
userVo.CreatedAt = u.CreatedAt.Unix()
@@ -140,7 +175,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.First(&user, data.Id)
res := h.DB.First(&user, data.Id)
if res.Error != nil {
resp.ERROR(c, "No user found")
return
@@ -148,7 +183,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -158,36 +193,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
tx := h.db.Begin()
res := h.db.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除聊天历史记录
res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除登录日志
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c)
}
@@ -195,10 +226,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
var total int64
h.db.Model(&model.UserLoginLog{}).Count(&total)
h.DB.Model(&model.UserLoginLog{}).Count(&total)
offset := (page - 1) * pageSize
var items []model.UserLoginLog
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "获取数据失败")
return

View File

@@ -1,11 +1,21 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"geekai/core"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils"
"errors"
"fmt"
"gorm.io/gorm"
"strings"
"github.com/gin-gonic/gin"
@@ -15,6 +25,7 @@ var logger = logger2.GetLogger()
type BaseHandler struct {
App *core.AppServer
DB *gorm.DB
}
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
@@ -57,3 +68,27 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
}
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0
}
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
value, exists := c.Get(types.LoginUserCache)
if exists {
return value.(model.User), nil
}
userId, ok := c.Get(types.LoginUserID)
if !ok {
return model.User{}, errors.New("user not login")
}
var user model.User
res := h.DB.First(&user, userId)
// 更新缓存
if res.Error == nil {
c.Set(types.LoginUserCache, user)
}
return user, res.Error
}

View File

@@ -1,9 +1,16 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/utils/resp"
"geekai/core/types"
"geekai/service"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
)
@@ -45,3 +52,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
}
}
// SlideGet 获取滑动验证图片
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
data, err := h.service.SlideGet()
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// SlideCheck 滑动验证结果校验
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
var data struct {
Key string `json:"key"`
X int `json:"x"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.service.SlideCheck(data) {
resp.SUCCESS(c)
} else {
resp.ERROR(c)
}
}

View File

@@ -1,48 +1,52 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatModelHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
// 只加载用户订阅的 AI 模型
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
var res *gorm.DB
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
err := utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
}
var models []string
err = utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res := h.db.Where("enabled = ?", true).Where(
h.db.Where("value IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -14,27 +21,26 @@ import (
type ChatRoleHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
handler := &ChatRoleHandler{db: db}
handler.App = app
return handler
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List get user list
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
var roleVos = make([]vo.ChatRole, 0)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error())
resp.SUCCESS(c, roleVos)
return
}
// 获取所有角色
if all {
if userId == 0 || all {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
@@ -49,21 +55,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.NotAuth(c)
return
}
var user model.User
h.db.First(&user, userId)
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return
}
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
if !utils.ContainsStr(roleKeys, r.Key) {
continue
@@ -80,7 +80,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// UpdateRole 更新用户聊天角色
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
@@ -94,7 +94,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return
}
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil {
logger.Error("添加应用失败:", err)
resp.ERROR(c, "更新数据库失败!")

View File

@@ -1,15 +1,21 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"gorm.io/gorm"
"html/template"
"io"
"strings"
@@ -20,7 +26,7 @@ import (
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -30,8 +36,8 @@ func (h *ChatHandler) sendAzureMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
@@ -57,9 +63,6 @@ func (h *ChatHandler) sendAzureMessage(
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var functionCall = false
var functionName string
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
@@ -69,34 +72,17 @@ func (h *ChatHandler) sendAzureMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
}
fun := responseBody.Choices[0].Delta.FunctionCall
if functionCall && fun.Name == "" {
arguments = append(arguments, fun.Arguments)
if len(responseBody.Choices) == 0 {
continue
}
if !utils.IsEmptyValue(fun) {
functionName = fun.Name
f := h.App.Functions[functionName]
if f != nil {
functionCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
continue
}
}
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
@@ -122,53 +108,8 @@ func (h *ChatHandler) sendAzureMessage(
}
}
if functionCall { // 调用函数完成任务
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
// for creating image, check if the user's img_calls > 0
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
utils.ReplyMessage(ws, ErrImg)
} else {
f := h.App.Functions[functionName]
if functionName == types.FuncImage {
params["user_id"] = userVo.Id
params["role_id"] = role.Id
params["chat_id"] = session.ChatId
params["icon"] = "/images/avatar/mid_journey.png"
params["session_id"] = session.SessionId
}
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncImage {
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景", params["prompt"])
// update user's img_calls
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
}
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
@@ -177,77 +118,64 @@ func (h *ChatHandler) sendAzureMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && functionCall == false {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -258,7 +186,8 @@ func (h *ChatHandler) sendAzureMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -1,11 +1,18 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
@@ -36,7 +43,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -46,8 +53,8 @@ func (h *ChatHandler) sendBaiduMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
@@ -128,9 +135,6 @@ func (h *ChatHandler) sendBaiduMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -138,63 +142,63 @@ func (h *ChatHandler) sendBaiduMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -205,7 +209,8 @@ func (h *ChatHandler) sendBaiduMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -1,50 +1,68 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"net/url"
"strings"
"time"
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
const ErrImg = "![](/images/wx.png)"
var ErrImg = "![](/images/wx.png)"
var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
h := ChatHandler{
db: db,
redis: redis,
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
}
h.App = app
return &h
}
var chatConfig types.ChatConfig
func (h *ChatHandler) Init() {
// 如果后台有上传微信客服微信二维码,则覆盖
if h.App.SysConfig.WechatCardURL != "" {
ErrImg = fmt.Sprintf("![](%s)", h.App.SysConfig.WechatCardURL)
}
}
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
@@ -60,9 +78,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
modelId := h.GetInt(c, "model_id", 0)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
modelId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
res := h.db.First(&chatModel, modelId)
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -71,7 +100,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
@@ -80,7 +109,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session = &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
Username: user.Mobile,
Username: user.Username,
UserId: user.Id,
}
h.App.ChatSession.Put(sessionId, session)
@@ -88,7 +117,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.db.Where("chat_id=?", chatId).First(&chat)
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
@@ -96,28 +125,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Value: chatModel.Value,
Weight: chatModel.Weight,
Platform: types.Platform(chatModel.Platform)}
Id: chatModel.Id,
Name: chatModel.Name,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// 初始化聊天配置
var config model.Config
h.db.Where("marker", "chat").First(&config)
err = utils.JsonDecode(config.Config, &chatConfig)
if err != nil {
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
c.Abort()
return
}
h.Init()
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
@@ -125,9 +144,10 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Error(err)
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.App.ChatClients.Delete(sessionId)
h.App.ChatSession.Delete(sessionId)
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
@@ -136,19 +156,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
message := string(msg)
logger.Info("Receive a message: ", message)
//utils.ReplyMessage(client, "这是一条测试消息!")
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, message, client)
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Info("回答完毕: " + string(message))
logger.Infof("回答完毕: %v", message.Content)
}
}
@@ -156,16 +187,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
defer func() {
if r := recover(); r != nil {
logger.Error("Recover message from error: ", r)
}
}()
if !h.App.Debug {
defer func() {
if r := recover(); r != nil {
logger.Error("Recover message from error: ", r)
}
}()
}
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "非法用户,请联系管理员")
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作")
return res.Error
}
var userVo vo.User
@@ -181,14 +214,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Calls < session.Model.Weight {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数%d已不足以支付当前模型的单次对话需要消耗的对话额度%d", userVo.Calls, session.Model.Weight))
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
}
@@ -198,36 +225,69 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure:
req.Temperature = h.App.ChatConfig.Azure.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
case types.Baidu:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
// TODO 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
if h.App.SysConfig.EnabledFunction {
var functions = make([]types.Function, 0)
for _, f := range types.InnerFunctions {
functions = append(functions, f)
}
req.Functions = functions
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
required := parameters["required"]
delete(parameters, "required")
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
// Fixed: compatible for gpt4-turbo-xxx model
if !strings.HasPrefix(req.Model, "gpt-4-turbo-") {
tool.Function.Required = required
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
@@ -235,43 +295,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 加载聊天上下文
var chatCtx []interface{}
if h.App.ChatConfig.EnableContext {
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num
tokens := req.MaxTokens
for _, f := range types.InnerFunctions {
tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model)
tokens += tks
}
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.ModelToTokens[req.Model] {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
if chatConfig.ContextDeep > 0 {
var historyMessages []model.HistoryMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
@@ -281,6 +317,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)
@@ -288,10 +347,49 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
if session.Model.Platform == types.QWen {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: prompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
text := prompt
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
data = append(data, gin.H{
"type": "text",
"text": text,
})
content = data
} else {
content = prompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure:
@@ -304,7 +402,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
@@ -316,8 +415,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
// Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) {
var data struct {
Text string `json:"text"`
Model string `json:"model"`
Text string `json:"text"`
Model string `json:"model"`
ChatId string `json:"chat_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -325,10 +425,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
}
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
if data.Text == "" {
var item model.HistoryMessage
if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Last(&item)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -378,39 +478,53 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Debug().Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
}
var apiURL string
switch platform {
switch session.Model.Platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = h.App.ChatConfig.OpenAI.ApiURL
apiURL = apiKey.ApiURL
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(*apiKey)
if session.Model.Platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
@@ -418,6 +532,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
logger.Debugf(utils.JsonEncode(req))
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@@ -431,9 +547,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
proxyURL := h.App.Config.ProxyURL
if proxyURL != "" && platform == types.OpenAI { // 使用代理
proxy, _ := url.Parse(proxyURL)
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
@@ -442,42 +557,79 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
switch platform {
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure:
request.Header.Set("api-key", *apiKey)
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM:
token, err := h.getChatGLMToken(*apiKey)
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info(token)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
return client.Do(request)
}
// 扣减用户的对话次数
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
// 仅当用户没有导入自己的 API KEY 时才进行扣减
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
num := 1
if session.Model.Weight > 0 {
num = session.Model.Weight
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
}
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
// 将AI回复消息中生成的图片链接下载到本地
func (h *ChatHandler) extractImgUrl(text string) string {
pattern := `!\[([^\]]*)]\(([^)]+)\)`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(text, -1)
// 下载图片并替换链接地址
for _, match := range matches {
imageURL := match[2]
logger.Debug(imageURL)
// 对于相同地址的图片,已经被替换了,就不再重复下载了
if !strings.Contains(text, imageURL) {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
text = strings.ReplaceAll(text, imageURL, newImgURL)
}
return text
}

View File

@@ -1,32 +1,41 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
if !h.IsLogin(c) {
resp.SUCCESS(c)
return
}
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.db.Find(&roles, roleIds)
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
@@ -58,7 +67,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
@@ -70,14 +79,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil {
resp.ERROR(c, "No chats found")
return
@@ -89,13 +98,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
}
err = h.db.Transaction(func(tx *gorm.DB) error {
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil {
return res.Error
}
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.HistoryMessage{})
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
if res.Error != nil {
return res.Error
}
@@ -116,9 +125,9 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.HistoryMessage
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
@@ -144,20 +153,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
// 删除当前会话的聊天记录
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.")
return
@@ -179,18 +188,26 @@ func (h *ChatHandler) Detail(c *gin.Context) {
}
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
if res.Error != nil {
resp.ERROR(c, "No chat found")
return
}
// 填充角色名称
var role model.ChatRole
res = h.DB.Where("id", chatItem.RoleId).First(&role)
if res.Error != nil {
resp.ERROR(c, "Role not found")
return
}
var chatItemVo vo.ChatItem
err := utils.CopyObject(chatItem, &chatItemVo)
if err != nil {
resp.ERROR(c, err.Error())
return
}
chatItemVo.RoleName = role.Name
resp.SUCCESS(c, chatItemVo)
}

View File

@@ -1,11 +1,18 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
@@ -20,7 +27,7 @@ import (
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -30,8 +37,8 @@ func (h *ChatHandler) sendChatGLMMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
@@ -107,9 +114,6 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -117,63 +121,64 @@ func (h *ChatHandler) sendChatGLMMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -184,7 +189,8 @@ func (h *ChatHandler) sendChatGLMMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -1,25 +1,33 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"fmt"
"gorm.io/gorm"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
req2 "github.com/imroc/req/v3"
)
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -29,22 +37,20 @@ func (h *ChatHandler) sendOpenAiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
utils.ReplyMessage(ws, err.Error())
return err
} else {
defer response.Body.Close()
@@ -56,10 +62,11 @@ func (h *ChatHandler) sendOpenAiMessage(
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var functionCall = false
var functionName string
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
@@ -68,44 +75,67 @@ func (h *ChatHandler) sendOpenAiMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && tool.Function.Name == "" {
arguments = append(arguments, tool.Function.Arguments)
continue
}
}
// 兼容 Function Call
fun := responseBody.Choices[0].Delta.FunctionCall
if functionCall && fun.Name == "" {
if fun.Name != "" {
tool = *new(types.ToolCall)
tool.Function.Name = fun.Name
} else if toolCall {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(fun) {
functionName = fun.Name
f := h.App.Functions[functionName]
if f != nil {
functionCall = true
if !utils.IsEmptyValue(tool) {
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg)
}
continue
}
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
// output stopped
if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
if isNew {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
@@ -121,53 +151,40 @@ func (h *ChatHandler) sendOpenAiMessage(
}
}
if functionCall { // 调用函数完成任务
if toolCall { // 调用函数完成任务
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
// for creating image, check if the user's img_calls > 0
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
utils.ReplyMessage(ws, ErrImg)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", function.Token).
SetBody(params).
SetSuccessResult(&apiRes).Post(function.Action)
errMsg := ""
if err != nil {
errMsg = err.Error()
} else if r.IsErrorState() {
errMsg = r.Status
}
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
f := h.App.Functions[functionName]
if functionName == types.FuncImage {
params["user_id"] = userVo.Id
params["role_id"] = role.Id
params["chat_id"] = session.ChatId
params["session_id"] = session.SessionId
}
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncImage {
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景。%s", params["prompt"], data)
// update user's img_calls
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: apiRes.Data,
})
contents = append(contents, utils.InterfaceToString(apiRes.Data))
}
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -175,77 +192,77 @@ func (h *ChatHandler) sendOpenAiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && functionCall == false {
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -256,17 +273,20 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
@@ -274,7 +294,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {

View File

@@ -0,0 +1,248 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
content = line[5:]
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
}
return nil
}

View File

@@ -1,17 +1,25 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"html/template"
"io"
"net/http"
@@ -50,15 +58,16 @@ type xunFeiResp struct {
}
var Model2URL = map[string]string{
"generalv1": "1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -67,29 +76,34 @@ func (h *ChatHandler) sendXunFeiMessage(
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
if apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
apiKey = key.Value
var apiKey model.ApiKey
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
}
// use the last unused key
if res.Error != nil {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey, "|")
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1)
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
@@ -170,9 +184,6 @@ func (h *ChatHandler) sendXunFeiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -180,63 +191,64 @@ func (h *ChatHandler) sendXunFeiMessage(
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -247,7 +259,8 @@ func (h *ChatHandler) sendXunFeiMessage(
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
@@ -263,7 +276,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": float64(req.Temperature),
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",

View File

@@ -0,0 +1,46 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ConfigHandler struct {
BaseHandler
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, value)
}

View File

@@ -0,0 +1,262 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
service *dalle.Service
uploader *oss.UploaderManager
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
return &DallJobHandler{
service: service,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) {
if !h.preCheck(c) {
return
}
var data types.DallTask
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.DallJob{
UserId: uint(userId),
Prompt: data.Prompt,
Power: h.App.SysConfig.DallPower,
}
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
}
h.service.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// ImgWall 照片墙
func (h *DallJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *DallJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var items []model.DallJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
}
var jobs = make([]vo.DallJob, 0)
for _, item := range items {
var job vo.DallJob
err := utils.CopyObject(item, &job)
if err != nil {
continue
}
jobs = append(jobs, job)
}
return nil, jobs
}
// Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.DallJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,226 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"errors"
"fmt"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
type FunctionHandler struct {
BaseHandler
config types.ApiConfig
uploadManager *oss.UploaderManager
dallService *dalle.Service
}
func NewFunctionHandler(
server *core.AppServer,
db *gorm.DB,
config *types.AppConfig,
manager *oss.UploaderManager,
dallService *dalle.Service) *FunctionHandler {
return &FunctionHandler{
BaseHandler: BaseHandler{
App: server,
DB: db,
},
config: config.ApiConfig,
uploadManager: manager,
dallService: dallService,
}
}
type resVo struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data struct {
Title string `json:"title"`
UpdatedAt string `json:"updated_at"`
Items []dataItem `json:"items"`
} `json:"data"`
}
type dataItem struct {
Title string `json:"title"`
Url string `json:"url"`
Remark string `json:"remark"`
}
// check authorization
func (h *FunctionHandler) checkAuth(c *gin.Context) error {
tokenString := c.GetHeader(types.UserAuthHeader)
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(h.App.Config.Session.SecretKey), nil
})
if err != nil {
return fmt.Errorf("error with parse auth token: %v", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return errors.New("token is invalid")
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
return errors.New("token is expired")
}
return nil
}
// WeiBo 微博热搜
func (h *FunctionHandler) WeiBo(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
if h.config.Token == "" {
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
var res resVo
r, err := req.C().R().
SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
return
}
if res.Code != types.Success {
resp.ERROR(c, res.Message)
return
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
for i, v := range res.Data.Items {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
}
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
}
// ZaoBao 今日早报
func (h *FunctionHandler) ZaoBao(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
if h.config.Token == "" {
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
var res resVo
r, err := req.C().R().
SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
return
}
if res.Code != types.Success {
resp.ERROR(c, res.Message)
return
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
for _, v := range res.Data.Items {
builder = append(builder, v.Title)
}
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
}
// Dall3 DallE3 AI 绘图
func (h *FunctionHandler) Dall3(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
var params map[string]interface{}
if err := c.ShouldBindJSON(&params); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
logger.Debugf("绘画参数:%+v", params)
var user model.User
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
if res.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
return
}
// create dall task
prompt := utils.InterfaceToString(params["prompt"])
job := model.DallJob{
UserId: user.Id,
Prompt: prompt,
Power: h.App.SysConfig.DallPower,
}
res = h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
return
}
content, err := h.dallService.Image(types.DallTask{
JobId: job.Id,
UserId: user.Id,
Prompt: job.Prompt,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: job.Power,
}, true)
if err != nil {
resp.ERROR(c, "任务执行失败:"+err.Error())
return
}
resp.SUCCESS(c, content)
}

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"strings"
@@ -15,32 +22,29 @@ import (
// InviteHandler 用户邀请
type InviteHandler struct {
BaseHandler
db *gorm.DB
}
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
h := InviteHandler{db: db}
h.App = app
return &h
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c)
var inviteCode model.InviteCode
res := h.db.Where("user_id = ?", userId).First(&inviteCode)
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
// 如果邀请码不存在,则创建一个
if res.Error != nil {
code := strings.ToUpper(utils.RandString(8))
for {
res = h.db.Where("code = ?", code).First(&inviteCode)
res = h.DB.Where("code = ?", code).First(&inviteCode)
if res.Error != nil { // 不存在相同的邀请码则退出
break
}
}
inviteCode.UserId = userId
inviteCode.Code = code
h.db.Create(&inviteCode)
h.DB.Create(&inviteCode)
}
var codeVo vo.InviteCode
@@ -65,7 +69,7 @@ func (h *InviteHandler) List(c *gin.Context) {
return
}
userId := h.GetLoginUserId(c)
session := h.db.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
@@ -91,6 +95,6 @@ func (h *InviteHandler) List(c *gin.Context) {
// Hits 访问邀请码
func (h *InviteHandler) Hits(c *gin.Context) {
code := c.Query("code")
h.db.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,255 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// MarkMapHandler 生成思维导图
type MarkMapHandler struct {
BaseHandler
clients *types.LMap[int, *types.WsClient]
}
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
return &MarkMapHandler{
BaseHandler: BaseHandler{App: app, DB: db},
clients: types.NewLMap[int, *types.WsClient](),
}
}
func (h *MarkMapHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
modelId := h.GetInt(c, "model_id", 0)
userId := h.GetInt(c, "user_id", 0)
client := types.NewWsClient(ws)
h.clients.Put(userId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.clients.Delete(userId)
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
continue
}
// change model
if message.Type == "model_id" {
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
continue
}
logger.Info("Receive a message: ", message.Content)
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
}
}
}()
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId)
if res.Error != nil {
return fmt.Errorf("error with query user info: %v", res.Error)
}
var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel)
if res.Error != nil {
return fmt.Errorf("error with query chat model: %v", res.Error)
}
if user.Status == false {
return errors.New("当前用户被禁用")
}
if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power)
}
messages := make([]interface{}, 0)
messages = append(messages, types.Message{Role: "system", Content: `
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
# Geek-AI 助手
## 完整的开源系统
### 前端开源
### 后端开源
## 支持各种大模型
### OpenAI
### Azure
### 文心一言
### 通义千问
## 集成多种收费方式
### 支付宝
### 微信
另外,除此之外不要任何解释性语句。
`})
messages = append(messages, types.Message{Role: "user", Content: prompt})
var req = types.ApiRequest{
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
}
defer response.Body.Close()
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
var message = types.Message{}
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
return fmt.Errorf("error with decode data: %v", err)
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("解析响应失败: %v", err)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
// remove key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
return errors.New("请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
return errors.New("请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else {
return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
}
}
return nil
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", types.OpenAI).
Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := apiKey.ApiURL
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
return client.Do(request)
}

View File

@@ -0,0 +1,43 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MenuHandler struct {
BaseHandler
}
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu
err := utils.CopyObject(item, &product)
if err == nil {
list = append(list, product)
}
}
}
resp.SUCCESS(c, list)
}

View File

@@ -1,48 +1,61 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/mj"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/mj"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
type MidJourneyHandler struct {
BaseHandler
db *gorm.DB
pool *mj.ServicePool
snowflake *service.Snowflake
uploader *oss.UploaderManager
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool) *MidJourneyHandler {
h := MidJourneyHandler{
db: db,
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
return &MidJourneyHandler{
snowflake: snowflake,
pool: pool,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
@@ -55,22 +68,47 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Img string `json:"img"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Weight float32 `json:"weight"`
SessionId string `json:"session_id"`
TaskType string `json:"task_type"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Iw float32 `json:"iw"`
CRef string `json:"cref"` //生成角色一致的图像
SRef string `json:"sref"` //生成风格一致的图像
Cw int `json:"cw"` // 参考程度
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -80,39 +118,57 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
var prompt = data.Prompt
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
prompt += " --ar " + data.Rate
var params = ""
if data.Rate != "" && !strings.Contains(params, "--ar") {
params += " --ar " + data.Rate
}
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
prompt += fmt.Sprintf(" --seed %d", data.Seed)
if data.Seed > 0 && !strings.Contains(params, "--seed") {
params += fmt.Sprintf(" --seed %d", data.Seed)
}
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
prompt += fmt.Sprintf(" --s %d", data.Stylize)
if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
params += fmt.Sprintf(" --s %d", data.Stylize)
}
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
prompt += fmt.Sprintf(" --c %d", data.Chaos)
if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
params += fmt.Sprintf(" --c %d", data.Chaos)
}
if data.Img != "" {
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
}
if len(data.ImgArr) > 0 && data.Iw > 0 {
params += fmt.Sprintf(" --iw %.2f", data.Iw)
}
if data.Raw {
prompt += " --style raw"
params += " --style raw"
}
if data.Quality > 0 {
prompt += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", data.NegPrompt)
params += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.Tile {
prompt += " --tile "
params += " --tile "
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += fmt.Sprintf(" %s", data.Model)
if data.CRef != "" {
params += fmt.Sprintf(" --cref %s", data.CRef)
if data.Cw > 0 {
params += fmt.Sprintf(" --cw %d", data.Cw)
} else {
params += " --cw 100"
}
}
if data.SRef != "" {
params += fmt.Sprintf(" --sref %s", data.SRef)
}
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
params += fmt.Sprintf(" %s", data.Model)
}
// 处理融图和换脸的提示词
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
}
// 如果本地图片上传的是相对地址,处理成绝对地址
for k, v := range data.ImgArr {
if !strings.HasPrefix(v, "http") {
data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/"))
}
}
idValue, _ := c.Get(types.LoginUserID)
@@ -124,31 +180,68 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
job := model.MidJourneyJob{
Type: types.TaskImage.String(),
Type: data.TaskType,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: prompt,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil {
opt := "绘图"
if data.TaskType == types.TaskBlend.String() {
job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
opt = "融图"
} else if data.TaskType == types.TaskSwapFace.String() {
job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
opt = "换脸"
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: int(job.Id),
Id: job.Id,
TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskImage,
Prompt: fmt.Sprintf("%s %s", taskId, prompt),
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
})
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
type reqVo struct {
Src string `json:"src"`
Index int `json:"index"`
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
SessionId string `json:"session_id"`
@@ -171,18 +264,56 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
ReferenceId: data.MessageId,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: jobId,
Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskUpscale,
Prompt: data.Prompt,
UserId: userId,
ChannelId: data.ChannelId,
Index: data.Index,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
})
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -199,30 +330,95 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
ChannelId: data.ChannelId,
ReferenceId: data.MessageId,
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{
Id: jobId,
Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskVariation,
Prompt: data.Prompt,
UserId: userId,
Index: data.Index,
ChannelId: data.ChannelId,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
})
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
// ImgWall 照片墙
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
userId := h.GetInt(c, "user_id", 0)
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
session := h.db.Session(&gorm.Session{})
if status == 1 {
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
@@ -230,6 +426,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish = ?", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
@@ -238,8 +437,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
var items []model.MidJourneyJob
res := session.Find(&items)
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
return res.Error, nil
}
var jobs = make([]vo.MidJourneyJob, 0)
@@ -250,27 +448,72 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
continue
}
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 {
// 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item)
continue
}
// 正在运行中任务使用代理访问图片
if item.ImgURL == "" && item.OrgURL != "" {
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} else {
job.ImgURL = job.OrgURL
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
return nil, jobs
}
// Remove remove task image
func (h *MidJourneyHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -1,25 +1,30 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OrderHandler struct {
BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
@@ -31,8 +36,8 @@ func (h *OrderHandler) List(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, _ := utils.GetLoginUser(c, h.db)
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order

View File

@@ -1,27 +1,38 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/payment"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/payment"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"embed"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"github.com/shopspring/decimal"
"math"
"net/http"
"net/url"
"sync"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const (
PayWayAlipay = "支付宝"
PayWayXunHu = "虎皮椒"
PayWayJs = "PayJS"
)
// PaymentHandler 支付服务回调 handler
@@ -29,23 +40,32 @@ type PaymentHandler struct {
BaseHandler
alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService
js *payment.PayJS
snowflake *service.Snowflake
db *gorm.DB
fs embed.FS
lock sync.Mutex
}
func NewPaymentHandler(server *core.AppServer, alipayService *payment.AlipayService, huPiPayService *payment.HuPiPayService, snowflake *service.Snowflake, db *gorm.DB, fs embed.FS) *PaymentHandler {
h := PaymentHandler{
func NewPaymentHandler(
server *core.AppServer,
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
js *payment.PayJS,
db *gorm.DB,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
js: js,
snowflake: snowflake,
fs: fs,
db: db,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
}
h.App = server
return &h
}
func (h *PaymentHandler) DoPay(c *gin.Context) {
@@ -58,14 +78,20 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
if order.Status == types.OrderPaidSuccess {
resp.ERROR(c, "This order had been paid, please do not pay twice")
return
}
// 更新扫码状态
h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
if payWay == "alipay" { // 支付宝
// 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL
@@ -81,40 +107,20 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
c.Redirect(302, uri)
return
} else if payWay == "hupi" { // 虎皮椒支付
params := map[string]string{
"version": "1.1",
"trade_order_id": orderNo,
"total_fee": fmt.Sprintf("%f", order.Amount),
"title": order.Subject,
"notify_url": h.App.Config.HuPiPayConfig.NotifyURL,
"return_url": "",
"wap_name": "极客学长",
"callback_url": "",
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", order.Amount),
Title: order.Subject,
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
WapName: "极客学长",
}
res, err := h.huPiPayService.Pay(params)
r, err := h.huPiPayService.Pay(params)
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
resp.ERROR(c, err.Error())
return
}
var r struct {
Openid int64 `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
err = utils.JsonDecode(res, &r)
if err != nil {
resp.ERROR(c, "error with decode payment result: "+err.Error())
return
}
if r.ErrCode != 0 {
resp.ERROR(c, "error with generate pay url: "+r.ErrMsg)
return
}
c.Redirect(302, r.URL)
}
resp.ERROR(c, "Invalid operations")
@@ -131,7 +137,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
@@ -146,7 +152,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
for {
time.Sleep(time.Second)
var item model.Order
h.db.Where("order_no = ?", data.OrderNo).First(&item)
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
@@ -170,7 +176,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
}
var product model.Product
res := h.db.First(&product, data.ProductId)
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
@@ -182,42 +188,69 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return
}
var user model.User
res = h.db.First(&user, data.UserId)
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
payWay := PayWayAlipay
if data.PayWay == "hupi" {
var payWay string
var notifyURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
default:
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Calls: product.Calls,
ImgCalls: product.ImgCalls,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{
UserId: user.Id,
Mobile: user.Mobile,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: product.Price - product.Discount,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.db.Create(&order)
if res.Error != nil {
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
// PayJs 单独处理,只能用官方生成的二维码
if data.PayWay == "payjs" {
params := payment.JPayReq{
TotalFee: int(math.Ceil(order.Amount * 100)),
OutTradeNo: order.OrderNo,
Subject: product.Name,
}
r := h.js.Pay(params)
if r.IsOK() {
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
return
} else {
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
return
}
}
var logo string
if data.PayWay == "alipay" {
logo = "res/img/alipay.jpg"
@@ -235,7 +268,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return
}
parse, err := url.Parse(h.App.Config.AlipayConfig.NotifyURL)
parse, err := url.Parse(notifyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -251,52 +284,137 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
}
// AlipayNotify 支付宝支付回调
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
err := c.Request.ParseForm()
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
c.String(http.StatusOK, "fail")
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
// TODO这里最好用支付宝的公钥签名签证一下交易真假
//res := h.alipayService.TradeVerify(c.Request.Form)
r := h.alipayService.TradeQuery(c.Request.Form.Get("out_trade_no"))
logger.Infof("验证支付结果:%+v", r)
if !r.Success() {
c.String(http.StatusOK, "fail")
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
CallbackURL: returnURL,
WapName: "极客学长",
}
r, err := h.huPiPayService.Pay(params)
if err != nil {
logger.Error("error with generating Pay URL: ", err.Error())
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
payURL = r.URL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
returnURL = h.App.Config.JPayConfig.ReturnURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
params := url.Values{}
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
params.Add("out_trade_no", orderNo)
params.Add("body", product.Name)
params.Add("notify_url", notifyURL)
params.Add("auto", "0")
payURL = h.js.PayH5(params)
case "alipay":
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
returnURL = h.App.Config.AlipayConfig.ReturnURL
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
if err != nil {
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
return
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
h.lock.Lock()
defer h.lock.Unlock()
err = h.notify(r.OutTradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
resp.SUCCESS(c, payURL)
}
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string) error {
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error)
logger.Error(err)
return err
}
h.lock.Lock()
defer h.lock.Unlock()
// 已支付订单,直接返回
if order.Status == types.OrderPaidSuccess {
return nil
}
var user model.User
res = h.db.First(&user, order.UserId)
res = h.DB.First(&user, order.UserId)
if res.Error != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error)
logger.Error(err)
@@ -311,34 +429,27 @@ func (h *PaymentHandler) notify(orderNo string) error {
return err
}
// 1. 点卡days == 0, calls > 0
// 2. vip 套餐days > 0, calls == 0
if remark.Days > 0 {
if user.ExpiredTime > time.Now().Unix() {
var opt string
var power int
if remark.Days > 0 { // VIP 充值
if user.ExpiredTime >= time.Now().Unix() {
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
opt = "VIP充值VIP 没到期,只延期不增加算力"
} else {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
power = h.App.SysConfig.VipMonthPower
opt = "VIP充值"
}
user.Vip = true
} else if !user.Vip { // 充值点卡的非 VIP 用户
user.ExpiredTime = time.Now().AddDate(0, 0, 30).Unix()
}
if remark.Calls > 0 { // 充值点卡
user.Calls += remark.Calls
} else {
user.Calls += h.App.SysConfig.VipMonthCalls
}
if remark.ImgCalls > 0 {
user.ImgCalls += remark.ImgCalls
} else {
user.ImgCalls += h.App.SysConfig.VipMonthImgCalls
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
@@ -348,7 +459,8 @@ func (h *PaymentHandler) notify(orderNo string) error {
// 更新订单状态
order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess
res = h.db.Updates(&order)
order.TradeNo = tradeNo
res = h.DB.Updates(&order)
if res.Error != nil {
err := fmt.Errorf("error with update order info: %v", res.Error)
logger.Error(err)
@@ -356,7 +468,23 @@ func (h *PaymentHandler) notify(orderNo string) error {
}
// 更新产品销量
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
// 记录算力充值日志
if opt != "" {
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
}
return nil
}
@@ -369,6 +497,9 @@ func (h *PaymentHandler) GetPayWays(c *gin.Context) {
if h.App.Config.HuPiPayConfig.Enabled {
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
}
if h.App.Config.JPayConfig.Enabled {
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
}
resp.SUCCESS(c, data)
}
@@ -381,12 +512,76 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
}
orderNo := c.Request.Form.Get("trade_order_id")
logger.Infof("收到订单支付回调,订单 NO%s", orderNo)
// TODO 是否要保存订单交易流水号
h.lock.Lock()
defer h.lock.Unlock()
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,订单 NO%s交易流水号%s", orderNo, tradeNo)
err = h.notify(orderNo)
if err = h.huPiPayService.Check(tradeNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// AlipayNotify 支付宝支付回调
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
// TODO验证交易签名
res := h.alipayService.TradeVerify(c.Request.Form)
logger.Infof("验证支付结果:%+v", res)
if !res.Success() {
logger.Error("订单校验失败:", res.Message)
c.String(http.StatusOK, "fail")
return
}
tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(res.OutTradeNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// PayJsNotify PayJs 支付异步回调
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("out_trade_no")
returnCode := c.Request.Form.Get("return_code")
logger.Infof("收到订单支付回调,订单 NO%s支付结果代码%v", orderNo, returnCode)
// 支付失败
if returnCode != "1" {
return
}
// 校验订单支付状态
tradeNo := c.Request.Form.Get("payjs_order_id")
err = h.js.Check(tradeNo)
if err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
c.String(http.StatusOK, "fail")
return

View File

@@ -0,0 +1,74 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
userId := h.GetLoginUserId(c)
session = session.Where("user_id", userId)
if data.Model != "" {
session = session.Where("model", data.Model)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}

View File

@@ -1,31 +1,35 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ProductHandler struct {
BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product
var list = make([]vo.Product, 0)
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product

View File

@@ -1,120 +0,0 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/utils/resp"
"fmt"
"github.com/imroc/req/v3"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
type PromptHandler struct {
BaseHandler
db *gorm.DB
}
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
h := &PromptHandler{db: db}
h.App = app
return h
}
type apiRes struct {
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type apiErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
// Rewrite translate and rewrite prompt with ChatGPT
func (h *PromptHandler) Rewrite(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := h.request(data.Prompt, rewritePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) Translate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := h.request(data.Prompt, translatePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
// 获取 OpenAI 的 API KEY
var apiKey model.ApiKey
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
if res.Error != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
}
messages := make([]interface{}, 1)
messages[0] = types.Message{
Role: "user",
Content: fmt.Sprintf(promptTemplate, prompt),
}
var response apiRes
var errRes apiErrRes
r, err := req.C().SetProxyURL(h.App.Config.ProxyURL).R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(types.ApiRequest{
Model: "gpt-3.5-turbo",
Temperature: 0.9,
MaxTokens: 1024,
Stream: false,
Messages: messages,
}).
SetErrorResult(&errRes).
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
}
return response.Choices[0].Message.Content, nil
}

View File

@@ -1,25 +1,35 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
BaseHandler
db *gorm.DB
lock sync.Mutex
}
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db}
h.App = server
return &h
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
@@ -32,11 +42,20 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
}
// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
data.TxId = strings.ReplaceAll(data.TxId, " ", "")
h.lock.Lock()
defer h.lock.Unlock()
var item model.Reward
res := h.db.Where("tx_id = ?", data.TxId).First(&item)
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!")
return
@@ -47,16 +66,13 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.HACKER(c)
return
}
tx := h.db.Begin()
calls := (item.Amount + 0.1) * 10
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -64,13 +80,26 @@ func (h *RewardHandler) Verify(c *gin.Context) {
// 更新核销状态
item.Status = true
item.UserId = user.Id
res = h.db.Updates(&item)
item.Exchange = utils.JsonEncode(exchange)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)

View File

@@ -1,39 +1,79 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/sd"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
"time"
)
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
pool *sd.ServicePool
redis *redis.Client
pool *sd.ServicePool
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
}
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool) *SdJobHandler {
h := SdJobHandler{
db: db,
pool: pool,
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{
pool: pool,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
@@ -44,8 +84,8 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
@@ -55,7 +95,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
if !h.checkLimits(c) {
if !h.preCheck(c) {
return
}
@@ -88,23 +128,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
params := types.SdTaskParams{
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
Prompt: data.Prompt,
NegativePrompt: data.NegativePrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
taskId, err := h.snowflake.Next(true)
if err != nil {
resp.ERROR(c, "error with generate task id: "+err.Error())
return
}
params := types.SdTaskParams{
TaskId: taskId,
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
}
job := model.SdJob{
UserId: userId,
Type: types.TaskImage.String(),
@@ -112,9 +158,10 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params),
Prompt: data.Prompt,
Progress: 0,
Power: h.App.SysConfig.SdPower,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
@@ -124,23 +171,71 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Id: int(job.Id),
SessionId: data.SessionId,
Type: types.TaskImage,
Prompt: data.Prompt,
Params: params,
UserId: userId,
})
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
// JobList 获取 stable diffusion 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
userId := h.GetInt(c, "user_id", 0)
// ImgWall 照片墙
func (h *SdJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
session := h.db.Session(&gorm.Session{})
if status == 1 {
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
@@ -148,6 +243,9 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
@@ -156,8 +254,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
var items []model.SdJob
res := session.Find(&items)
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
return res.Error, nil
}
var jobs = make([]vo.SdJob, 0)
@@ -168,23 +265,69 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
continue
}
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 {
// 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item)
continue
}
// 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "")
// 从 leveldb 中获取图片预览数据
var imageData string
err = h.leveldb.Get(item.TaskId, &imageData)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
job.ImgURL = "data:image/png;base64," + imageData
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
return nil, jobs
}
// Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.SdJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte(sd.Finished))
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *SdJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -1,11 +1,21 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/sms"
"geekai/utils"
"geekai/utils/resp"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
@@ -15,22 +25,31 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct {
BaseHandler
redis *redis.Client
sms *service.AliYunSmsService
sms *sms.ServiceManager
smtp *service.SmtpService
captcha *service.CaptchaService
}
func NewSmsHandler(app *core.AppServer, client *redis.Client, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha}
handler.App = app
return handler
func NewSmsHandler(
app *core.AppServer,
client *redis.Client,
sms *sms.ServiceManager,
smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler {
return &SmsHandler{
redis: client,
sms: sms,
captcha: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
}
// SendCode 发送验证码短信
// SendCode 发送验证码
func (h *SmsHandler) SendCode(c *gin.Context) {
var data struct {
Mobile string `json:"mobile"`
Key string `json:"key"`
Dots string `json:"dots"`
Receiver string `json:"receiver"` // 接收者
Key string `json:"key"`
Dots string `json:"dots"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -43,14 +62,28 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
}
code := utils.RandomNumber(6)
err := h.sms.SendVerifyCode(data.Mobile, code)
var err error
if strings.Contains(data.Receiver, "@") { // email
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!")
return
}
err = h.smtp.SendVerifyCode(data.Receiver, code)
} else {
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!")
return
}
err = h.sms.GetService().SendVerifyCode(data.Receiver, code)
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 存储验证码,等待后面注册验证
_, err = h.redis.Set(c, CodeStorePrefix+data.Mobile, code, 0).Result()
_, err = h.redis.Set(c, CodeStorePrefix+data.Receiver, code, 0).Result()
if err != nil {
resp.ERROR(c, "验证码保存失败")
return

View File

@@ -1,40 +1,17 @@
package handler
import (
"chatplus/service"
"github.com/gin-gonic/gin"
"geekai/service"
"geekai/service/payment"
"gorm.io/gorm"
)
type TestHandler struct {
db *gorm.DB
snowflake *service.Snowflake
js *payment.PayJS
}
func NewTestHandler(snowflake *service.Snowflake) *TestHandler {
return &TestHandler{snowflake: snowflake}
}
func (h *TestHandler) TestPay(c *gin.Context) {
//appId := "" //Appid
//appSecret := "" //密钥
//var host = "https://api.xunhupay.com/payment/do.html" //跳转支付页接口URL
//client := payment.NewXunHuPay(appId, appSecret) //初始化调用
//
////支付参数appid、time、nonce_str和hash这四个参数不用传调用的时候执行方法内部已经处理
//orderNo, _ := h.snowflake.Next()
//params := map[string]string{
// "version": "1.1",
// "trade_order_id": orderNo,
// "total_fee": "0.1",
// "title": "测试支付",
// "notify_url": "http://xxxxxxx.com",
// "return_url": "http://localhost:8888",
// "wap_name": "极客学长",
// "callback_url": "",
//}
//
//execute, err := client.Execute(host, params) //执行支付操作
//if err != nil {
// logger.Error(err)
//}
//resp.SUCCESS(c, execute)
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js}
}

View File

@@ -1,31 +1,101 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/service/oss"
"chatplus/utils/resp"
"geekai/core"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
BaseHandler
db *gorm.DB
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
handler := &UploadHandler{db: db, uploaderManager: manager}
handler.App = app
return handler
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
fileURL, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, fileURL)
userId := h.GetLoginUserId(c)
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}
func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 {
for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
}
resp.SUCCESS(c, files)
}
// Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}

View File

@@ -1,18 +1,27 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/gin-gonic/gin"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"gorm.io/gorm"
@@ -20,26 +29,31 @@ import (
type UserHandler struct {
BaseHandler
db *gorm.DB
searcher *xdb.Searcher
redis *redis.Client
searcher *xdb.Searcher
redis *redis.Client
licenseService *service.LicenseService
}
func NewUserHandler(
app *core.AppServer,
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, redis: client}
handler.App = app
return handler
client *redis.Client,
licenseService *service.LicenseService) *UserHandler {
return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher,
redis: client,
licenseService: licenseService,
}
}
// Register user register
func (h *UserHandler) Register(c *gin.Context) {
// parameters process
var data struct {
Mobile string `json:"mobile"`
RegWay string `json:"reg_way"`
Username string `json:"username"`
Password string `json:"password"`
Code string `json:"code"`
InviteCode string `json:"invite_code"`
@@ -49,35 +63,34 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
data.Password = strings.TrimSpace(data.Password)
if len(data.Mobile) < 10 {
resp.ERROR(c, "请输入合法的手机号")
return
}
if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符")
return
}
// 检测最大注册人数
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if int(totalUser) >= h.licenseService.GetLicense().UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
if h.App.SysConfig.EnabledMsg {
var key string
if data.RegWay == "email" || data.RegWay == "mobile" {
key = CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
resp.ERROR(c, "验证码错误")
return
}
}
// 验证邀请码
inviteCode := model.InviteCode{}
if data.InviteCode == "" {
if h.App.SysConfig.ForceInvite {
resp.ERROR(c, "当前系统设定必须使用邀请码才能注册")
return
}
} else {
res := h.db.Where("code = ?", data.InviteCode).First(&inviteCode)
if data.InviteCode != "" {
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "无效的邀请码")
return
@@ -86,32 +99,26 @@ func (h *UserHandler) Register(c *gin.Context) {
// check if the username is exists
var item model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&item)
if res.RowsAffected > 0 {
resp.ERROR(c, "该手机号码已经被注册,请更换其他手机号")
res := h.DB.Where("username = ?", data.Username).First(&item)
if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册")
return
}
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
Mobile: data.Mobile,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.InitChatCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
Power: h.App.SysConfig.InitPower,
}
res = h.db.Create(&user)
res = h.DB.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
@@ -121,26 +128,36 @@ func (h *UserHandler) Register(c *gin.Context) {
// 记录邀请关系
if data.InviteCode != "" {
// 增加邀请数量
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InviteChatCalls > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("calls", gorm.Expr("calls + ?", h.App.SysConfig.InviteChatCalls))
}
if h.App.SysConfig.InviteImgCalls > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", h.App.SysConfig.InviteImgCalls))
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
// 记录邀请算力充值日志
var inviter model.User
h.DB.Where("id", inviteCode.UserId).First(&inviter)
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
})
}
// 添加邀请记录
h.db.Create(&model.InviteLog{
h.DB.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Mobile,
Username: user.Username,
InviteCode: inviteCode.Code,
Reward: utils.JsonEncode(types.InviteReward{ChatCalls: h.App.SysConfig.InviteChatCalls, ImgCalls: h.App.SysConfig.InviteImgCalls}),
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
})
}
if h.App.SysConfig.EnabledMsg {
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
}
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
@@ -164,7 +181,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
var data struct {
Mobile string `json:"username"`
Username string `json:"username"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -172,7 +189,7 @@ func (h *UserHandler) Login(c *gin.Context) {
return
}
var user model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&user)
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户名不存在")
return
@@ -192,11 +209,11 @@ func (h *UserHandler) Login(c *gin.Context) {
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.db.Model(&user).Updates(user)
h.DB.Model(&user).Updates(user)
h.db.Create(&model.UserLoginLog{
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Mobile,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
@@ -222,24 +239,16 @@ func (h *UserHandler) Login(c *gin.Context) {
// Logout 注 销
func (h *UserHandler) Logout(c *gin.Context) {
sessionId := c.GetHeader(types.ChatTokenHeader)
key := h.GetUserKey(c)
if _, err := h.redis.Del(c, key).Result(); err != nil {
logger.Error("error with delete session: ", err)
}
// 删除 websocket 会话列表
h.App.ChatSession.Delete(sessionId)
// 关闭 socket 连接
client := h.App.ChatClients.Get(sessionId)
if client != nil {
client.Close()
}
resp.SUCCESS(c)
}
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
@@ -255,26 +264,23 @@ func (h *UserHandler) Session(c *gin.Context) {
}
type userProfile struct {
Id uint `json:"id"`
Mobile string `json:"mobile"`
Avatar string `json:"avatar"`
ChatConfig types.UserChatConfig `json:"chat_config"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
TotalTokens int64 `json:"total_tokens"`
Tokens int64 `json:"tokens"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
Id uint `json:"id"`
Nickname string `json:"nickname"`
Username string `json:"username"`
Avatar string `json:"avatar"`
Power int `json:"power"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
}
func (h *UserHandler) Profile(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
var profile userProfile
err = utils.CopyObject(user, &profile)
if err != nil {
@@ -294,15 +300,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
user.Avatar = data.Avatar
user.ChatConfig = utils.JsonEncode(data.ChatConfig)
res := h.db.Updates(&user)
user.Nickname = data.Nickname
res := h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新用户信息失败")
return
@@ -327,21 +333,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
password := utils.GenPassword(data.OldPass, user.Salt)
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
if password != user.Password {
resp.ERROR(c, "原密码错误")
return
}
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.db.Model(&user).UpdateColumn("password", newPass)
res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败")
@@ -354,9 +360,9 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
// ResetPass 重置密码
func (h *UserHandler) ResetPass(c *gin.Context) {
var data struct {
Mobile string
Code string // 验证码
Password string // 新密码
Username string `json:"username"`
Code string `json:"code"` // 验证码
Password string `json:"password"` // 新密码
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -364,25 +370,23 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.Where("mobile", data.Mobile).First(&user)
res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户不存在!")
return
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
if h.App.SysConfig.EnabledMsg {
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
return
}
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -391,11 +395,11 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
}
// BindMobile 绑定手机
func (h *UserHandler) BindMobile(c *gin.Context) {
// BindUsername 重置账
func (h *UserHandler) BindUsername(c *gin.Context) {
var data struct {
Mobile string `json:"mobile"`
Code string `json:"code"`
Username string `json:"username"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -403,28 +407,28 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
resp.ERROR(c, "验证码错误")
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&item)
res := h.DB.Where("username = ?", data.Username).First(&item)
if res.Error == nil {
resp.ERROR(c, "该手机号已经被其他账号绑定")
resp.ERROR(c, "该号已经被其他账号绑定")
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res = h.db.Model(&user).UpdateColumn("mobile", data.Mobile)
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return

View File

@@ -1,5 +1,12 @@
package logger
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"

View File

@@ -1,22 +1,30 @@
package main
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/handler/admin"
"chatplus/handler/chatimpl"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/fun"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/service/payment"
"chatplus/service/sd"
"chatplus/service/wx"
"chatplus/store"
"context"
"embed"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/handler/admin"
"geekai/handler/chatimpl"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/dalle"
"geekai/service/mj"
"geekai/service/oss"
"geekai/service/payment"
"geekai/service/sd"
"geekai/service/sms"
"geekai/service/wx"
"geekai/store"
"io"
"log"
"os"
@@ -43,34 +51,34 @@ type AppLifecycle struct {
// OnStart 应用程序启动时执行
func (l *AppLifecycle) OnStart(context.Context) error {
log.Println("AppLifecycle OnStart")
logger.Info("AppLifecycle OnStart")
return nil
}
// OnStop 应用程序停止时执行
func (l *AppLifecycle) OnStop(context.Context) error {
log.Println("AppLifecycle OnStop")
logger.Info("AppLifecycle OnStop")
return nil
}
func NewAppLifeCycle() *AppLifecycle {
return &AppLifecycle{}
}
func main() {
configFile := os.Getenv("CONFIG_FILE")
if configFile == "" {
configFile = "config.toml"
}
var debug bool
debugEnv := os.Getenv("DEBUG")
if debugEnv == "" {
debug = true
} else {
debug, _ = strconv.ParseBool(os.Getenv("DEBUG"))
}
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
logger.Info("Loading config file: ", configFile)
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
}
}()
if !debug {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
}
}()
}
app := fx.New(
// 初始化配置应用配置
@@ -96,6 +104,7 @@ func main() {
fx.Provide(store.NewGormConfig),
fx.Provide(store.NewMysql),
fx.Provide(store.NewRedisClient),
fx.Provide(store.NewLevelDB),
fx.Provide(func() embed.FS {
return xdbFS
@@ -115,9 +124,6 @@ func main() {
return xdb.NewWithBuffer(cBuff)
}),
// 创建函数
fx.Provide(fun.NewFunctions),
// 创建控制器
fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler),
@@ -132,6 +138,8 @@ func main() {
fx.Provide(handler.NewPaymentHandler),
fx.Provide(handler.NewOrderHandler),
fx.Provide(handler.NewProductHandler),
fx.Provide(handler.NewConfigHandler),
fx.Provide(handler.NewPowerLogHandler),
fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
@@ -143,14 +151,31 @@ func main() {
fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler),
fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler),
// 创建服务
fx.Provide(service.NewAliYunSmsService),
fx.Provide(sms.NewSendServiceManager),
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
return service.NewCaptchaService(config.ApiConfig)
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) {
service.Run()
service.CheckTaskNotify()
service.DownloadImages()
service.CheckTaskStatus()
}),
// 邮件服务
fx.Provide(service.NewSmtpService),
// License 服务
fx.Provide(service.NewLicenseService),
fx.Invoke(func(licenseService *service.LicenseService) {
licenseService.SyncLicense()
}),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
@@ -165,12 +190,26 @@ func main() {
// MidJourney service pool
fx.Provide(mj.NewServicePool),
fx.Invoke(func(pool *mj.ServicePool) {
if pool.HasAvailableService() {
pool.DownloadImages()
pool.CheckTaskNotify()
pool.SyncTaskProgress()
}
}),
// Stable Diffusion 机器人
fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool) {
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
}
}),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewPayJS),
fx.Provide(service.NewSnowflake),
fx.Provide(service.NewXXLJobExecutor),
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
@@ -196,7 +235,7 @@ func main() {
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/username", h.BindUsername)
group.POST("resetPass", h.ResetPass)
}),
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
@@ -213,6 +252,8 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
s.Engine.POST("/api/upload", h.Upload)
s.Engine.GET("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/")
@@ -222,6 +263,8 @@ func main() {
group := s.Engine.Group("/api/captcha/")
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
}),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
group := s.Engine.Group("/api/reward/")
@@ -229,33 +272,53 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
}),
// 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/config/")
group.POST("update", h.Update)
group.GET("get", h.Get)
group := s.Engine.Group("/api/admin/")
group.POST("config/update", h.Update)
group.GET("config/get", h.Get)
group.POST("active", h.Active)
group.GET("config/get/license", h.GetLicense)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/")
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
@@ -271,11 +334,13 @@ func main() {
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
group.POST("set", h.Set)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/")
group.GET("list", h.List)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/")
@@ -299,8 +364,10 @@ func main() {
group.GET("payWays", h.GetPayWays)
group.POST("query", h.OrderQuery)
group.POST("qrcode", h.PayQrcode)
group.POST("mobile", h.Mobile)
group.POST("alipay/notify", h.AlipayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/")
@@ -332,25 +399,89 @@ func main() {
group.GET("hits", h.Hits)
}),
fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt/")
group.POST("rewrite", h.Rewrite)
group.POST("translate", h.Translate)
fx.Provide(admin.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
group := s.Engine.Group("/api/admin/function/")
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("token", h.GenToken)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/test/")
group.GET("pay", h.TestPay)
// 验证码
fx.Provide(admin.NewCaptchaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
group := s.Engine.Group("/api/admin/login/")
group.GET("captcha", h.GetCaptcha)
}),
fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload)
}),
fx.Provide(handler.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
group := s.Engine.Group("/api/function/")
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
}),
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
group := s.Engine.Group("/api/powerLog/")
group.POST("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
group := s.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}),
fx.Provide(admin.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
group := s.Engine.Group("/api/admin/menu/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}),
fx.Provide(handler.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
group := s.Engine.Group("/api/menu/")
group.GET("list", h.List)
}),
fx.Provide(handler.NewMarkMapHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
group := s.Engine.Group("/api/markMap/")
group.Any("client", h.Client)
}),
fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
go func() {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
}()
}),
fx.Provide(NewAppLifeCycle),
// 注册生命周期回调函数
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
lifecycle.Append(fx.Hook{

View File

@@ -0,0 +1,38 @@
-----BEGIN CERTIFICATE-----
MIIDszCCApugAwIBAgIQICMRB0rBU2/rZJbfJGMYIzANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDYzNTQxWhcNMjQxMTA2MDYzNTQxWjCB
hDELMAkGA1UEBhMCQ04xHzAdBgNVBAoMFm1ib25meTkwMTVAc2FuZGJveC5jb20xDzANBgNVBAsM
BkFsaXBheTFDMEEGA1UEAww65pSv5LuY5a6dKOS4reWbvSnnvZHnu5zmioDmnK/mnInpmZDlhazl
j7gtMjA4ODcyMTAyMDc1MDU4MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKsoKcw5
sxaiyV7mpWzDtnQ1K518eQLP0+dJlZAf06aBep/Aj9DIqrba/k7DHt8dKQvILMLAMpN1+2IRxbaO
yxMa/laj3lZ1eHrB6F077O3D62oHcE3noZtXL0N1zZAxpmkNmYIHeLZS2oLMS4ANu47O/wpDC7BV
HjdpZugtdPJ4mxdCpM9GDdLs7W4s5QI4PUPK4skFNMFoKI+0cYP/9ju87UP//IHC/K510GWNl+Gn
Cvgag3AmiIB0utJNsGhxm6zT1T9tUWjW9iz/BxBKiPatsCX9VpPQzGnW7ZonRQtiZSokIlP2IPvl
H5DcwpWUz3/LUY0SmKxnKOEYeOOqCW8CAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
DQEBCwUAA4IBAQAtgxF2EzjOndEFxBUD9tFwcSt6XKGggOp52oft1pvynPg4ALTLafOtfEPDrFBH
PwpYrSu9s9C8NJtaA2HrlCfBjIuwEFTXiN+HPvS0SwSPKt9AXEiTcOF8vDcGamEen8QI4fo5Jia7
2VRKkerkww5/+FzSaVO7ZUKuL80M1QJStmAZc8kPPwdYOTTW2bGf8BcmSDL6SPElBkt7tCCRd4sn
+jq4cZ0yb2i77rBZCwHcTvfTqIBblPwLv4uGvg3+83BxIB5w6Kqp06bKEAPmobFY5IVHa+ON0/qi
BXxXr+WQ3piKRVQEN64+PTAjSc67Ix1umvpLl3Ko6Ry7NJmpDcUn
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDszCCApugAwIBAgIQIBkIGbgVxq210KxLJ+YA/TANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UE
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxJTAjBgNVBAsMHENlcnRpZmljYXRpb24gQXV0
aG9yaXR5IHRlc3QxNjA0BgNVBAMMLUFudCBGaW5hbmNpYWwgQ2VydGlmaWNhdGlvbiBBdXRob3Jp
dHkgUjEgdGVzdDAeFw0xOTA4MTkxMTE2MDBaFw0yNDA4MDExMTE2MDBaMIGRMQswCQYDVQQGEwJD
TjEbMBkGA1UECgwSQW50IEZpbmFuY2lhbCB0ZXN0MSUwIwYDVQQLDBxDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSB0ZXN0MT4wPAYDVQQDDDVBbnQgRmluYW5jaWFsIENlcnRpZmljYXRpb24gQXV0aG9y
aXR5IENsYXNzIDIgUjEgdGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMh4FKYO
ZyRQHD6eFbPKZeSAnrfjfU7xmS9Yoozuu+iuqZlb6Z0SPLUqqTZAFZejOcmr07ln/pwZxluqplxC
5+B48End4nclDMlT5HPrDr3W0frs6Xsa2ZNcyil/iKNB5MbGll8LRAxntsKvZZj6vUTMb705gYgm
VUMILwi/ZxKTQqBtkT/kQQ5y6nOZsj7XI5rYdz6qqOROrpvS/d7iypdHOMIM9Iz9DlL1mrCykbBi
t25y+gTeXmuisHUwqaRpwtCGK4BayCqxRGbNipe6W73EK9lBrrzNtTr9NaysesT/v+l25JHCL9tG
wpNr1oWFzk4IHVOg0ORiQ6SUgxZUTYcCAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
DQEBCwUAA4IBAQBWThEoIaQoBX2YeRY/I8gu6TYnFXtyuCljANnXnM38ft+ikhE5mMNgKmJYLHvT
yWWWgwHoSAWEuml7EGbE/2AK2h3k0MdfiWLzdmpPCRG/RJHk6UB1pMHPilI+c0MVu16OPpKbg5Vf
LTv7dsAB40AzKsvyYw88/Ezi1osTXo6QQwda7uefvudirtb8FcQM9R66cJxl3kt1FXbpYwheIm/p
j1mq64swCoIYu4NrsUYtn6CV542DTQMI5QdXkn+PzUUly8F6kDp+KpMNd0avfWNL5+O++z+F5Szy
1CPta1D7EQ/eYmMP+mOQ35oifWIoFCpN6qQVBS/Hob1J/UUyg7BW
-----END CERTIFICATE-----

View File

@@ -0,0 +1,88 @@
-----BEGIN CERTIFICATE-----
MIIBszCCAVegAwIBAgIIaeL+wBcKxnswDAYIKoEcz1UBg3UFADAuMQswCQYDVQQG
EwJDTjEOMAwGA1UECgwFTlJDQUMxDzANBgNVBAMMBlJPT1RDQTAeFw0xMjA3MTQw
MzExNTlaFw00MjA3MDcwMzExNTlaMC4xCzAJBgNVBAYTAkNOMQ4wDAYDVQQKDAVO
UkNBQzEPMA0GA1UEAwwGUk9PVENBMFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAE
MPCca6pmgcchsTf2UnBeL9rtp4nw+itk1Kzrmbnqo05lUwkwlWK+4OIrtFdAqnRT
V7Q9v1htkv42TsIutzd126NdMFswHwYDVR0jBBgwFoAUTDKxl9kzG8SmBcHG5Yti
W/CXdlgwDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFEwysZfZ
MxvEpgXBxuWLYlvwl3ZYMAwGCCqBHM9VAYN1BQADSAAwRQIgG1bSLeOXp3oB8H7b
53W+CKOPl2PknmWEq/lMhtn25HkCIQDaHDgWxWFtnCrBjH16/W3Ezn7/U/Vjo5xI
pDoiVhsLwg==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIF0zCCA7ugAwIBAgIIH8+hjWpIDREwDQYJKoZIhvcNAQELBQAwejELMAkGA1UE
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNVBAsMF0NlcnRpZmlj
YXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5jaWFsIENlcnRpZmlj
YXRpb24gQXV0aG9yaXR5IFIxMB4XDTE4MDMyMTEzNDg0MFoXDTM4MDIyODEzNDg0
MFowejELMAkGA1UEBhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNV
BAsMF0NlcnRpZmljYXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5j
aWFsIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFIxMIICIjANBgkqhkiG9w0BAQEF
AAOCAg8AMIICCgKCAgEAtytTRcBNuur5h8xuxnlKJetT65cHGemGi8oD+beHFPTk
rUTlFt9Xn7fAVGo6QSsPb9uGLpUFGEdGmbsQ2q9cV4P89qkH04VzIPwT7AywJdt2
xAvMs+MgHFJzOYfL1QkdOOVO7NwKxH8IvlQgFabWomWk2Ei9WfUyxFjVO1LVh0Bp
dRBeWLMkdudx0tl3+21t1apnReFNQ5nfX29xeSxIhesaMHDZFViO/DXDNW2BcTs6
vSWKyJ4YIIIzStumD8K1xMsoaZBMDxg4itjWFaKRgNuPiIn4kjDY3kC66Sl/6yTl
YUz8AybbEsICZzssdZh7jcNb1VRfk79lgAprm/Ktl+mgrU1gaMGP1OE25JCbqli1
Pbw/BpPynyP9+XulE+2mxFwTYhKAwpDIDKuYsFUXuo8t261pCovI1CXFzAQM2w7H
DtA2nOXSW6q0jGDJ5+WauH+K8ZSvA6x4sFo4u0KNCx0ROTBpLif6GTngqo3sj+98
SZiMNLFMQoQkjkdN5Q5g9N6CFZPVZ6QpO0JcIc7S1le/g9z5iBKnifrKxy0TQjtG
PsDwc8ubPnRm/F82RReCoyNyx63indpgFfhN7+KxUIQ9cOwwTvemmor0A+ZQamRe
9LMuiEfEaWUDK+6O0Gl8lO571uI5onYdN1VIgOmwFbe+D8TcuzVjIZ/zvHrAGUcC
AwEAAaNdMFswCwYDVR0PBAQDAgEGMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFF90
tATATwda6uWx2yKjh0GynOEBMB8GA1UdIwQYMBaAFF90tATATwda6uWx2yKjh0Gy
nOEBMA0GCSqGSIb3DQEBCwUAA4ICAQCVYaOtqOLIpsrEikE5lb+UARNSFJg6tpkf
tJ2U8QF/DejemEHx5IClQu6ajxjtu0Aie4/3UnIXop8nH/Q57l+Wyt9T7N2WPiNq
JSlYKYbJpPF8LXbuKYG3BTFTdOVFIeRe2NUyYh/xs6bXGr4WKTXb3qBmzR02FSy3
IODQw5Q6zpXj8prYqFHYsOvGCEc1CwJaSaYwRhTkFedJUxiyhyB5GQwoFfExCVHW
05ZFCAVYFldCJvUzfzrWubN6wX0DD2dwultgmldOn/W/n8at52mpPNvIdbZb2F41
T0YZeoWnCJrYXjq/32oc1cmifIHqySnyMnavi75DxPCdZsCOpSAT4j4lAQRGsfgI
kkLPGQieMfNNkMCKh7qjwdXAVtdqhf0RVtFILH3OyEodlk1HYXqX5iE5wlaKzDop
PKwf2Q3BErq1xChYGGVS+dEvyXc/2nIBlt7uLWKp4XFjqekKbaGaLJdjYP5b2s7N
1dM0MXQ/f8XoXKBkJNzEiM3hfsU6DOREgMc1DIsFKxfuMwX3EkVQM1If8ghb6x5Y
jXayv+NLbidOSzk4vl5QwngO/JYFMkoc6i9LNwEaEtR9PhnrdubxmrtM+RjfBm02
77q3dSWFESFQ4QxYWew4pHE0DpWbWy/iMIKQ6UZ5RLvB8GEcgt8ON7BBJeMc+Dyi
kT9qhqn+lw==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIICiDCCAgygAwIBAgIIQX76UsB/30owDAYIKoZIzj0EAwMFADB6MQswCQYDVQQG
EwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UECwwXQ2VydGlmaWNh
dGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNpYWwgQ2VydGlmaWNh
dGlvbiBBdXRob3JpdHkgRTEwHhcNMTkwNDI4MTYyMDQ0WhcNNDkwNDIwMTYyMDQ0
WjB6MQswCQYDVQQGEwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UE
CwwXQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNp
YWwgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgRTEwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAASCCRa94QI0vR5Up9Yr9HEupz6hSoyjySYqo7v837KnmjveUIUNiuC9pWAU
WP3jwLX3HkzeiNdeg22a0IZPoSUCpasufiLAnfXh6NInLiWBrjLJXDSGaY7vaokt
rpZvAdmjXTBbMAsGA1UdDwQEAwIBBjAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRZ
4ZTgDpksHL2qcpkFkxD2zVd16TAfBgNVHSMEGDAWgBRZ4ZTgDpksHL2qcpkFkxD2
zVd16TAMBggqhkjOPQQDAwUAA2gAMGUCMQD4IoqT2hTUn0jt7oXLdMJ8q4vLp6sg
wHfPiOr9gxreb+e6Oidwd2LDnC4OUqCWiF8CMAzwKs4SnDJYcMLf2vpkbuVE4dTH
Rglz+HGcTLWsFs4KxLsq7MuU+vJTBUeDJeDjdA==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDxTCCAq2gAwIBAgIUEMdk6dVgOEIS2cCP0Q43P90Ps5YwDQYJKoZIhvcNAQEF
BQAwajELMAkGA1UEBhMCQ04xEzARBgNVBAoMCmlUcnVzQ2hpbmExHDAaBgNVBAsM
E0NoaW5hIFRydXN0IE5ldHdvcmsxKDAmBgNVBAMMH2lUcnVzQ2hpbmEgQ2xhc3Mg
MiBSb290IENBIC0gRzMwHhcNMTMwNDE4MDkzNjU2WhcNMzMwNDE4MDkzNjU2WjBq
MQswCQYDVQQGEwJDTjETMBEGA1UECgwKaVRydXNDaGluYTEcMBoGA1UECwwTQ2hp
bmEgVHJ1c3QgTmV0d29yazEoMCYGA1UEAwwfaVRydXNDaGluYSBDbGFzcyAyIFJv
b3QgQ0EgLSBHMzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOPPShpV
nJbMqqCw6Bz1kehnoPst9pkr0V9idOwU2oyS47/HjJXk9Rd5a9xfwkPO88trUpz5
4GmmwspDXjVFu9L0eFaRuH3KMha1Ak01citbF7cQLJlS7XI+tpkTGHEY5pt3EsQg
wykfZl/A1jrnSkspMS997r2Gim54cwz+mTMgDRhZsKK/lbOeBPpWtcFizjXYCqhw
WktvQfZBYi6o4sHCshnOswi4yV1p+LuFcQ2ciYdWvULh1eZhLxHbGXyznYHi0dGN
z+I9H8aXxqAQfHVhbdHNzi77hCxFjOy+hHrGsyzjrd2swVQ2iUWP8BfEQqGLqM1g
KgWKYfcTGdbPB1MCAwEAAaNjMGEwHQYDVR0OBBYEFG/oAMxTVe7y0+408CTAK8hA
uTyRMB8GA1UdIwQYMBaAFG/oAMxTVe7y0+408CTAK8hAuTyRMA8GA1UdEwEB/wQF
MAMBAf8wDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3DQEBBQUAA4IBAQBLnUTfW7hp
emMbuUGCk7RBswzOT83bDM6824EkUnf+X0iKS95SUNGeeSWK2o/3ALJo5hi7GZr3
U8eLaWAcYizfO99UXMRBPw5PRR+gXGEronGUugLpxsjuynoLQu8GQAeysSXKbN1I
UugDo9u8igJORYA+5ms0s5sCUySqbQ2R5z/GoceyI9LdxIVa1RjVX8pYOj8JFwtn
DJN3ftSFvNMYwRuILKuqUYSHc2GPYiHVflDh5nDymCMOQFcFG3WsEuB+EYQPFgIU
1DHmdZcz7Llx8UOZXX2JupWCYzK1XhJb+r4hK5ncf/w8qGtYlmyJpxk3hr1TfUJX
Yf4Zr0fJsGuv
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDmTCCAoGgAwIBAgIQICMRB2LW76yahgdg3IFNPDANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDU0NjE5WhcNMjQxMTExMDU0NjE5WjBr
MQswCQYDVQQGEwJDTjEfMB0GA1UECgwWbWJvbmZ5OTAxNUBzYW5kYm94LmNvbTEPMA0GA1UECwwG
QWxpcGF5MSowKAYDVQQDDCEyMDg4NzIxMDIwNzUwNTgxLTkwMjEwMDAxMzE2NTgwMjMwggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCxihQPf1Q+g9ArgM46shVqL5sbRha/df95D1PsWyEq
ANmWmG4zZ+ksYDVQrc4KzhSRoi56sm/7TDFYTmM6bW99e/nKW58WxyZB4ie5qA3F4n17psPyDqb8
IokcQmCphSFDaXQD6AoXoLNtTM0vAI2cWxAgebZ/vsrdj5Ntjt+Rp3NYMCk1i5xovHcfILzLEGbX
QXoT9fo5AhHotTWa6xHVLPUGY9qwLzQxHzBmvy5ZMfnOfJkm/mDisTSqAUB59F3dzU/1ARVkEZ1w
Mgb4XohWBw6iurQfbMnH2mIomAAwwZVFv+sXDbL9yMbSMo/SjVsTQprn0Q0EnwLo7nmmOM6HAgMB
AAGjEjAQMA4GA1UdDwEB/wQEAwIE8DANBgkqhkiG9w0BAQsFAAOCAQEAn3Y4/C1h9R6ONsBqX3/q
XfHX7yX1FM0Y1x48X3/Yxk6HivAkTukhhhVYVKJsbrbzRqHDp9vhAP/FR6o6pAevaYMmLov0VMXU
7oAuetgkaYEYkDuNen5/Hpdhqi2vTtdT+q9w8zHJd6MDQ0aoHgIxpLKw5vof2R1N4fwSgNXMiXE5
kmllKQMem/+on2p+Sj80/2asxryHIGlH87qPzkffv+kIOkZthbTApTFLLjdVri2QHGe8/cc4xy01
/9iR3IUzNahotT41lJ4bMevBY7XMAS3n5ekyABN/9ZRJqhWdXgmFCRN/u56qd6lDgu7R2M2QUoyc
LuW5DfgRItKlmUB7sw==
-----END CERTIFICATE-----

View File

@@ -0,0 +1 @@
MIIEpQIBAAKCAQEAsYoUD39UPoPQK4DOOrIVai+bG0YWv3X/eQ9T7FshKgDZlphuM2fpLGA1UK3OCs4UkaIuerJv+0wxWE5jOm1vfXv5ylufFscmQeInuagNxeJ9e6bD8g6m/CKJHEJgqYUhQ2l0A+gKF6CzbUzNLwCNnFsQIHm2f77K3Y+TbY7fkadzWDApNYucaLx3HyC8yxBm10F6E/X6OQIR6LU1musR1Sz1BmPasC80MR8wZr8uWTH5znyZJv5g4rE0qgFAefRd3c1P9QEVZBGdcDIG+F6IVgcOorq0H2zJx9piKJgAMMGVRb/rFw2y/cjG0jKP0o1bE0Ka59ENBJ8C6O55pjjOhwIDAQABAoIBAFetNfz1R7hbxjlFshMAkVzQR8wvT9qbvl+dtzdZRcaFhu89NecDIP7+QDYor0FcxoGpU0TazDyRQyk2BQD8vHt+9zv9BVLtZLJSqoWgPbUFBi1DjS8EF2ka8RVYnn35NhUhhd7L//ftL88Bh673mfembQ9srDjoEy1Z01feoABAnCMkNFl986DmEwnarvEufXSDIgeN4ioMxha4NvfIPuI0zpVdV1O9sv+SGC+VEWZBtN3GNsaf4zS/f8FVGvTiU/Abz0gSw/iwSPHclDWQDTN3yFHf/tfqlzh0mH0WfhnuOBFWXzK+R7fbnM+asI9ttvzRcfpzgRGXdPcNcOv/6cECgYEA3DVqpi1k8MYfJixju6SG5gfyhM4VFksFmCMaNPgtatDMBKLMTgV/Ej6LXREojcy29uZl83F09pVlpd41eG39ULIPktixA/BqErQ2UaWh6kOxifycpu22Jh0r09hax6UgVrcBrrnCJEjcFsuJlrZvXQSzc3PBxjWy5gjabS5h9iECgYEAzmVAIh2frF01Y95zsLueAhhZwCtPanm6kf7ivR4r1plIX3b2sNRhWGmEHFgaCE6Braa0ogQ73Hd26kw4ZW+D6QMGC/zjCBEzDLLf++SjdVUHiY5AR4WHqXzq1jdAlsVyo9R661oAOp3lhiJVGLNXkHyEfEVPHsaxJh4osYSbX6cCgYEAx32Qx0i6eDFTyLZQB46uMrgiaVN04QRH5iJuvGvUYT8UhGKjaU8rZfDJOh+wOH2rhxMEaz1uc3C2bERY9mfWI4Ob/jFWc7YZsiYWS3Mcsuhubw4tMECLUg39RWZsHw8ls8kIuixIh6yFzhTH6YQOcRswIrhMZG8DScfdcSmiz2ECgYEAkWP1t5KSpkLKl11etcKUXfl1T8+yk9jIOowIgRw92WAFAWq2AH67TCKYM7dEL1HOO9tRJ0hAOt/U3ttuZtYVYBEHM26jJ02mXm2rJrA7DS4mrxmL4lYH6LbcXqZxU0Qnq4zEQgIWYzRTORf6Rfof1uJAGaJhR9bDd4yLMfGt2cUCgYEAo216Y61xOHUTA4AF1eekk+r+uOcQgQDvLXfs9FkDdJLk0mPG48/+eIYpPFnANJ/riF/DWOp8WGEe2IzA9yUFexzDbNQK8ha9kGcxaSAyiCwzjZ/t9/+hScDSV8kNqWSRSisu/YOFleEHbokT6mbLZ+gdqES8mUUanaEBzRQYGxo=

View File

@@ -1,79 +0,0 @@
{
"data": [
"task(owy5niy1sbbnlq0)",
"A beautiful Chinese girl plays the guitar on the beach. She is dressed in a flowing dress that matches the colors of the sunset. With her eyes closed, she strums the guitar with passion and confidence, her fingers dancing gracefully on the strings. The painting employs a vibrant color palette, capturing the warmth of the setting sun blending with the serene hues of the ocean. The artist uses a combination of impressionistic and realistic brushstrokes to convey both the girl's delicate features and the dynamic movement of the waves. The rendering effect creates a dream-like atmosphere, as if the viewer is being transported to a magical realm where music and nature intertwine. The picture is bathed in a soft, golden light, casting a warm glow on the girl's face, illuminating her joy and connection to the music she creates.",
"",
[],
30,
"DPM++ 3M SDE Karras",
1,
1,
7,
512,
512,
false,
0.7,
2,
"Latent",
0,
0,
0,
"Use same checkpoint",
"Use same sampler",
"",
"",
[],
"None",
false,
"",
0.8,
-1,
false,
-1,
0,
0,
0,
null,
null,
null,
false,
false,
"positive",
"comma",
0,
false,
false,
"",
"Seed",
"",
[],
"Nothing",
"",
[],
"Nothing",
"",
[],
true,
false,
false,
false,
0,
null,
null,
false,
null,
null,
false,
null,
null,
false,
50,
[],
"",
"",
""
],
"event_data": null,
"fn_index": 316,
"session_hash": "ttr8efgt63g"
}

View File

@@ -1,19 +1,26 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"errors"
"fmt"
"geekai/core/types"
"github.com/imroc/req/v3"
"time"
)
type CaptchaService struct {
config types.ChatPlusApiConfig
config types.ApiConfig
client *req.Client
}
func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService {
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
return &CaptchaService{
config: config,
client: req.C().SetTimeout(10 * time.Second),
@@ -60,3 +67,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true
}
func (s *CaptchaService) SlideGet() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
}
if res.Code != types.Success {
return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
}
return res.Data, nil
}
func (s *CaptchaService) SlideCheck(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return false
}
if res.Code != types.Success {
return false
}
return true
}

View File

@@ -0,0 +1,307 @@
package dalle
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// DALL-E 绘画服务
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.DallTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
go func() {
for {
var task types.DallTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
_, err = s.Image(task, false)
if err != nil {
logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
}
}
}()
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
Style string `json:"style"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
s.db.Where("id", task.UserId).First(&user)
if user.Power < task.Power {
return "", errors.New("insufficient of power")
}
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI).
Where("type", "img").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
}
var res imgRes
var errRes ErrRes
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
Style: task.Style,
Quality: task.Quality,
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
}
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running DALL-E task notify checking ...")
for {
var message sd.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.DallJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
if err != nil {
logger.Error("error with download image: %s, error: %v", imgURL, err)
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
if err != nil {
return "", err
}
// update img_url
res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
return imgURL, nil
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@@ -1,116 +0,0 @@
package fun
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
// AI 绘画函数
type FuncImage struct {
name string
db *gorm.DB
uploadManager *oss.UploaderManager
proxyURL string
}
func NewImageFunc(db *gorm.DB, manager *oss.UploaderManager, config *types.AppConfig) FuncImage {
return FuncImage{
db: db,
name: "DALL-E3 绘画",
uploadManager: manager,
proxyURL: config.ProxyURL,
}
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
func (f FuncImage) Invoke(params map[string]interface{}) (string, error) {
logger.Infof("绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx := f.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("error with get generation API KEY: %v", tx.Error)
}
// get image generation api URL
var conf model.Config
var chatConfig types.ChatConfig
tx = f.db.Where("marker", "chat").First(&conf)
if tx.Error != nil {
return "", fmt.Errorf("error with get chat configs: %v", tx.Error)
}
err := utils.JsonDecode(conf.Config, &chatConfig)
if err != nil {
return "", fmt.Errorf("error with decode chat config: %v", err)
}
apiURL := chatConfig.DallApiURL
if utils.IsEmptyValue(apiURL) {
apiURL = "https://api.openai.com/v1/images/generations"
}
imgNum := chatConfig.DallImgNum
if imgNum <= 0 {
imgNum = 1
}
var res imgRes
var errRes ErrRes
r, err := req.C().SetProxyURL(f.proxyURL).R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: imgNum,
Size: "1024x1024",
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiURL)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
}
// 存储图片
imgURL, err := f.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
if err != nil {
return "", fmt.Errorf("下载图片失败: %s", err.Error())
}
logger.Info(imgURL)
return fmt.Sprintf("\n\n![](%s)\n", imgURL), nil
}
func (f FuncImage) Name() string {
return f.name
}
var _ Function = &FuncImage{}

View File

@@ -1,40 +0,0 @@
package fun
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/service/oss"
"gorm.io/gorm"
)
type Function interface {
Invoke(map[string]interface{}) (string, error)
Name() string
}
var logger = logger2.GetLogger()
type resVo struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data struct {
Title string `json:"title"`
UpdatedAt string `json:"updated_at"`
Items []dataItem `json:"items"`
} `json:"data"`
}
type dataItem struct {
Title string `json:"title"`
Url string `json:"url"`
Remark string `json:"remark"`
}
func NewFunctions(config *types.AppConfig, db *gorm.DB, manager *oss.UploaderManager) map[string]Function {
return map[string]Function{
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
types.FuncHeadLine: NewHeadLines(config.ApiConfig),
types.FuncImage: NewImageFunc(db, manager, config),
}
}

View File

@@ -1,58 +0,0 @@
package fun
import (
"chatplus/core/types"
"errors"
"fmt"
"github.com/imroc/req/v3"
"strings"
"time"
)
// 今日头条函数实现
type FuncHeadlines struct {
name string
config types.ChatPlusApiConfig
client *req.Client
}
func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
return FuncHeadlines{
name: "今日头条",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/headline/fetch", f.config.ApiURL)
var res resVo
r, err := f.client.R().
SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", err, r.Err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
for i, v := range res.Data.Items {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Url, v.Remark))
}
return strings.Join(builder, "\n\n"), nil
}
func (f FuncHeadlines) Name() string {
return f.name
}
var _ Function = &FuncHeadlines{}

View File

@@ -1,58 +0,0 @@
package fun
import (
"chatplus/core/types"
"errors"
"fmt"
"github.com/imroc/req/v3"
"strings"
"time"
)
// 微博热搜函数实现
type FuncWeiboHot struct {
name string
config types.ChatPlusApiConfig
client *req.Client
}
func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
return FuncWeiboHot{
name: "微博热搜",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/weibo/fetch", f.config.ApiURL)
var res resVo
r, err := f.client.R().
SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", err, r.Err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
for i, v := range res.Data.Items {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
}
return strings.Join(builder, "\n\n"), nil
}
func (f FuncWeiboHot) Name() string {
return f.name
}
var _ Function = &FuncWeiboHot{}

View File

@@ -1,59 +0,0 @@
package fun
import (
"chatplus/core/types"
"errors"
"fmt"
"github.com/imroc/req/v3"
"strings"
"time"
)
// 每日早报函数实现
type FuncZaoBao struct {
name string
config types.ChatPlusApiConfig
client *req.Client
}
func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
return FuncZaoBao{
name: "每日早报",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/zaobao/fetch", f.config.ApiURL)
var res resVo
r, err := f.client.R().
SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", err, r.Err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
for _, v := range res.Data.Items {
builder = append(builder, v.Title)
}
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
return strings.Join(builder, "\n\n"), nil
}
func (f FuncZaoBao) Name() string {
return f.name
}
var _ Function = &FuncZaoBao{}

View File

@@ -0,0 +1,195 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store"
"strings"
"time"
"github.com/imroc/req/v3"
"github.com/shirou/gopsutil/host"
)
type LicenseService struct {
config types.ApiConfig
levelDB *store.LevelDB
license *types.License
urlWhiteList []string
machineId string
}
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License
var machineId string
_ = levelDB.Get(types.LicenseKey, &license)
info, err := host.Info()
if err == nil {
machineId = info.HostID
}
logger.Infof("License: %+v", license)
return &LicenseService{
config: server.Config.ApiConfig,
levelDB: levelDB,
license: &license,
machineId: machineId,
}
}
type License struct {
Name string `json:"name"`
License string `json:"license"`
Mid string `json:"mid"`
ActiveAt int64 `json:"active_at"`
ExpiredAt int64 `json:"expired_at"`
UserNum int `json:"user_num"`
}
// ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return fmt.Errorf("发送激活请求失败: %v", err)
}
if response.IsErrorState() {
return fmt.Errorf("发送激活请求失败:%v", response.Status)
}
if res.Code != types.Success {
return fmt.Errorf("激活失败:%v", res.Message)
}
s.license = &types.License{
Key: license,
MachineId: machineId,
UserNum: res.Data.UserNum,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}
err = s.levelDB.Put(types.LicenseKey, s.license)
if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err)
}
return nil
}
// SyncLicense 定期同步 License
func (s *LicenseService) SyncLicense() {
go func() {
retryCounter := 0
for {
license, err := s.fetchLicense()
if err != nil {
retryCounter++
if retryCounter < 5 {
logger.Error(err)
}
s.license.IsActive = false
} else {
s.license = license
}
urls, err := s.fetchUrlWhiteList()
if err == nil {
s.urlWhiteList = urls
}
time.Sleep(time.Second * 10)
}
}()
}
func (s *LicenseService) fetchLicense() (*types.License, error) {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
response, err := req.C().R().
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return nil, fmt.Errorf("发送激活请求失败: %v", err)
}
if response.IsErrorState() {
return nil, fmt.Errorf("激活失败:%v", response.Status)
}
if res.Code != types.Success {
return nil, fmt.Errorf("激活失败:%v", res.Message)
}
return &types.License{
Key: res.Data.License,
MachineId: res.Data.Mid,
UserNum: res.Data.UserNum,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}, nil
}
func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data []string `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)
}
if response.IsErrorState() {
return nil, fmt.Errorf("发送请求失败:%v", response.Status)
}
if res.Code != types.Success {
return nil, fmt.Errorf("获取白名单失败:%v", res.Message)
}
return res.Data, nil
}
// GetLicense 获取许可信息
func (s *LicenseService) GetLicense() *types.License {
return s.license
}
// IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行
if s.license.IsActive {
if s.license.MachineId != s.machineId {
return errors.New("系统使用了盗版的许可证书")
}
if time.Now().Unix() > s.license.ExpiredAt {
return errors.New("系统许可证书已经过期")
}
return nil
}
for _, v := range s.urlWhiteList {
if strings.HasPrefix(uri, v) {
return nil
}
}
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
}

View File

@@ -1,215 +0,0 @@
package mj
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"net/http"
"net/url"
"regexp"
"strings"
)
// MidJourney 机器人
var logger = logger2.GetLogger()
type Bot struct {
config *types.MidJourneyConfig
bot *discordgo.Session
name string
service *Service
}
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.BotToken)
if err != nil {
return nil, err
}
if proxy != "" {
proxy, _ := url.Parse(proxy)
discord.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
discord.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
}
}
return &Bot{
config: config,
bot: discord,
name: name,
service: service,
}, nil
}
func (b *Bot) Run() error {
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate)
logger.Infof("Starting MidJourney %s", b.name)
err := b.bot.Open()
if err != nil {
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
return err
}
logger.Infof("Starting MidJourney %s successfully!", b.name)
return nil
}
type TaskStatus string
const (
Start = TaskStatus("Started")
Running = TaskStatus("Running")
Stopped = TaskStatus("Stopped")
Finished = TaskStatus("Finished")
)
type Image struct {
URL string `json:"url"`
ProxyURL string `json:"proxy_url"`
Filename string `json:"filename"`
Width int `json:"width"`
Height int `json:"height"`
Size int `json:"size"`
Hash string `json:"hash"`
}
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
// parse content
req := CBReq{
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: 0,
Status: Start}
b.service.Notify(req)
return
}
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Stopped)") {
req := CBReq{
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: extractProgress(m.Content),
Status: Stopped}
b.service.Notify(req)
return
}
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
progress := extractProgress(content)
var status TaskStatus
if progress == 100 {
status = Finished
} else {
status = Running
}
for _, attachment := range attachments {
if attachment.Width == 0 || attachment.Height == 0 {
continue
}
image := Image{
URL: attachment.URL,
Height: attachment.Height,
ProxyURL: attachment.ProxyURL,
Width: attachment.Width,
Size: attachment.Size,
Filename: attachment.Filename,
Hash: extractHashFromFilename(attachment.Filename),
}
req := CBReq{
MessageId: messageId,
ReferenceId: referenceId,
Image: image,
Prompt: extractPrompt(content),
Content: content,
Progress: progress,
Status: status,
}
b.service.Notify(req)
break // only get one image
}
}
// extract prompt from string
func extractPrompt(input string) string {
pattern := `\*\*(.*?)\*\*`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
func extractProgress(input string) int {
pattern := `\((\d+)\%\)`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return utils.IntValue(matches[1], 0)
}
return 100
}
func extractHashFromFilename(filename string) string {
if !strings.HasSuffix(filename, ".png") {
return ""
}
index := strings.LastIndex(filename, "_")
if index != -1 {
return filename[index+1 : len(filename)-4]
}
return ""
}

View File

@@ -1,145 +1,68 @@
package mj
import (
"chatplus/core/types"
"fmt"
"github.com/imroc/req/v3"
"time"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// MidJourney client
import "geekai/core/types"
type Client struct {
client *req.Client
config *types.MidJourneyConfig
type Client interface {
Imagine(task types.MjTask) (ImageRes, error)
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error)
Upscale(task types.MjTask) (ImageRes, error)
Variation(task types.MjTask) (ImageRes, error)
QueryTask(taskId string) (QueryRes, error)
}
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second)
// set proxy URL
if proxy != "" {
client.SetProxyURL(proxy)
}
logger.Info(proxy)
return &Client{client: client, config: config}
type ImageReq struct {
BotType string `json:"botType,omitempty"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter interface{} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook,omitempty"`
State string `json:"state,omitempty"`
}
func (c *Client) Imagine(prompt string) error {
interactionsReq := &InteractionsRequest{
Type: 2,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
SessionID: SessionID,
Data: map[string]any{
"version": "1166847114203123795",
"id": "938956540159881230",
"name": "imagine",
"type": "1",
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"value": prompt,
},
},
"application_command": map[string]any{
"id": "938956540159881230",
"application_id": ApplicationID,
"version": "1118961510123847772",
"default_permission": true,
"default_member_permissions": nil,
"type": 1,
"nsfw": false,
"name": "imagine",
"description": "Create images with Midjourney",
"dm_permission": true,
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"description": "The prompt to imagine",
"required": true,
},
},
"attachments": []any{},
},
},
}
url := "https://discord.com/api/v9/interactions"
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %w%v", err, r.Err)
}
return nil
type ImageRes struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
type QueryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}

View File

@@ -0,0 +1,240 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
}
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
}
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &PlusClient{}

View File

@@ -1,59 +1,158 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"github.com/go-redis/redis/v8"
"time"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, licenseService *service.LicenseService) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.MjConfigs {
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
for k, config := range appConfig.MjPlusConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(&config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, queue, 4, 600, db, client, manager, appConfig.ProxyURL)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
err := licenseService.IsValidApiURL(config.ApiURL)
if err != nil {
logger.Error(err)
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
service.Run()
plusService.Run()
}()
services = append(services, plusService)
}
services = append(services, service)
for k, config := range appConfig.MjProxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
proxyService := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
proxyService.Run()
}()
services = append(services, proxyService)
}
return &ServicePool{
taskQueue: queue,
services: services,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
}
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
@@ -64,3 +163,56 @@ func (p *ServicePool) PushTask(task types.MjTask) {
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
if res.Error != nil {
continue
}
for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second * 10)
}
}()
}
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if s.Name == name {
return s
}
}
return nil
}

View File

@@ -0,0 +1,185 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -1,55 +1,48 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"gorm.io/gorm"
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
)
// Service MJ 绘画服务
type Service struct {
name string // service name
client *Client // MJ client
taskQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
proxyURL string
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
}
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, proxy string) *Service {
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
return &Service{
name: name,
db: db,
taskQueue: queue,
client: client,
uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
proxyURL: proxy,
taskStartTimes: make(map[int]time.Time, 0),
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.name)
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
@@ -57,96 +50,143 @@ func (s *Service) Run() {
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
break
case types.TaskUpscale:
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
break
case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
}
if err != nil {
logger.Error("绘画任务执行失败:", err)
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
atomic.AddInt32(&s.handledTaskNum, -1)
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
}
}
func (s *Service) Notify(data CBReq) {
// extract the task ID
split := strings.Split(data.Prompt, " ")
var job model.MidJourneyJob
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil && data.Status == Finished {
logger.Warn("重复消息:", data.MessageId)
return
}
res = s.db.Where("task_id = ?", split[0]).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
// upload image
if data.Status == Finished {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
job.ImgURL = imgURL
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
if data.Status == Finished {
// update user's img calls
s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// release lock task
atomic.AddInt32(&s.handledTaskNum, -1)
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -1,34 +0,0 @@
package mj
const (
ApplicationID string = "936929561302675456"
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
)
type InteractionsRequest struct {
Type int `json:"type"`
ApplicationID string `json:"application_id"`
MessageFlags *int `json:"message_flags,omitempty"`
MessageID *string `json:"message_id,omitempty"`
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
SessionID string `json:"session_id"`
Data map[string]any `json:"data"`
Nonce string `json:"nonce,omitempty"`
}
type InteractionsResult struct {
Code int `json:"code"`
Message string
Error map[string]any
}
type CBReq struct {
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`
Content string `json:"content"`
Prompt string `json:"prompt"`
Status TaskStatus `json:"status"`
Progress int `json:"progress"`
}

View File

@@ -1,15 +1,25 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/gin-gonic/gin"
"geekai/core/types"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/gin-gonic/gin"
)
type AliYunOss struct {
@@ -44,16 +54,16 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
}
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
if err != nil {
return "", err
return File{}, err
}
// 打开上传文件
src, err := file.Open()
if err != nil {
return "", err
return File{}, err
}
defer src.Close()
@@ -62,10 +72,16 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
// 上传文件
err = s.bucket.PutObject(objectKey, src)
if err != nil {
return "", err
return File{}, err
}
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
return File{
Name: file.Filename,
ObjKey: objectKey,
URL: fmt.Sprintf("%s/%s", s.config.Domain, objectKey),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -83,7 +99,7 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
@@ -93,10 +109,29 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.bucket.DeleteObject(key)
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.bucket.DeleteObject(objectKey)
}
var _ Uploader = AliYunOss{}

View File

@@ -1,9 +1,17 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/gin-gonic/gin"
"net/url"
"os"
@@ -23,23 +31,30 @@ func NewLocalStorage(config *types.AppConfig) LocalStorage {
}
}
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) {
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {
return "", fmt.Errorf("error with get form: %v", err)
return File{}, fmt.Errorf("error with get form: %v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, file.Filename)
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
if err != nil {
return "", fmt.Errorf("error with generate filename: %s", err.Error())
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
}
// 将文件保存到指定路径
err = ctx.SaveUploadedFile(file, filePath)
err = ctx.SaveUploadedFile(file, path)
if err != nil {
return "", fmt.Errorf("error with save upload file: %s", err.Error())
return File{}, fmt.Errorf("error with save upload file: %s", err.Error())
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
ext := filepath.Ext(file.Filename)
return File{
Name: file.Filename,
ObjKey: path,
URL: utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path),
Ext: ext,
Size: file.Size,
}, nil
}
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -48,7 +63,7 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
filename := filepath.Base(parse.Path)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
if err != nil {
return "", fmt.Errorf("error with generate image dir: %v", err)
}
@@ -65,7 +80,24 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
err = os.WriteFile(filePath, imageData, 0644)
if err != nil {
return "", fmt.Errorf("error writing to file:%v", err)
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) Delete(fileURL string) error {
if _, err := os.Stat(fileURL); err == nil {
return os.Remove(fileURL)
}
filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1)
return os.Remove(filePath)
}

View File

@@ -1,17 +1,26 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"geekai/core/types"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type MiniOss struct {
@@ -65,34 +74,64 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {
return "", fmt.Errorf("error with get form: %v", err)
return File{}, fmt.Errorf("error with get form: %v", err)
}
// Open the uploaded file
fileReader, err := file.Open()
if err != nil {
return "", fmt.Errorf("error opening file: %v", err)
return File{}, fmt.Errorf("error opening file: %v", err)
}
defer fileReader.Close()
fileExt := filepath.Ext(file.Filename)
fileExt := utils.GetImgExt(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Content-Type"),
})
if err != nil {
return "", fmt.Errorf("error uploading to MinIO: %v", err)
return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
}
return File{
Name: file.Filename,
ObjKey: info.Key,
URL: fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s MiniOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
objectKey,
strings.NewReader(string(imageData)),
int64(len(imageData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.client.RemoveObject(context.Background(), s.config.Bucket, key, minio.RemoveObjectOptions{})
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{})
}
var _ Uploader = MiniOss{}

View File

@@ -1,17 +1,27 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"net/url"
"path/filepath"
"time"
)
type QinNiuOss struct {
@@ -50,16 +60,16 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
}
}
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
if err != nil {
return "", err
return File{}, err
}
// 打开上传文件
src, err := file.Open()
if err != nil {
return "", err
return File{}, err
}
defer src.Close()
@@ -70,10 +80,17 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
extra := storage.PutExtra{}
err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra)
if err != nil {
return "", err
return File{}, err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
return File{
Name: file.Filename,
ObjKey: key,
URL: fmt.Sprintf("%s/%s", s.config.Domain, ret.Key),
Ext: fileExt,
Size: file.Size,
}, nil
}
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -91,7 +108,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
fileExt := utils.GetImgExt(parse.Path)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
ret := storage.PutRet{}
extra := storage.PutExtra{}
@@ -103,10 +120,32 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.manager.Delete(s.config.Bucket, key)
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else {
objectKey = fileURL
}
return s.manager.Delete(s.config.Bucket, objectKey)
}
var _ Uploader = QinNiuOss{}

View File

@@ -1,9 +1,29 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
type File struct {
Name string `json:"name"`
ObjKey string `json:"obj_key"`
Size int64 `json:"size"`
URL string `json:"url"`
Ext string `json:"ext"`
}
type Uploader interface {
PutFile(ctx *gin.Context, name string) (string, error)
PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@@ -1,7 +1,14 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"geekai/core/types"
"strings"
)
@@ -9,11 +16,6 @@ type UploaderManager struct {
handler Uploader
}
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
active := Local
if config.OSS.Active != "" {

View File

@@ -1,9 +1,16 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/smartwalle/alipay/v3"
"log"
"net/url"

View File

@@ -1,9 +1,19 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
@@ -16,57 +26,144 @@ import (
type HuPiPayService struct {
appId string
appSecret string
host string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
host: config.HuPiPayConfig.PayURL,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
type HuPiPayReq struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
TotalFee string `json:"total_fee"`
Title string `json:"title"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
WapName string `json:"wap_name"`
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
}
type HuPiResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params map[string]string) (string, error) {
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params["appid"] = s.appId
params["time"] = simple
params["nonce_str"] = simple
for k, v := range params {
data.Add(k, v)
params.AppId = s.appId
params.Time = simple
params.NonceStr = simple
encode := utils.JsonEncode(params)
m := make(map[string]string)
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
data.Add("hash", s.Sign(params))
resp, err := http.PostForm(s.host, data)
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return "error", err
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return "error", err
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
}
return string(all), err
var res HuPiResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params map[string]string) string {
var data string
keys := make([]string, 0, 0)
params["appid"] = s.appId
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
keys = append(keys, key)
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
//拼接
for _, k := range keys {
data = fmt.Sprintf("%s%s=%s&", data, k, params[k])
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(tradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("open_order_id", tradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
data = strings.Trim(data, "&")
data = fmt.Sprintf("%s%s", data, s.appSecret)
m := md5.New()
m.Write([]byte(data))
sign := fmt.Sprintf("%x", m.Sum(nil))
return sign
}

Some files were not shown because too many files have changed in this diff Show More