Compare commits

...

235 Commits

Author SHA1 Message Date
RockYang
825a1b1027 文档对话不显示文档列表的 bug 2025-03-06 14:00:13 +08:00
RockYang
950e7d1b00 Merge branch 'dev' 2025-03-05 18:42:44 +08:00
RockYang
d8f9f48278 merge v4.1.6 2025-03-05 18:42:30 +08:00
RockYang
6a066f1b8e 默认允许跨域请求API 2025-02-11 14:34:31 +08:00
RockYang
b7d137247a 优化 docker 容器依赖关系 2025-02-11 10:10:43 +08:00
RockYang
4373642ebd 优化打包脚本,新增 arm64 架构打包脚本 2025-02-11 10:01:05 +08:00
RockYang
5bf90920f5 resolve conflicts 2025-02-11 09:53:41 +08:00
RockYang
edcbb3e226 merge v4.1.5 2025-02-11 09:45:26 +08:00
RockYang
5213bdf08b fixed bug for calculate chat message tokens 2025-01-19 13:08:38 +08:00
RockYang
cf817fd8ea merge code for v4.1.4 2025-01-18 23:20:41 +08:00
RockYang
a2481ff1cf update readme 2025-01-07 11:58:16 +08:00
RockYang
bc7d06d3e5 修复登录失效的 Bug 2024-12-21 21:44:52 +08:00
RockYang
8e81dfa12a update database name 2024-12-16 11:26:54 +08:00
RockYang
0ff76f0f21 update database file 2024-12-16 11:11:42 +08:00
RockYang
787caa84c8 merge config.toml file 2024-12-16 10:09:23 +08:00
RockYang
c2503e663a merge v4.1.3 2024-12-16 10:07:52 +08:00
RockYang
405a88862b Merge branch 'dev' 2024-11-27 15:04:26 +08:00
RockYang
296eabe09a merge v4.1.2 2024-11-27 15:00:02 +08:00
RockYang
54b45ec2ff update docker-compose.yaml 2024-11-14 16:37:43 +08:00
RockYang
c434f85045 update database 2024-11-14 16:01:27 +08:00
RockYang
4d10279870 merge v4.1.1 and fixed conflicts 2024-11-13 18:40:04 +08:00
RockYang
9de9489673 remove sensitive words 2024-11-04 10:05:57 +08:00
RockYang
9814fec930 update the default api url to https://api.geekai.pro 2024-10-29 14:09:55 +08:00
RockYang
53ba731159 update database sql file 2024-10-29 14:07:49 +08:00
RockYang
2081d3ce29 add databse sql file 2024-10-23 20:01:17 +08:00
RockYang
1fe1e40a43 优化实时语音对话组件,处理异常 2024-10-23 18:04:09 +08:00
RockYang
ad6e2dd370 modify text link color for register page 2024-10-21 18:26:19 +08:00
RockYang
bb63f23414 Merge branch 'main' of gitee.com:blackfox/geekai-plus 2024-10-21 18:21:21 +08:00
RockYang
43f6bf74f2 更换登录页面背景图片 2024-10-21 18:21:04 +08:00
RockYang
662d7b099e 给 realtime 语音对话增加音效 2024-10-18 06:26:05 +08:00
RockYang
d5eeeea764 优化充值产品定价逻辑,确保手机端和PC端显示的价格一致 2024-10-17 18:15:25 +08:00
RockYang
43c507c597 the relay server for openai websocket is ready 2024-10-17 16:46:41 +08:00
RockYang
e356771049 add websocket relayer for openai realtime api 2024-10-16 18:16:09 +08:00
RockYang
48139290ed integrated openai realtime console 2024-10-15 19:25:18 +08:00
RockYang
bd852c82b7 add PCM16 audio stream to wave is reday 2024-10-14 18:39:50 +08:00
RockYang
13564993d7 add voice chat test case 2024-10-12 19:07:29 +08:00
RockYang
bfc1e1bc2c suno and luma task management funtion in admin console is ready 2024-10-10 17:07:40 +08:00
RockYang
ba20717a09 image task list page for admin console is ready 2024-10-09 18:17:44 +08:00
RockYang
52e40daf23 fixed bug in FileSelect component for deleting files 2024-10-08 18:00:46 +08:00
RockYang
b2f57aa483 merge v4.1.0 and fixed conflicts 2024-10-08 17:54:08 +08:00
RockYang
4c2dba1004 merge v4.1.0 and fixed conflicts 2024-10-08 17:51:14 +08:00
RockYang
430a7b2297 fixed bug for websocket message handler rebind 2024-10-08 16:41:19 +08:00
RockYang
c91a38a882 fixed webscoket event re-bind bug 2024-10-05 21:18:59 +08:00
RockYang
6e02bee4b7 fixed alipay mobile payment 2024-10-05 11:45:44 +08:00
RockYang
b62218110e fixed alipay mobile payment 2024-10-05 10:23:00 +08:00
RockYang
e2960b2607 auto jump to mobile page when use mobile device access the page 2024-10-04 11:25:01 +08:00
RockYang
88e7c39066 fixed bug for: websocket is not auto connected when user not login 2024-10-02 07:26:34 +08:00
RockYang
2a6dd636fa update database 2024-09-30 17:12:23 +08:00
RockYang
6bf38f78d5 add message handler ONLY when websocket connect successfully 2024-09-30 16:33:26 +08:00
RockYang
5a04a935be support wechat and alipay payment for mobile page 2024-09-30 16:20:40 +08:00
RockYang
8923e938d2 optimize the vue component communication, replace event listening with share data 2024-09-30 14:20:59 +08:00
RockYang
1a1734abf0 websocket api refactor is ready 2024-09-29 19:28:47 +08:00
RockYang
8093a3eeb2 mj websocket refactor is ready 2024-09-29 07:51:08 +08:00
RockYang
9edb3d0a82 sd websocket refactor is finished 2024-09-27 18:28:54 +08:00
RockYang
d95fab11be refactor websocket message protocol, keep the only connection for all clients 2024-09-27 17:50:54 +08:00
RockYang
6ef09c8ad5 add ws handler 2024-09-25 18:43:12 +08:00
RockYang
283a023a06 update database sql file for v4.1.4 2024-09-23 15:54:22 +08:00
RockYang
d315edef5f logout the user when it has been disabled 2024-09-20 16:49:03 +08:00
RockYang
5fa17b300e add release v4.1.4 2024-09-20 15:50:04 +08:00
RockYang
32919de7a7 add email white list check in register handler 2024-09-20 14:10:40 +08:00
RockYang
7d126aab41 add email white list 2024-09-20 10:40:37 +08:00
RockYang
16ac57ced3 payment for mobile page is ready 2024-09-19 19:03:03 +08:00
RockYang
4976b967e7 wechat payment for mobile page is ready 2024-09-19 17:59:27 +08:00
RockYang
e874178782 wechat payment for pc is ready 2024-09-19 14:42:25 +08:00
RockYang
8cb66ad01b fixed bug geek-plus#6, register page first tab not auto active 2024-09-19 09:03:14 +08:00
RockYang
f887a39912 geek payment notify api is ready 2024-09-18 22:24:05 +08:00
RockYang
2beffd3dd3 urgent bug fix: remove suno and luma task will recharge user power 2024-09-18 20:33:29 +08:00
RockYang
d8cb92d8d4 Geek Pay notify is ready 2024-09-18 18:07:49 +08:00
RockYang
158db83965 add geek payment 2024-09-18 07:03:46 +08:00
RockYang
603bfa7def recommend user to use Google Chrome 2024-09-15 10:25:50 +08:00
RockYang
829fb879a6 merge app type function branch 2024-09-14 18:17:55 +08:00
RockYang
0385e60ce1 refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model 2024-09-14 17:06:13 +08:00
胡双明
aaea23f785 feat: 应用分类功能 2024-09-14 11:05:49 +08:00
RockYang
131efd6ba5 refactor chat message body struct 2024-09-14 07:11:45 +08:00
RockYang
866564370d return at least one chat role for getUserRoles API 2024-09-14 05:55:56 +08:00
RockYang
dcdc0d8918 add tid field for chat app role 2024-09-13 18:32:13 +08:00
RockYang
6c7fa17e50 fixed bug, filelist page support pagination, do not load captcha component for user login first time 2024-09-13 17:03:05 +08:00
胡双明
38a0d00142 Merge branch 'main' into husm_2024-09-02 2024-09-13 10:25:15 +08:00
RockYang
5c77e67b0f fixed bug for reset password 2024-09-12 17:25:19 +08:00
RockYang
961cee5e41 fixed bug for register page code verification 2024-09-12 15:42:09 +08:00
胡双明
c9cc93be8c Merge branch 'main' into husm_2024-09-02 2024-09-11 09:57:01 +08:00
RockYang
49f2e1a71e optimize download function for suno 2024-09-11 08:39:28 +08:00
胡双明
97eff6085a feat: 210 AI对话页面文件列表增加分页功能 2024-09-10 18:14:34 +08:00
RockYang
8b2e2d61af file list api support pagination 2024-09-10 15:24:36 +08:00
胡双明
c096efb416 Merge branch 'main' into husm_2024-09-02 2024-09-10 14:29:08 +08:00
RockYang
cdaf6fb9dc update v4.1.3 database sql file 2024-09-10 11:15:26 +08:00
RockYang
78f443ed6d update v4.1.3 database sql file 2024-09-10 11:11:17 +08:00
RockYang
54e8d72b10 support multiple delete users, update database sql file 2024-09-10 10:56:04 +08:00
RockYang
05161f48fd update config file 2024-09-09 18:58:03 +08:00
RockYang
e971bf6b88 merge luma page code for v4.1.3 2024-09-09 18:07:10 +08:00
RockYang
55b979784c Merge remote-tracking branch 'inet/husm_2024-09-02' into dev-4.1.3 2024-09-09 10:54:43 +08:00
RockYang
79adc871ef export the newest database sql file 2024-09-05 15:29:00 +08:00
RockYang
97aa922b5f 优化聊天页面代码,刷新页面之后自动加载当前对话的 role_id 和 model_id 2024-09-05 14:50:37 +08:00
胡双明
11c760a4e8 feat: 生成视频页面2 2024-09-05 11:56:02 +08:00
RockYang
8144fada25 merged v4.0.9 and fixed conflicts 2024-09-05 11:02:32 +08:00
RockYang
87b03332d9 add auto execute task to downloading video files 2024-09-05 10:54:58 +08:00
胡双明
8b14eeadf4 feat: 生成视频页面 2024-09-05 08:50:49 +08:00
RockYang
e0ead127e0 remove unused css file 2024-09-04 22:42:56 +08:00
RockYang
0887bcdee0 adjust chat page styles 2024-09-04 18:07:39 +08:00
RockYang
67d83041d7 user can select function tools by themself 2024-09-04 14:53:21 +08:00
RockYang
1350f388f0 add page and total field for pagination vo 2024-09-03 18:32:15 +08:00
RockYang
65dde9e69d add sync lock for sub or add user's power 2024-09-03 12:09:36 +08:00
RockYang
2e5bd238b7 luma create and list api is ready 2024-09-02 18:08:50 +08:00
胡双明
8fc8fd6cba feat(Luma): 加上传前后帧调换功能 2024-09-02 17:33:03 +08:00
RockYang
dfc6c87250 optimize ChatPlus page, fixed bug for websocket reconnection 2024-09-02 16:35:15 +08:00
RockYang
b63e01225e add luma api service 2024-08-30 18:12:14 +08:00
RockYang
561b82027a update change log file 2024-08-30 16:47:52 +08:00
RockYang
f6d8fbf570 suno add new function for merging full songs and upload custom music 2024-08-30 16:46:48 +08:00
RockYang
568201ebbb add logo for favirate icon 2024-08-29 13:36:35 +08:00
RockYang
ab421f2185 luma page, upload image and remove image function is ready 2024-08-26 17:59:05 +08:00
RockYang
f71a2f5263 add download for video 2024-08-26 07:24:04 +08:00
RockYang
d000cc5a67 luma page video list component is ready 2024-08-23 18:25:58 +08:00
RockYang
04d6ba0853 luma page is ready 2024-08-20 14:31:40 +08:00
RockYang
8d7c028ca8 refactor reset password functions 2024-08-19 12:05:00 +08:00
RockYang
3ae7ebfeaf fixed styles 2024-08-19 06:42:53 +08:00
RockYang
aa42d38387 add bind mobile, bind email, bind wechat function is ready 2024-08-14 15:56:50 +08:00
RockYang
43843b92f2 add mobile and email filed for user 2024-08-13 18:40:50 +08:00
RockYang
5da879600a add verification code for login and register page 2024-08-13 14:55:47 +08:00
RockYang
87ed2064e3 add Captcha components 2024-08-13 10:01:07 +08:00
RockYang
34e96e91d4 fixed conflicts 2024-08-13 06:48:22 +08:00
RockYang
8c4c2b89ce add wechat login for login dialog 2024-08-12 18:00:34 +08:00
RockYang
373021c191 add drag icon for dragable rows 2024-08-12 14:00:50 +08:00
RockYang
740c3c1b00 release v4.1.2 2024-08-09 18:50:02 +08:00
RockYang
67c7132e6b redeem code function is ready 2024-08-09 18:19:51 +08:00
RockYang
c77843424b add redeem code function 2024-08-08 18:28:50 +08:00
RockYang
2d4959aa7d add cache for getting user info and system configs 2024-08-08 14:37:33 +08:00
RockYang
167c59a159 add clear unpaid order functions 2024-08-07 18:00:28 +08:00
RockYang
1d0006ce59 refactor stable diffusion service, use api key instead of configs 2024-08-07 17:30:59 +08:00
RockYang
6a8b4ee2f1 refactor midjourney service, use api key in database 2024-08-06 18:30:57 +08:00
RockYang
72b1515b68 show sql error message 2024-08-05 16:14:44 +08:00
RockYang
754ba02263 fixed build script 2024-08-01 18:45:53 +08:00
RockYang
7ddf57ae06 merge v4.0.8 2024-08-01 18:09:00 +08:00
RockYang
3f0252b498 update change log 2024-08-01 08:54:04 +08:00
RockYang
cc5180a6f7 update readme 2024-08-01 08:52:46 +08:00
RockYang
1d9d487f0e restore use power when removed not finish jobs 2024-07-31 16:08:46 +08:00
RockYang
96f1126d02 remove platform field for api key and chat model 2024-07-30 17:24:21 +08:00
RockYang
7f9b8d8246 update datebasesl 2024-07-30 14:55:49 +08:00
RockYang
5132d52a44 add back-to-top component for all list page 2024-07-29 11:00:53 +08:00
RockYang
1bcbf74883 remove chat debug log 2024-07-28 18:55:17 +08:00
RockYang
abdf5298fe add function to generate lyrics 2024-07-28 10:04:53 +08:00
RockYang
2129f7a8b7 song detail page is ready 2024-07-26 19:12:44 +08:00
RockYang
f6f8748521 adjust chat records layout styles 2024-07-25 11:01:27 +08:00
RockYang
59301df073 add put url file for oss interface 2024-07-23 18:36:26 +08:00
RockYang
e17dcf4d5f remove other platform supports, ONLY use chatGPT API 2024-07-22 18:36:58 +08:00
RockYang
09f44e6d9b optimize foot copyright snaps 2024-07-22 17:54:09 +08:00
RockYang
59824bffc5 add close button for music player 2024-07-22 07:12:21 +08:00
RockYang
cb0dacd5e0 enable use random pure color background for index page 2024-07-19 18:43:01 +08:00
RockYang
7463cfc66c the music player is ready 2024-07-18 18:34:11 +08:00
RockYang
b248560ba2 add suno page 2024-07-17 18:58:09 +08:00
RockYang
37368fe13f support upload file from clipboard 2024-07-17 10:23:02 +08:00
RockYang
246b023624 allow user to use chat role directly, no need to add to workspace 2024-07-16 18:28:08 +08:00
RockYang
9f44c34d34 update docs url 2024-07-16 18:15:34 +08:00
RockYang
a6b9f57a50 show error message for Midjourney task list page 2024-07-16 17:16:58 +08:00
RockYang
42bc23cacf fixed bug for function call for openai 2024-07-16 06:25:40 +08:00
RockYang
282f55c7a3 优化版权显示逻辑,允许激活用户更改自定义版权 2024-07-15 18:44:14 +08:00
RockYang
44798f89ba update docker-compose file 2024-07-12 18:13:43 +08:00
RockYang
596cb2b206 update database file, add tika host config 2024-07-12 18:10:32 +08:00
RockYang
d1965deff1 tidy apis 2024-07-12 14:39:14 +08:00
RockYang
b793b81768 update geekai image version 2024-07-05 11:09:13 +08:00
RockYang
a5ef4299ec wechat login is ready 2024-07-04 15:34:32 +08:00
RockYang
cdb1a8bde1 feat: support wechat login function 2024-07-02 18:27:06 +08:00
RockYang
233f6e00f0 fixed conflicts 2024-06-30 06:09:12 +08:00
RockYang
b7dba68549 optimize ngin configuration for chat-plus.conf 2024-06-30 05:44:39 +08:00
RockYang
64e5fc48ba docs: update change log file 2024-06-28 16:21:43 +08:00
RockYang
a692cf1338 feat: optimize chat page data list style, support list style and chat style 2024-06-28 15:53:49 +08:00
RockYang
6998dd7af4 feat: chat with file function is ready 2024-06-27 18:01:49 +08:00
RockYang
9343c73e0f enable set custom index background image 2024-06-27 10:49:31 +08:00
RockYang
739cd46539 add test code for reading pdf files 2024-06-26 18:50:48 +08:00
RockYang
f8fed83507 feat: new UI for chat file manager is ready 2024-06-25 18:59:27 +08:00
RockYang
d63536d5ef fixed bug: mobile chat list could not update chat title 2024-06-25 09:53:08 +08:00
RockYang
4905fb28d4 update version 2024-06-23 17:54:42 +08:00
RockYang
a3a2a8abcb add wechat payment configs sample 2024-06-22 16:04:04 +08:00
RockYang
839dd8dbf4 update docker-compose file 2024-06-22 12:37:41 +08:00
RockYang
0375164f40 add database files 2024-06-22 12:17:35 +08:00
RockYang
691294b444 finish mobile wechat payment 2024-06-22 12:10:43 +08:00
RockYang
bdea12c51a update docker image mirror url 2024-06-19 08:35:05 +08:00
RockYang
a27d9ea259 Merge branch 'main' of gitee.com:blackfox/geekai 2024-06-19 08:22:00 +08:00
RockYang
7cd824c284 update docker image version to 4.0.6 2024-06-15 15:35:32 +08:00
RockYang
e27d95e2b5 update docker image version to 4.0.6 2024-06-15 15:33:38 +08:00
RockYang
c24b4d7074 update change log files 2024-06-14 18:23:54 +08:00
GeekMaster
6839827db0 Merge branch 'main' of https://github.com/yangjian102621/geekai into main 2024-06-12 15:44:26 +08:00
RockYang
ab24398748 wechat payment is ready for PC 2024-06-12 14:20:37 +08:00
RockYang
6110522b54 change payment component, upgrade golang to 1.22.4 2024-06-11 11:48:41 +08:00
RockYang
bcdf5e3776 fix bug: free model not record the chat history 2024-06-06 15:01:32 +08:00
RockYang
2207830db9 fixe page styles 2024-06-05 18:08:23 +08:00
RockYang
d52dfbfef4 fixed bug markmap generation 2024-06-04 16:21:08 +08:00
RockYang
d6a04f96fe Merge pull request #208 from mari1995/main
图片墙,选择框单选问题
2024-06-04 08:33:51 +08:00
RockYang
66ccb387e8 dalle3 and gptt-4o api compatible with azure 2024-06-03 18:34:37 +08:00
RockYang
5f820b9dc1 merge v4.0.6 2024-06-03 14:22:08 +08:00
RockYang
3cc2263dc7 fixed bug for function call error None is not of type 'array' 2024-05-30 09:59:44 +08:00
RockYang
f0a3c5d8ae fixed bug for mobile chat share 2024-05-30 08:37:14 +08:00
RockYang
2a4ef27774 add v4.0.8 database sql file 2024-05-29 17:41:37 +08:00
RockYang
2b057f32aa feat: add dalle3 page for h5 2024-05-29 17:25:01 +08:00
RockYang
bc6451026f feat: add system config for enable rand background image for index page 2024-05-29 16:24:56 +08:00
RockYang
99fd596862 feat: add system config for enable rand background image for index page 2024-05-29 16:23:42 +08:00
RockYang
f0959b5df6 fix markdown formula parse plugin 2024-05-29 13:49:45 +08:00
SSMario
6788edbe9d Update Image.vue 2024-05-27 21:23:11 +08:00
SSMario
3895305882 Update Image.vue 2024-05-27 21:00:59 +08:00
RockYang
1b0938b33f micro fixs 2024-05-27 17:39:17 +08:00
SSMario
c2acbaaa94 Update ImagesWall.vue 2024-05-27 17:04:15 +08:00
RockYang
02faff461a fixed bug for dalle prompt translate 2024-05-27 11:42:14 +08:00
RockYang
e18e5a38c6 put model and app selector on the top of chat page 2024-05-24 12:33:22 +08:00
RockYang
2f9b1b7835 fixed bug for payment api authorization 2024-05-24 11:31:38 +08:00
RockYang
717b137a6d chore: use config value for order pay timeout 2024-05-22 18:15:06 +08:00
RockYang
f755bdccae feat: add sign check for PC QR code payment 2024-05-22 17:47:53 +08:00
RockYang
4bba77ab47 extract code for saving chat history 2024-05-22 15:32:44 +08:00
RockYang
6944a32ff3 check if the api url in whitelist for mj plus client 2024-05-22 11:47:04 +08:00
RockYang
5742b40aee fixed bug for mobile chat page change chat model not work 2024-05-21 17:54:03 +08:00
RockYang
7f1ec90748 auto resize the input element rows, when use inputed more than one line 2024-05-21 17:36:47 +08:00
RockYang
4a99be2f15 remove license code 2024-05-21 16:20:29 +08:00
RockYang
bee19392c1 add logs for updating database failed 2024-05-21 11:55:38 +08:00
RockYang
00d31a2379 update docker image url 2024-05-21 11:03:11 +08:00
RockYang
5d65505ab7 rename project name to geekai 2024-05-20 15:11:14 +08:00
RockYang
3dc7d0516a update database file 2024-05-19 19:37:16 +08:00
RockYang
50335ebc2d remove code for set left component fixed height 2024-05-18 08:07:09 +08:00
RockYang
bcadee7290 refactor login dialog for front page 2024-05-18 00:27:32 +08:00
RockYang
cac3194d5b finished refactor chat page UI 2024-05-17 19:25:38 +08:00
RockYang
d45f9fbad6 fixed bugs for send message captcha component 2024-05-16 22:33:00 +08:00
RockYang
d98b08d7cd feat: add top navbar for front page 2024-05-16 20:10:00 +08:00
RockYang
5a8fe5a6cf feat: support add external link menu 2024-05-16 10:53:00 +08:00
RockYang
36c27d6092 handler chat error in the chat entry func 2024-05-15 15:30:34 +08:00
RockYang
3ab29da8f0 add charge link for insufficient of power 2024-05-15 07:10:31 +08:00
RockYang
3699f024f1 fix bug for white-list api key check 2024-05-14 22:30:42 +08:00
RockYang
114d0088dc update version 2024-05-14 18:03:58 +08:00
RockYang
43b6665370 refactor: use waterflow component in mj, sd and dall image drawing page 2024-05-13 19:04:00 +08:00
RockYang
5fb9f84182 fixed sd page waterfall component 2024-05-11 18:27:35 +08:00
RockYang
e35c34ad9a enable to update AI Drawing configuarations in admin console page 2024-05-11 17:27:14 +08:00
RockYang
1a4d798f8b fix: markmap do not cost power in front page 2024-05-11 07:08:14 +08:00
RockYang
afb91a7023 optimize chat handler error handle code 2024-05-10 18:26:19 +08:00
RockYang
dc4c1f7877 fix bug: remove chat role failed 2024-05-10 17:38:55 +08:00
RockYang
bbc8fe2b40 fixed bug for dalle3 task not decrease power 2024-05-10 11:18:37 +08:00
RockYang
57c932f07c add toolbar for markmap component 2024-05-10 06:38:34 +08:00
RockYang
8b3b0139b0 use proxy for downloading discord images 2024-05-09 18:48:53 +08:00
RockYang
31828a3336 add stable diffusion default negtive prompt system config 2024-05-07 16:49:54 +08:00
RockYang
9a797bb4a5 add stable diffusion default negtive prompt system config 2024-05-07 16:21:31 +08:00
378 changed files with 30489 additions and 23026 deletions

View File

@@ -1,4 +1,121 @@
# 更新日志 # 更新日志
## v4.1.6
* 功能新增:**支持OpenAI实时语音对话功能** :rocket: :rocket: :rocket:, Beta 版,目前没有做算力计费控制,目前只有 VIP 用户可以使用。
* 功能优化优化MysQL容器配置文档解决MysQL容器资源占用过高问题
* 功能新增管理后台增加AI绘图任务管理可在管理后台浏览和删除用户的绘图任务
* 功能新增管理后台增加Suno和Luma任务管理功能
* Bug修复修复管理后台删除兑换码报 404 错误
* 功能优化:优化充值产品定价逻辑,可以设置原价和优惠价,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!****升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!****升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**。
## v4.1.5
* 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
* Bug修复兼容手机端原生微信支付和支付宝支付渠道
* Bug修复修复删除绘图任务时候因为字段长度过短导致SQL执行失败问题
* 功能优化:优化 Vue 组件通信代码,使用共享数据来替换之前的事件订阅模式,效率更高一些
* 功能优化:优化思维导图生成功果页面,优化用户体验
## v4.1.4
* 功能优化:用户文件列表组件增加分页功能支持
* Bug修复修复用户注册失败Bug注册操作只弹出一次行为验证码
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
* 功能新增:给 AI 应用(角色)增加分类,前端支持分类筛选
* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
* 功能优化移除PayJS支付渠道支持PayJs已经关闭注册服务请使用其他支付方式。
* 功能新增新增GeeK易支付支付渠道支持支付宝微信支付QQ钱包京东支付抖音支付Paypal支付等支付方式
* Bug修复修复注册页面 tab 组件没有自动选中问题 [#6](https://github.com/yangjian102621/geekai-plus/issues/6)
* 功能优化Luma生成视频任务增加自动翻译功能
* Bug修复Suno 和 Luma 任务没有判断用户算力
* 功能新增:邮箱注册增加邮箱后缀白名单,防止使用某些垃圾邮箱注册薅羊毛
* 功能优化清空未支付订单时只清空超过15分钟未支付的订单
## v4.1.3
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
* 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
* 功能优化:管理后台给可以拖动排序的组件添加拖动图标
* 功能优化Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
* Bug修复手机端角色和模型选择不生效
* Bug修复用户登录过期之后聊天页面出现大量报错需要刷新页面才能正常
* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
* 功能优化:给算力增减服务全部加上数据库事务和同步锁
* 功能优化:支持用户在前端对话界面选择插件
* 功能新增:支持 Luma 文生视频功能
## v4.1.2
* Bug修复修复思维导图页面获取模型失败的问题
* 功能优化优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
* Bug修复修复后台拖动排序组件 Bug
* 功能优化:更新数据库失败时候显示具体的的报错信息
* Bug修复修复管理后台对话详情页内容显示异常问题
* 功能优化:管理后台新增清空所有未支付订单的功能
* 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
* 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
## v4.1.1
* Bug修复修复 GPT 模型 function call 调用后没有输出的问题
* 功能新增:允许获取 License 授权用户可以自定义版权信息
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
* 功能新增MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
* 功能新增:允许在管理后台设置首页显示的导航菜单
* Bug修复修复注册页面先显示关闭注册组件然后再显示注册组件
* 功能新增:增加 Suno 文生歌曲功能
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
* 功能优化:在所有列表页面增加返回顶部按钮
## v4.1.0
* bug修复修复移动端修改聊天标题不生效的问题
* Bug修复修复用户注册不显示用户名的问题
* Bug修复修复管理后台拖动排序不生效的问题
* 功能优化:允许用户设置自定义首页背景图片
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
* 功能优化:优化聊天界面的用户上传文件的列表样式
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
## v4.0.9
* 环境升级:升级 Golang 到 go1.22.4
* 功能增加:接入微信商户号支付渠道
* Bug修复修复前端页面菜单把页面撑开底部留白问题
* 功能优化:聊天页面自动根据内容调整输入框的高度
* Bug修复修复Dalle绘图失败退回算力的问题
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
* 功能优化:允许设置邮件验证码的抬头
* Bug修复修复免费模型不会记录聊天记录的bug
* Bug修复修复聊天输入公式显示异常的Bug
## v4.0.8
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
* 功能优化:当数据库更新失败的时候记录错误日志
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
* Bug修复修复移动端聊天页面模型切换不生效的Bug
* 功能优化给PC端扫码支付增加签名验证和有效期验证
* Bug修复修复支付码生成API权限控制的问题
* Bug修复模型算力设置为0时不扣减用户算力并且不记录算力消费日志
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
* 功能新增H5端支持 Dalle 绘图
## v4.0.7
* 功能优化升级quic-go支持 Go1.21
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
* Bug修复修复弹窗等于图形验证码一直验证失败的问题
* 功能重构:重构前端 UI 页面,增加顶部导航
* 功能优化:优化 Vue 非父子组件之间的通信方式
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
## v4.0.6
* Bug修复修复PC端画廊页面的瀑布流组件样式错乱问题
* 功能新增:给思维导图增加 ToolBar实现思维导图的放大缩小和定位
* Bug修复修复思维导图不扣费的Bug
* Bug修复修复管理后台角色删除失败的Bug
* Bug修复兼容最新版秋叶SD懒人包的 SD API新增 scheduler 参数
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
* Bug修复修复注册用户提示注册人数达到上限的 Bug
* 功能优化将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
## v4.0.5 ## v4.0.5
@@ -38,6 +155,7 @@
* 功能新增:支持管理后台 Logo 修改 * 功能新增:支持管理后台 Logo 修改
## 4.0.2 ## 4.0.2
* 功能新增:支持前端菜单可以配置 * 功能新增:支持前端菜单可以配置
* 功能优化:在登录和注册界面标题显示软件版本号 * 功能优化:在登录和注册界面标题显示软件版本号
* 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数 * 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数

View File

@@ -1,15 +1,14 @@
# GeekAI # GeekAI
### 本项目已经正式更名为 GeekAI请大家及时更新代码克隆地址 > 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure, **GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Claude, 通义千问KimiDeepSeekGitee AI 等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
主要特性: 主要特性:
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。 - 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
- 基于 Websocket 实现,完美的打字机体验。 - 基于 Websocket 实现,完美的打字机体验。
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。 - 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。 - 支持 OpenAI, Claude, 通义千问KimiDeepSeek等多个大语言模型**支持 Gitee AI Serverless 大模型 API**。
- 支持 Suno 文生音乐 - 支持 Suno 文生音乐
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。 - 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。 - 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
@@ -26,63 +25,16 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
- [x] 支持网站 Logo 版权等信息的修改 - [x] 支持网站 Logo 版权等信息的修改
## 功能截图 ## 功能截图
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
### PC 端聊天界面
![ChatGPT Chat Page](/docs/imgs/gpt.gif)
### AI 对话界面
![ChatGPT new Chat Page](/docs/imgs/chat-new.png)
### MidJourney 专业绘画界面
![mid-journey](/docs/imgs/mj_image.jpg)
### Stable-Diffusion 专业绘画页面
![Stable-Diffusion](/docs/imgs/sd_image.jpg)
![Stable-Diffusion](/docs/imgs/sd_image_detail.jpg)
### 绘图作品展
![ChatGPT image_list](/docs/imgs/image-list.png)
### AI应用列表
![ChatGPT-app-list](/docs/imgs/app-list.jpg)
### 会员充值
![会员充值](/docs/imgs/member.png)
### 自动调用函数插件
![ChatGPT function plugin](/docs/imgs/plugin.png)
![ChatGPT function plugin](/docs/imgs/mj.jpg)
### 管理后台
![ChatGPT admin](/docs/imgs/admin_dashboard.png)
![ChatGPT admin](/docs/imgs/admin_config.jpg)
![ChatGPT admin](/docs/imgs/admin_models.jpg)
![ChatGPT admin](/docs/imgs/admin_user.png)
### 移动端 Web 页面
![Mobile chat list](/docs/imgs/mobile_chat_list.png)
![Mobile chat session](/docs/imgs/mobile_chat_session.png)
![Mobile chat setting](/docs/imgs/mobile_user_profile.png)
![Mobile chat setting](/docs/imgs/mobile_pay.png)
### 体验地址 ### 体验地址
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/> > 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!** > **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
## 快速部署 ## 快速部署
请参考文档 [**GeekAI 快速部署**](https://ai.r9it.com/docs/install/)。 请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
## 使用须知 ## 使用须知
@@ -101,14 +53,14 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
## TODOLIST ## TODOLIST
* [ ] 支持基于知识库的 AI 问答 * [ ] 支持基于知识库的 AI 问答
* [ ] 会员邀请注册推广功能 * [ ] 文生视频,文生歌曲功能
* [ ] 微信支付功能 * [ ] 微信支付功能
## 项目文档 ## 项目文档
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/) 最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**GeekAI 文档**](https://ai.r9it.com/docs/)。 详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。** 加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。**

3
api/.gitignore vendored
View File

@@ -17,4 +17,5 @@ bin
data data
config.toml config.toml
static/upload static/upload
storage.json storage.json
res/certs/wechat/apiclient_key.pem

View File

@@ -3,8 +3,7 @@ ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local" MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录 StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = "" TikaHost = "http://tika:9998"
WeChatBot = false
[Session] [Session]
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换 SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
@@ -17,7 +16,7 @@ WeChatBot = false
DB = 0 DB = 0
[ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通 [ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通
ApiURL = "" ApiURL = "https://sapi.geekai.me"
AppId = "" AppId = ""
Token = "" Token = ""
@@ -64,23 +63,6 @@ WeChatBot = false
SubDir = "" SubDir = ""
Domain = "" Domain = ""
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[[MjPlusConfigs]]
Enabled = false
ApiURL = "https://api.chat-plus.net"
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动 [XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务 Enabled = false # 是否启用 XXL JOB 服务
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址 ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
@@ -89,6 +71,15 @@ WeChatBot = false
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
RegistryKey = "chatgpt-plus" # 任务注册 key RegistryKey = "chatgpt-plus" # 任务注册 key
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
# 支付宝商户支付
[AlipayConfig] [AlipayConfig]
Enabled = false # 启用支付宝支付通道 Enabled = false # 启用支付宝支付通道
SandBox = false # 是否启用沙盒模式 SandBox = false # 是否启用沙盒模式
@@ -98,28 +89,27 @@ WeChatBot = false
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书 PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书 AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书 RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
# 虎皮椒支付
[HuPiPayConfig] [HuPiPayConfig]
Enabled = false Enabled = false
Name = "wechat"
AppId = "" AppId = ""
AppSecret = "" AppSecret = ""
ApiURL = "https://api.xunhupay.com" ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接 # 微信商户支付
UseTls = false [WechatPayConfig]
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
[JPayConfig] # PayJs 支付配置
Enabled = false Enabled = false
Name = "wechat" # 请不要改动 AppId = "" # 商户应用ID
AppId = "" # 商户 ID MchId = "" # 商户
PrivateKey = "" # 秘钥 SerialNo = "" # API 证书序列号
ApiURL = "https://payjs.cn" PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己 ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置
# 易支付
[GeekPayConfig]
Enabled = true
AppId = "" # 商户ID
PrivateKey = "" # 商户私钥
ApiURL = "https://pay.geekai.cn"
Methods = ["alipay", "wxpay", "qqpay", "jdpay", "douyin", "paypal"] # 支持的支付方式

View File

@@ -9,12 +9,12 @@ package core
import ( import (
"bytes" "bytes"
"context"
"fmt"
"geekai/core/types" "geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"context"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -32,39 +32,24 @@ import (
) )
type AppServer struct { type AppServer struct {
Debug bool Debug bool
Config *types.AppConfig Config *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
SysConfig *types.SystemConfig // system config cache SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
} }
func NewServer(appConfig *types.AppConfig) *AppServer { func NewServer(appConfig *types.AppConfig) *AppServer {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard gin.DefaultWriter = io.Discard
return &AppServer{ return &AppServer{
Debug: false, Debug: false,
Config: appConfig, Config: appConfig,
Engine: gin.Default(), Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
} }
} }
func (s *AppServer) Init(debug bool, client *redis.Client) { func (s *AppServer) Init(debug bool, client *redis.Client) {
if debug { // 调试模式允许跨域请求 API // 允许跨域请求 API
s.Debug = debug
logger.Info("Enabled debug mode")
}
s.Engine.Use(corsMiddleware()) s.Engine.Use(corsMiddleware())
s.Engine.Use(staticResourceMiddleware()) s.Engine.Use(staticResourceMiddleware())
s.Engine.Use(authorizeMiddleware(s, client)) s.Engine.Use(authorizeMiddleware(s, client))
@@ -77,13 +62,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
func (s *AppServer) Run(db *gorm.DB) error { func (s *AppServer) Run(db *gorm.DB) error {
// load system configs // load system configs
var sysConfig model.Config var sysConfig model.Config
res := db.Where("marker", "system").First(&sysConfig) err := db.Where("marker", "system").First(&sysConfig).Error
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to load system config: %v", err)
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return fmt.Errorf("failed to decode system config: %v", err)
} }
logger.Infof("http://%s", s.Config.Listen) logger.Infof("http://%s", s.Config.Listen)
return s.Engine.Run(s.Config.Listen) return s.Engine.Run(s.Config.Listen)
@@ -95,7 +80,7 @@ func errorHandler(c *gin.Context) {
if r := recover(); r != nil { if r := recover(); r != nil {
logger.Errorf("Handler Panic: %v", r) logger.Errorf("Handler Panic: %v", r)
debug.PrintStack() debug.PrintStack()
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) c.JSON(http.StatusBadRequest, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
c.Abort() c.Abort()
} }
}() }()
@@ -113,9 +98,9 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段 //允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization") c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
// 允许浏览器(客户端)可以解析的头部 (重要) // 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间 //设置缓存时间
c.Header("Access-Control-Max-Age", "172800") c.Header("Access-Control-Max-Age", "172800")
//允许客户端传递校验信息比如 cookie (重要) //允许客户端传递校验信息比如 cookie (重要)
@@ -139,19 +124,26 @@ func corsMiddleware() gin.HandlerFunc {
// 用户授权验证 // 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
var tokenString string var tokenString string
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/") isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader) tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" { } else if clientProtocols != "" { // Websocket 连接
tokenString = c.Query("token") // 解析子协议内容
protocols := strings.Split(clientProtocols, ",")
if protocols[0] == "realtime" {
tokenString = strings.TrimSpace(protocols[1][25:])
} else if protocols[0] == "token" {
tokenString = strings.TrimSpace(protocols[1])
}
} else { } else {
tokenString = c.GetHeader(types.UserAuthHeader) tokenString = c.GetHeader(types.UserAuthHeader)
} }
if tokenString == "" { if tokenString == "" {
if needLogin(c) { if needLogin(c) {
resp.ERROR(c, "You should put Authorization in request headers") resp.NotAuth(c, "You should put Authorization in request headers")
c.Abort() c.Abort()
return return
} else { // 直接放行 } else { // 直接放行
@@ -213,29 +205,33 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/admin/logout" || c.Request.URL.Path == "/api/admin/logout" ||
c.Request.URL.Path == "/api/admin/login/captcha" || c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" || c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/user/session" ||
c.Request.URL.Path == "/api/chat/history" || c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" || c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" || c.Request.URL.Path == "/api/app/list" ||
c.Request.URL.Path == "/api/app/type/list" ||
c.Request.URL.Path == "/api/app/list/user" ||
c.Request.URL.Path == "/api/model/list" || c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" || c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/dall/imgWall" || c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" || c.Request.URL.Path == "/api/menu/list" ||
c.Request.URL.Path == "/api/markMap/client" || c.Request.URL.Path == "/api/markMap/client" ||
c.Request.URL.Path == "/api/payment/doPay" ||
c.Request.URL.Path == "/api/payment/payWays" ||
c.Request.URL.Path == "/api/suno/detail" ||
c.Request.URL.Path == "/api/suno/play" ||
c.Request.URL.Path == "/api/download" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") { strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false return false
} }
@@ -370,6 +366,7 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 直接输出图像数据流 // 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes()) c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求 c.Abort() // 中断请求
} }
c.Next() c.Next()
} }

View File

@@ -38,7 +38,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload", BasePath: "./static/upload",
}, },
}, },
WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
} }
} }

View File

@@ -9,14 +9,14 @@ package types
// ApiRequest API 请求实体 // ApiRequest API 请求实体
type ApiRequest struct { type ApiRequest struct {
Model string `json:"model,omitempty"` // 兼容百度文心一言 Model string `json:"model,omitempty"`
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言 MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"` MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
Messages []interface{} `json:"messages,omitempty"` Stream bool `json:"stream,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM Messages []interface{} `json:"messages,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"` ToolChoice string `json:"tool_choice,omitempty"`
@@ -52,24 +52,24 @@ type Delta struct {
// ChatSession 聊天会话对象 // ChatSession 聊天会话对象
type ChatSession struct { type ChatSession struct {
SessionId string `json:"session_id"` UserId uint `json:"user_id"`
ClientIP string `json:"client_ip"` // 客户端 IP ClientIP string `json:"client_ip"` // 客户端 IP
Username string `json:"username"` // 当前登录的 username ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
UserId uint `json:"user_id"` // 当前登录的 user ID Model ChatModel `json:"model"` // GPT 模型
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 Start int64 `json:"start"` // 开始请求时间戳
Model ChatModel `json:"model"` // GPT 模型 Tools []int `json:"tools"` // 工具函数列表
Stream bool `json:"stream"` // 是否采用流式输出
} }
type ChatModel struct { type ChatModel struct {
Id uint `json:"id"` Id uint `json:"id"`
Platform Platform `json:"platform"` Name string `json:"name"`
Name string `json:"name"` Value string `json:"value"`
Value string `json:"value"` Power int `json:"power"`
Power int `json:"power"` MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度
MaxContext int `json:"max_context"` // 最大上下文长 Temperature float32 `json:"temperature"` // 模型温
Temperature float32 `json:"temperature"` // 模型温度 KeyId int `json:"key_id"` // 绑定 API KEY
KeyId int `json:"key_id"` // 绑定 API KEY
} }
type ApiError struct { type ApiError struct {
@@ -92,7 +92,7 @@ const (
PowerConsume = PowerType(2) // 消费 PowerConsume = PowerType(2) // 消费
PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款 PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
PowerInvite = PowerType(4) // 邀请奖励 PowerInvite = PowerType(4) // 邀请奖励
PowerReward = PowerType(5) // 众筹 PowerRedeem = PowerType(5) // 众筹
PowerGift = PowerType(6) // 系统赠送 PowerGift = PowerType(6) // 系统赠送
) )
@@ -104,8 +104,8 @@ func (t PowerType) String() string {
return "消费" return "消费"
case PowerRefund: case PowerRefund:
return "退款" return "退款"
case PowerReward: case PowerRedeem:
return "众筹" return "兑换"
} }
return "其他" return "其他"

View File

@@ -17,15 +17,17 @@ var ErrConClosed = errors.New("connection Closed")
// WsClient websocket client // WsClient websocket client
type WsClient struct { type WsClient struct {
Id string
Conn *websocket.Conn Conn *websocket.Conn
lock sync.Mutex lock sync.Mutex
mt int mt int
Closed bool Closed bool
} }
func NewWsClient(conn *websocket.Conn) *WsClient { func NewWsClient(conn *websocket.Conn, id string) *WsClient {
return &WsClient{ return &WsClient{
Conn: conn, Conn: conn,
Id: id,
lock: sync.Mutex{}, lock: sync.Mutex{},
mt: 2, // fixed bug for 'Invalid UTF-8 in text frame' mt: 2, // fixed bug for 'Invalid UTF-8 in text frame'
Closed: false, Closed: false,

View File

@@ -12,28 +12,25 @@ import (
) )
type AppConfig struct { type AppConfig struct {
Path string `toml:"-"` Path string `toml:"-"`
Listen string Listen string
Session Session Session Session
AdminSession Session AdminSession Session
ProxyURL string ProxyURL string
MysqlDns string // mysql 连接地址 MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录 StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息 Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config SmtpConfig SmtpConfig // 邮件发送配置
MjPlusConfigs []MjPlusConfig // MJ plus config XXLConfig XXLConfig
WeChatBot bool // 是否启用微信机器人 AlipayConfig AlipayConfig // 支付宝支付渠道配置
SdConfigs []StableDiffusionConfig // sd AI draw service pool HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
GeekPayConfig GeekPayConfig // GEEK 支付配置
XXLConfig XXLConfig WechatPayConfig WechatPayConfig // 微信支付渠道配置
AlipayConfig AlipayConfig TikaHost string // TiKa 服务器地址
HuPiPayConfig HuPiPayConfig
SmtpConfig SmtpConfig // 邮件发送配置
JPayConfig JPayConfig // payjs 支付配置
} }
type SmtpConfig struct { type SmtpConfig struct {
@@ -51,27 +48,6 @@ type ApiConfig struct {
Token string Token string
} }
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct { type AlipayConfig struct {
Enabled bool // 是否启用该支付通道 Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境 SandBox bool // 是否沙盒环境
@@ -81,29 +57,38 @@ type AlipayConfig struct {
PublicKey string // 用户公钥文件路径 PublicKey string // 用户公钥文件路径
AlipayPublicKey string // 支付宝公钥文件路径 AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径 RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知地址
ReturnURL string // 支付成功返回地址 ReturnURL string // 同步通知地址
}
type WechatPayConfig struct {
Enabled bool // 是否启用该支付通道
AppId string // 公众号的APPID,如wxd678efh567hg6787
MchId string // 直连商户的商户号,由微信支付生成并下发
SerialNo string // 商户证书的证书序列号
PrivateKey string // 用户私钥文件路径
ApiV3Key string // API V3 秘钥
NotifyURL string // 异步通知地址
} }
type HuPiPayConfig struct { //虎皮椒第四方支付配置 type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道 Enabled bool // 是否启用该支付通道
Name string // 支付名称wechat/alipay
AppId string // App ID AppId string // App ID
AppSecret string // app 密钥 AppSecret string // app 密钥
ApiURL string // 支付网关 ApiURL string // 支付网关
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知地址
ReturnURL string // 支付成功返回地址 ReturnURL string // 同步通知地址
} }
// JPayConfig PayJs 支付配置 // GeekPayConfig GEEK支付配置
type JPayConfig struct { type GeekPayConfig struct {
Enabled bool Enabled bool
Name string // 支付名称,默认 wechat AppId string // 商户 ID
AppId string // 商户 ID PrivateKey string // 私钥
PrivateKey string // 私钥 ApiURL string // API 网关
ApiURL string // API 网关 NotifyURL string // 异步通知地址
NotifyURL string // 异步回调地址 ReturnURL string // 同步通知地址
ReturnURL string // 支付成功返回地址 Methods []string // 支付方式
} }
type XXLConfig struct { // XXL 任务调度配置 type XXLConfig struct { // XXL 任务调度配置
@@ -126,30 +111,27 @@ type RedisConfig struct {
const LicenseKey = "Geek-AI-License" const LicenseKey = "Geek-AI-License"
type License struct { type License struct {
Key string `json:"key"` // 许可证书密钥 Key string `json:"key"` // 许可证书密钥
MachineId string `json:"machine_id"` // 机器码 MachineId string `json:"machine_id"` // 机器码
UserNum int `json:"user_num"` // 用户数量 ExpiredAt int64 `json:"expired_at"` // 过期时间
ExpiredAt int64 `json:"expired_at"` // 过期时间 IsActive bool `json:"is_active"` // 是否激活
IsActive bool `json:"is_active"` // 是否激活 Configs LicenseConfig `json:"configs"`
}
type LicenseConfig struct {
UserNum int `json:"user_num"` // 用户数量
DeCopy bool `json:"de_copy"` // 去版权
} }
func (c RedisConfig) Url() string { func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port) return fmt.Sprintf("%s:%d", c.Host, c.Port)
} }
type Platform string
const OpenAI = Platform("OpenAI")
const Azure = Platform("Azure")
const ChatGLM = Platform("ChatGLM")
const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
const QWen = Platform("QWen")
type SystemConfig struct { type SystemConfig struct {
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"` // 网站标题
AdminTitle string `json:"admin_title,omitempty"` Slogan string `json:"slogan,omitempty"` // 网站 slogan
Logo string `json:"logo,omitempty"` AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` // 方形 Logo
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值 InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力 DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值 InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
@@ -158,10 +140,6 @@ type SystemConfig struct {
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册 RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册 EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间 OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明 VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型 DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
@@ -169,10 +147,23 @@ type SystemConfig struct {
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力 MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力 SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力 DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
EnableContext bool `json:"enable_context,omitempty"` EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"` ContextDeep int `json:"context_deep,omitempty"`
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
} }

View File

@@ -24,5 +24,4 @@ type Function struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"` Parameters map[string]interface{} `json:"parameters"`
Required interface{} `json:"required,omitempty"`
} }

View File

@@ -22,3 +22,18 @@ type OrderRemark struct {
Price float64 `json:"price"` Price float64 `json:"price"`
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
} }
var PayMethods = map[string]string{
"alipay": "支付宝商号",
"wechat": "微信商号",
"hupi": "虎皮椒",
"geek": "易支付",
}
var PayNames = map[string]string{
"alipay": "支付宝",
"wxpay": "微信支付",
"qqpay": "QQ钱包",
"jdpay": "京东支付",
"douyin": "抖音支付",
"paypal": "PayPal支付",
}

View File

@@ -24,11 +24,10 @@ const (
// MjTask MidJourney 任务 // MjTask MidJourney 任务
type MjTask struct { type MjTask struct {
Id uint `json:"id"` Id uint `json:"id"` // 任务ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"` // 中转任务ID
ClientId string `json:"client_id"`
ImgArr []string `json:"img_arr"` ImgArr []string `json:"img_arr"`
ChannelId string `json:"channel_id"`
SessionId string `json:"session_id"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@@ -38,23 +37,27 @@ type MjTask struct {
MessageId string `json:"message_id,omitempty"` MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"` MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
ChannelId string `json:"channel_id"` // 渠道ID用来区分是哪个渠道创建的任务一个任务的 create 和 action 操作必须要再同一个渠道
Mode string `json:"mode"` // 绘画模式relax, fast, turbo
} }
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
ClientId string `json:"client_id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Params SdTaskParams `json:"params"` Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
} }
type SdTaskParams struct { type SdTaskParams struct {
ClientId string `json:"client_id"` // 客户端ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词 NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20 Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器 Sampler string `json:"sampler"` // 采样器
Scheduler string `json:"scheduler"` // 采样调度
FaceFix bool `json:"face_fix"` // 面部修复 FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7 CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子 Seed int64 `json:"seed"` // 随机数种子
@@ -69,13 +72,60 @@ type SdTaskParams struct {
// DallTask DALL-E task // DallTask DALL-E task
type DallTask struct { type DallTask struct {
JobId uint `json:"job_id"` ClientId string `json:"client_id"`
UserId uint `json:"user_id"` JobId uint `json:"job_id"`
Prompt string `json:"prompt"` UserId uint `json:"user_id"`
N int `json:"n"` Prompt string `json:"prompt"`
Quality string `json:"quality"` N int `json:"n"`
Size string `json:"size"` Quality string `json:"quality"`
Style string `json:"style"` Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"` Power int `json:"power"`
} }
type SunoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type int `json:"type"`
Title string `json:"title"`
RefTaskId string `json:"ref_task_id,omitempty"`
RefSongId string `json:"ref_song_id,omitempty"`
Prompt string `json:"prompt"` // 提示词/歌词
Tags string `json:"tags"`
Model string `json:"model"`
Instrumental bool `json:"instrumental"` // 是否纯音乐
ExtendSecs int `json:"extend_secs,omitempty"` // 延长秒杀
SongId string `json:"song_id,omitempty"` // 合并歌曲ID
AudioURL string `json:"audio_url"` // 用户上传音频地址
}
const (
VideoLuma = "luma"
VideoRunway = "runway"
VideoCog = "cog"
)
type VideoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
}
type VideoParams struct {
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
Loop bool `json:"loop"` // 是否循环参考图
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址
Model string `json:"model"` // 使用哪个模型生成视频
Radio string `json:"radio"` // 视频尺寸
Style string `json:"style"` // 风格
Duration int `json:"duration"` // 视频时长(秒)
}

View File

@@ -17,30 +17,56 @@ type BizVo struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
// WsMessage Websocket message // ReplyMessage 对话回复消息结构
type WsMessage struct { type ReplyMessage struct {
Type WsMsgType `json:"type"` // 消息类别start, end, img Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat
Content interface{} `json:"content"` ClientId string `json:"clientId"` // 客户端ID
Type WsMsgType `json:"type"` // 消息类别
Body interface{} `json:"body"`
} }
type WsMsgType string type WsMsgType string
type WsChannel string
const ( const (
WsStart = WsMsgType("start") MsgTypeText = WsMsgType("text") // 输出内容
WsMiddle = WsMsgType("middle") MsgTypeEnd = WsMsgType("end")
WsEnd = WsMsgType("end") MsgTypeErr = WsMsgType("error")
WsErr = WsMsgType("error") MsgTypePing = WsMsgType("ping") // 心跳消息
ChPing = WsChannel("ping")
ChChat = WsChannel("chat")
ChMj = WsChannel("mj")
ChSd = WsChannel("sd")
ChDall = WsChannel("dall")
ChSuno = WsChannel("suno")
ChLuma = WsChannel("luma")
) )
// InputMessage 对话输入消息结构
type InputMessage struct {
Channel WsChannel `json:"channel"` // 消息频道
Type WsMsgType `json:"type"` // 消息类别
Body interface{} `json:"body"`
}
type ChatMessage struct {
Tools []int `json:"tools,omitempty"` // 允许调用工具列表
Stream bool `json:"stream,omitempty"` // 是否采用流式输出
RoleId int `json:"role_id"`
ModelId int `json:"model_id"`
ChatId string `json:"chat_id"`
Content string `json:"content"`
}
type BizCode int type BizCode int
const ( const (
Success = BizCode(0) Success = BizCode(0)
Failed = BizCode(1) Failed = BizCode(1)
NotAuthorized = BizCode(400) // 未授权 NotAuthorized = BizCode(401) // 未授权
NotPermission = BizCode(403) // 没有权限
OkMsg = "Success" OkMsg = "Success"
ErrorMsg = "系统开小差了" ErrorMsg = "系统开小差了"
InvalidArgs = "非法参数或参数解析失败" InvalidArgs = "非法参数或参数解析失败"
NoData = "No Data"
) )

View File

@@ -1,12 +1,13 @@
module geekai module geekai
go 1.19 go 1.21
toolchain go1.22.4
require ( require (
github.com/BurntSushi/toml v1.1.0 github.com/BurntSushi/toml v1.1.0
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
github.com/eatmoreapple/openwechat v1.2.1
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang-jwt/jwt/v5 v5.0.0
@@ -17,7 +18,6 @@ require (
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
github.com/qiniu/go-sdk/v7 v7.17.1 github.com/qiniu/go-sdk/v7 v7.17.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/smartwalle/alipay/v3 v3.2.15
go.uber.org/zap v1.23.0 go.uber.org/zap v1.23.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/driver/mysql v1.4.7 gorm.io/driver/mysql v1.4.7
@@ -26,26 +26,35 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0 require github.com/xxl-job/xxl-job-executor-go v1.2.0
require ( require (
github.com/mojocn/base64Captcha v1.3.1 github.com/go-pay/gopay v1.5.101
github.com/shirou/gopsutil v3.21.11+incompatible github.com/google/go-tika v0.3.1
github.com/microcosm-cc/bluemonday v1.0.26
github.com/shopspring/decimal v1.3.1 github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0 github.com/syndtr/goleveldb v1.0.0
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/image v0.15.0
) )
require ( require (
github.com/aymerick/douceur v0.2.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/go-pay/crypto v0.0.1 // indirect
github.com/go-pay/errgroup v0.0.2 // indirect
github.com/go-pay/util v0.0.2 // indirect
github.com/go-pay/xlog v0.0.2 // indirect
github.com/go-pay/xtime v0.0.2 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/shirou/gopsutil v3.21.11+incompatible // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/mock v0.4.0 // indirect
) )
require ( require (
github.com/andybalholm/brotli v1.0.4 // indirect github.com/andybalholm/brotli v1.0.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dlclark/regexp2 v1.8.1 // indirect
@@ -56,7 +65,6 @@ require (
github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
@@ -73,26 +81,21 @@ require (
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/quic-go v0.45.0 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
github.com/quic-go/quic-go v0.35.1 // indirect
github.com/refraction-networking/utls v1.3.2 // indirect github.com/refraction-networking/utls v1.3.2 // indirect
github.com/rs/xid v1.5.0 // indirect github.com/rs/xid v1.5.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect
github.com/smartwalle/ncrypto v1.0.2 // indirect
github.com/smartwalle/ngx v1.0.6 // indirect
github.com/smartwalle/nsign v1.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
go.uber.org/dig v1.16.1 // indirect go.uber.org/dig v1.16.1 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.11.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.14.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.3.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/text v0.12.0 // indirect golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.10.0 // indirect golang.org/x/tools v0.21.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )
@@ -111,7 +114,7 @@ require (
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
go.uber.org/fx v1.19.3 go.uber.org/fx v1.19.3
go.uber.org/multierr v1.6.0 // indirect go.uber.org/multierr v1.6.0 // indirect
golang.org/x/crypto v0.12.0 golang.org/x/crypto v0.23.0
golang.org/x/sys v0.15.0 // indirect golang.org/x/sys v0.20.0 // indirect
gorm.io/gorm v1.25.1 gorm.io/gorm v1.25.1
) )

View File

@@ -6,12 +6,15 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8= github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -25,10 +28,9 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
@@ -40,10 +42,24 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-pay/crypto v0.0.1 h1:B6InT8CLfSLc6nGRVx9VMJRBBazFMjr293+jl0lLXUY=
github.com/go-pay/crypto v0.0.1/go.mod h1:41oEIvHMKbNcYlWUlRWtsnC6+ASgh7u29z0gJXe5bes=
github.com/go-pay/errgroup v0.0.2 h1:5mZMdm0TDClDm2S3G0/sm0f8AuQRtz0dOrTHDR9R8Cc=
github.com/go-pay/errgroup v0.0.2/go.mod h1:0+4b8mvFMS71MIzsaC+gVvB4x37I93lRb2dqrwuU8x8=
github.com/go-pay/gopay v1.5.101 h1:rVb+sfv6hiQtknAlZnTTLvU27NvFJ4p0yglN/vPpGXI=
github.com/go-pay/gopay v1.5.101/go.mod h1:AW4Yj8jDZX9BM1/GTLTY1Gy5SHjiq8kQvG5sBTN2sxI=
github.com/go-pay/util v0.0.2 h1:goJ4f6kNY5zzdtg1Cj8oWC+Cw7bfg/qq2rJangMAb9U=
github.com/go-pay/util v0.0.2/go.mod h1:qM8VbyF1n7YAPZBSJONSPMPsPedhUTktewUAdf1AjPg=
github.com/go-pay/xlog v0.0.2 h1:kUg5X8/5VZAPDg1J5eGjA3MG0/H5kK6Ew0dW/Bycsws=
github.com/go-pay/xlog v0.0.2/go.mod h1:DbjMADPK4+Sjxj28ekK9goqn4zmyY4hql/zRiab+S9E=
github.com/go-pay/xtime v0.0.2 h1:7YR4/iuELsEHpJ6LUO0SVK80hQxDO9MLCfuVYIiTCRM=
github.com/go-pay/xtime v0.0.2/go.mod h1:W1yRbJaSt4CSBcdAtLBQ8xajiN/Pl5hquGczUcUE9xE=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -66,22 +82,22 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -123,6 +139,8 @@ github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs= github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc= github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc=
@@ -135,18 +153,19 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A= github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU= github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
@@ -164,12 +183,8 @@ github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYX
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8=
github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
@@ -185,14 +200,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8=
github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE=
github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI=
github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps=
github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo=
github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -221,7 +228,6 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ= github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo= github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
@@ -233,6 +239,9 @@ go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk=
go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA= go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA=
go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM= go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM=
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
@@ -241,43 +250,42 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -286,34 +294,40 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -8,19 +8,19 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"context"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"context"
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/mojocn/base64Captcha"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -29,37 +29,47 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1 const SuperManagerID = 1
type ManagerHandler struct { type ManagerHandler struct {
handler.BaseHandler handler.BaseHandler
redis *redis.Client redis *redis.Client
captcha *service.CaptchaService
} }
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler { func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, captcha *service.CaptchaService) *ManagerHandler {
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client} return &ManagerHandler{
BaseHandler: handler.BaseHandler{DB: db, App: app},
redis: client,
captcha: captcha,
}
} }
// Login 登录 // Login 登录
func (h *ManagerHandler) Login(c *gin.Context) { func (h *ManagerHandler) Login(c *gin.Context) {
var data Manager var data struct {
Username string `json:"username"`
Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
// add captcha if h.App.SysConfig.EnabledVerify {
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) { var check bool
resp.ERROR(c, "验证码错误!") if data.X != 0 {
return check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
} }
var manager model.AdminUser var manager model.AdminUser

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
@@ -31,7 +32,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
func (h *ApiKeyHandler) Save(c *gin.Context) { func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type string `json:"type"`
Value string `json:"value"` Value string `json:"value"`
@@ -48,23 +48,22 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
if data.Id > 0 { if data.Id > 0 {
h.DB.Find(&apiKey, data.Id) h.DB.Find(&apiKey, data.Id)
} }
apiKey.Platform = data.Platform
apiKey.Value = data.Value apiKey.Value = data.Value
apiKey.Type = data.Type apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled apiKey.Enabled = data.Enabled
apiKey.ProxyURL = data.ProxyURL apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name apiKey.Name = data.Name
res := h.DB.Save(&apiKey) err := h.DB.Save(&apiKey).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
var keyVo vo.ApiKey var keyVo vo.ApiKey
err := utils.CopyObject(apiKey, &keyVo) err = utils.CopyObject(apiKey, &keyVo)
if err != nil { if err != nil {
resp.ERROR(c, "数据拷贝失败!") resp.ERROR(c, fmt.Sprintf("拷贝数据失败:%v", err))
return return
} }
keyVo.Id = apiKey.Id keyVo.Id = apiKey.Id
@@ -83,7 +82,7 @@ func (h *ApiKeyHandler) List(c *gin.Context) {
if t != "" { if t != "" {
session = session.Where("type", t) session = session.Where("type", t)
} }
var items []model.ApiKey var items []model.ApiKey
var keys = make([]vo.ApiKey, 0) var keys = make([]vo.ApiKey, 0)
res := session.Find(&items) res := session.Find(&items)
@@ -116,9 +115,9 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -131,9 +130,9 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) {
return return
} }
res := h.DB.Where("id", id).Delete(&model.ApiKey{}) err := h.DB.Where("id", id).Delete(&model.ApiKey{}).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -1,46 +0,0 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/handler"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)
type CaptchaHandler struct {
handler.BaseHandler
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {
CaptchaId string `json:"captcha_id"`
PicPath string `json:"pic_path"`
}
// GetCaptcha 获取验证码
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
var captchaVo CaptchaVo
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
// b64s是图片的base64编码
id, b64s, err := cp.Generate()
if err != nil {
resp.ERROR(c, "生成验证码错误!")
return
}
captchaVo.CaptchaId = id
captchaVo.PicPath = b64s
resp.SUCCESS(c, captchaVo)
}

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
@@ -21,16 +22,16 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
type ChatRoleHandler struct { type ChatAppHandler struct {
handler.BaseHandler handler.BaseHandler
} }
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// Save 创建或者更新某个角色 // Save 创建或者更新某个角色
func (h *ChatRoleHandler) Save(c *gin.Context) { func (h *ChatAppHandler) Save(c *gin.Context) {
var data vo.ChatRole var data vo.ChatRole
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -45,10 +46,16 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
role.Id = data.Id role.Id = data.Id
if data.CreatedAt > 0 { if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0) role.CreatedAt = time.Unix(data.CreatedAt, 0)
} else {
err = h.DB.Where("marker", data.Key).First(&role).Error
if err == nil {
resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
return
}
} }
res := h.DB.Save(&role) err = h.DB.Save(&role).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
// 填充 ID 数据 // 填充 ID 数据
@@ -57,7 +64,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
resp.SUCCESS(c, data) resp.SUCCESS(c, data)
} }
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatAppHandler) List(c *gin.Context) {
var items []model.ChatRole var items []model.ChatRole
var roles = make([]vo.ChatRole, 0) var roles = make([]vo.ChatRole, 0)
res := h.DB.Order("sort_num ASC").Find(&items) res := h.DB.Order("sort_num ASC").Find(&items)
@@ -68,13 +75,18 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// initialize model mane for role // initialize model mane for role
modelIds := make([]int, 0) modelIds := make([]int, 0)
typeIds := make([]int, 0)
for _, v := range items { for _, v := range items {
if v.ModelId > 0 { if v.ModelId > 0 {
modelIds = append(modelIds, v.ModelId) modelIds = append(modelIds, v.ModelId)
} }
if v.Tid > 0 {
typeIds = append(typeIds, v.Tid)
}
} }
modelNameMap := make(map[int]string) modelNameMap := make(map[int]string)
typeNameMap := make(map[int]string)
if len(modelIds) > 0 { if len(modelIds) > 0 {
var models []model.ChatModel var models []model.ChatModel
tx := h.DB.Where("id IN ?", modelIds).Find(&models) tx := h.DB.Where("id IN ?", modelIds).Find(&models)
@@ -84,6 +96,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
} }
} }
} }
if len(typeIds) > 0 {
var appTypes []model.AppType
tx := h.DB.Where("id IN ?", typeIds).Find(&appTypes)
if tx.Error == nil {
for _, m := range appTypes {
typeNameMap[int(m.Id)] = m.Name
}
}
}
for _, v := range items { for _, v := range items {
var role vo.ChatRole var role vo.ChatRole
@@ -93,6 +114,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
role.CreatedAt = v.CreatedAt.Unix() role.CreatedAt = v.CreatedAt.Unix()
role.UpdatedAt = v.UpdatedAt.Unix() role.UpdatedAt = v.UpdatedAt.Unix()
role.ModelName = modelNameMap[role.ModelId] role.ModelName = modelNameMap[role.ModelId]
role.TypeName = typeNameMap[role.Tid]
roles = append(roles, role) roles = append(roles, role)
} }
} }
@@ -101,7 +123,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
} }
// Sort 更新角色排序 // Sort 更新角色排序
func (h *ChatRoleHandler) Sort(c *gin.Context) { func (h *ChatAppHandler) Sort(c *gin.Context) {
var data struct { var data struct {
Ids []uint `json:"ids"` Ids []uint `json:"ids"`
Sorts []int `json:"sorts"` Sorts []int `json:"sorts"`
@@ -113,9 +135,9 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }
@@ -123,7 +145,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *ChatRoleHandler) Set(c *gin.Context) { func (h *ChatAppHandler) Set(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
Filed string `json:"filed"` Filed string `json:"filed"`
@@ -135,28 +157,24 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *ChatRoleHandler) Remove(c *gin.Context) { func (h *ChatAppHandler) Remove(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint
} if id <= 0 {
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if data.Id <= 0 { res := h.DB.Where("id", id).Delete(&model.ChatRole{})
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
if res.Error != nil { if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "删除失败!") resp.ERROR(c, "删除失败!")
return return
} }

View File

@@ -0,0 +1,148 @@
package admin
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatAppTypeHandler struct {
handler.BaseHandler
}
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或更新App类型
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Icon string `json:"icon"`
SortNum int `json:"sort_num"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id == 0 { // for add
err := h.DB.Where("name", data.Name).First(&model.AppType{}).Error
if err == nil {
resp.ERROR(c, "当前分类已经存在")
return
}
err = h.DB.Create(&model.AppType{
Name: data.Name,
Icon: data.Icon,
Enabled: data.Enabled,
SortNum: data.SortNum,
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
} else { // for update
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).Updates(map[string]interface{}{
"name": data.Name,
"icon": data.Icon,
"enabled": data.Enabled,
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var appType vo.AppType
err = utils.CopyObject(v, &appType)
if err != nil {
continue
}
appType.Id = v.Id
appType.CreatedAt = v.CreatedAt.Unix()
appTypes = append(appTypes, appType)
}
resp.SUCCESS(c, appTypes)
}
// Remove 删除App类型
func (h *ChatAppTypeHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.AppType{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用|禁用
func (h *ChatAppTypeHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Sort 更新排序
func (h *ChatAppTypeHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
for index, id := range data.Ids {
err := h.DB.Model(&model.AppType{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}

View File

@@ -259,9 +259,9 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
// RemoveMessage 删除聊天记录 // RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) { func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}) err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}).Error
if tx.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -15,6 +15,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -48,27 +49,32 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
return return
} }
item := model.ChatModel{ item := model.ChatModel{}
Platform: data.Platform, // 更新
Name: data.Name, if data.Id > 0 {
Value: data.Value, h.DB.Where("id", data.Id).First(&item)
Enabled: data.Enabled, }
SortNum: data.SortNum,
Open: data.Open, item.Name = data.Name
MaxTokens: data.MaxTokens, item.Value = data.Value
MaxContext: data.MaxContext, item.Enabled = data.Enabled
Temperature: data.Temperature, item.SortNum = data.SortNum
KeyId: data.KeyId, item.Open = data.Open
Power: data.Power} item.Power = data.Power
item.MaxTokens = data.MaxTokens
item.MaxContext = data.MaxContext
item.Temperature = data.Temperature
item.KeyId = data.KeyId
var res *gorm.DB var res *gorm.DB
if data.Id > 0 { if data.Id > 0 {
item.Id = data.Id res = h.DB.Save(&item)
res = h.DB.Select("*").Omit("created_at").Updates(&item)
} else { } else {
res = h.DB.Create(&item) res = h.DB.Create(&item)
} }
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") logger.Error("error with update database", res.Error)
resp.ERROR(c, res.Error.Error())
return return
} }
@@ -87,9 +93,13 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
func (h *ChatModelHandler) List(c *gin.Context) { func (h *ChatModelHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable") enable := h.GetBool(c, "enable")
name := h.GetTrim(c, "name")
if enable { if enable {
session = session.Where("enabled", enable) session = session.Where("enabled", enable)
} }
if name != "" {
session = session.Where("name LIKE ?", name+"%")
}
var items []model.ChatModel var items []model.ChatModel
var cms = make([]vo.ChatModel, 0) var cms = make([]vo.ChatModel, 0)
res := session.Order("sort_num ASC").Find(&items) res := session.Order("sort_num ASC").Find(&items)
@@ -137,9 +147,9 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -157,9 +167,9 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }
@@ -174,9 +184,9 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
return return
} }
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}) err := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -16,6 +16,7 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/host" "github.com/shirou/gopsutil/host"
"gorm.io/gorm" "gorm.io/gorm"
@@ -28,7 +29,11 @@ type ConfigHandler struct {
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB, licenseService: licenseService} return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB,
licenseService: licenseService,
}
} }
func (h *ConfigHandler) Update(c *gin.Context) { func (h *ConfigHandler) Update(c *gin.Context) {
@@ -39,6 +44,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"` Updated bool `json:"updated,omitempty"`
} `json:"config"` } `json:"config"`
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@@ -46,6 +52,12 @@ func (h *ConfigHandler) Update(c *gin.Context) {
return return
} }
// ONLY authorized user can change the copyright
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
return
}
value := utils.JsonEncode(&data.Config) value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value} config := model.Config{Key: data.Key, Config: value}
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key}) res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
@@ -128,3 +140,70 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense() license := h.licenseService.GetLicense()
resp.SUCCESS(c, license) resp.SUCCESS(c, license)
} }
// FixData 修复数据
func (h *ConfigHandler) FixData(c *gin.Context) {
resp.ERROR(c, "当前升级版本没有数据需要修正!")
return
//var fixed bool
//version := "data_fix_4.1.4"
//err := h.levelDB.Get(version, &fixed)
//if err == nil || fixed {
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
// return
//}
//tx := h.DB.Begin()
//var users []model.User
//err = tx.Find(&users).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, user := range users {
// if user.Email != "" || user.Mobile != "" {
// continue
// }
// if utils.IsValidEmail(user.Username) {
// user.Email = user.Username
// } else if utils.IsValidMobile(user.Username) {
// user.Mobile = user.Username
// }
// err = tx.Save(&user).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//
//var orders []model.Order
//err = h.DB.Find(&orders).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, order := range orders {
// if order.PayWay == "支付宝" {
// order.PayWay = "alipay"
// order.PayType = "alipay"
// } else if order.PayWay == "微信支付" {
// order.PayWay = "wechat"
// order.PayType = "wxpay"
// } else if order.PayWay == "hupi" {
// order.PayType = "wxpay"
// }
// err = tx.Save(&order).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//tx.Commit()
//err = h.levelDB.Put(version, true)
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//resp.SUCCESS(c)
}

View File

@@ -60,13 +60,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
stats.Tokens += item.Tokens stats.Tokens += item.Tokens
} }
// 众筹收入
var rewards []model.Reward
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入 // 订单收入
var orders []model.Order var orders []model.Order
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders) res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
@@ -101,13 +94,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens) historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
} }
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单 // 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders) res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders { for _, item := range orders {

View File

@@ -69,9 +69,9 @@ func (h *FunctionHandler) Set(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) err := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -101,9 +101,9 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.DB.Delete(&model.Function{Id: uint(id)}) err := h.DB.Delete(&model.Function{Id: uint(id)}).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }

View File

@@ -0,0 +1,254 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ImageHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type imageQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// MjList Midjourney 任务列表
func (h *ImageHandler) MjList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var list []model.MidJourneyJob
var items = make([]vo.MidJourneyJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.MidJourneyJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// SdList Stable Diffusion 任务列表
func (h *ImageHandler) SdList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SdJob{}).Count(&total)
var list []model.SdJob
var items = make([]vo.SdJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SdJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// DallList DALL-E 任务列表
func (h *ImageHandler) DallList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.DallJob{}).Count(&total)
var list []model.DallJob
var items = make([]vo.DallJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.DallJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *ImageHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, imgURL string
var power, userId, progress int
switch tab {
case "mj":
var job model.MidJourneyJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "mid-journey"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "sd":
var job model.SdJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "stable-diffusion"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "dall":
var job model.DallJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "dall-e-3"
power = job.Power
userId = int(job.UserId)
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(imgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,200 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MediaHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler {
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type mediaQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// SunoList Suno 任务列表
func (h *MediaHandler) SunoList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SunoJob{}).Count(&total)
var list []model.SunoJob
var items = make([]vo.SunoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SunoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// LumaList Luma 视频任务列表
func (h *MediaHandler) LumaList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.VideoJob{}).Count(&total)
var list []model.VideoJob
var items = make([]vo.VideoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.VideoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
if job.VideoURL == "" {
job.VideoURL = job.WaterURL
}
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *MediaHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, fileURL string
var power, userId, progress int
switch tab {
case "suno":
var job model.SunoJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "suno"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("SUNO 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.AudioURL
break
case "luma":
var job model.VideoJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = job.Type
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("LUMA 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.VideoURL
if fileURL == "" {
fileURL = job.WaterURL
}
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(fileURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -41,16 +41,16 @@ func (h *MenuHandler) Save(c *gin.Context) {
return return
} }
res := h.DB.Save(&model.Menu{ err := h.DB.Save(&model.Menu{
Id: data.Id, Id: data.Id,
Name: data.Name, Name: data.Name,
Icon: data.Icon, Icon: data.Icon,
URL: data.URL, URL: data.URL,
SortNum: data.SortNum, SortNum: data.SortNum,
Enabled: data.Enabled, Enabled: data.Enabled,
}) }).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -84,9 +84,9 @@ func (h *MenuHandler) Enable(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) err := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -104,9 +104,9 @@ func (h *MenuHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }
@@ -118,9 +118,9 @@ func (h *MenuHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Menu{}) err := h.DB.Where("id", id).Delete(&model.Menu{}).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }

View File

@@ -15,6 +15,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -67,6 +68,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix() order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
if !ok {
payMethod = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
if !ok {
payName = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
list = append(list, order) list = append(list, order)
} else { } else {
logger.Error(err) logger.Error(err)
@@ -92,11 +103,33 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return return
} }
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}) err := h.DB.Where("id = ?", id).Delete(&model.Order{}).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *OrderHandler) Clear(c *gin.Context) {
var orders []model.Order
err := h.DB.Where("status <> ?", 2).Where("pay_time", 0).Find(&orders).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
deleteIds := make([]uint, 0)
for _, order := range orders {
// 只删除 15 分钟内的未支付订单
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
deleteIds = append(deleteIds, order.Id)
}
}
err = h.DB.Where("id IN ?", deleteIds).Delete(&model.Order{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -55,16 +55,16 @@ func (h *ProductHandler) Save(c *gin.Context) {
if item.Id > 0 { if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0) item.CreatedAt = time.Unix(data.CreatedAt, 0)
} }
res := h.DB.Save(&item) err := h.DB.Save(&item).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
var itemVo vo.Product var itemVo vo.Product
err := utils.CopyObject(item, &itemVo) err = utils.CopyObject(item, &itemVo)
if err != nil { if err != nil {
resp.ERROR(c, "数据拷贝失败") resp.ERROR(c, "数据拷贝失败: "+err.Error())
return return
} }
itemVo.Id = item.Id itemVo.Id = item.Id
@@ -105,9 +105,9 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) err := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -125,9 +125,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]) err := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }
@@ -139,9 +139,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Product{}) err := h.DB.Where("id", id).Delete(&model.Product{}).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, err.Error())
return return
} }
} }

View File

@@ -0,0 +1,160 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RedeemHandler struct {
handler.BaseHandler
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RedeemHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
code := c.Query("code")
status := h.GetInt(c, "status", -1)
session := h.DB.Session(&gorm.Session{})
if code != "" {
session.Where("code LIKE ?", "%"+code+"%")
}
if status == 0 {
session.Where("redeem_at = ?", 0)
} else if status == 1 {
session.Where("redeem_at > ?", 0)
}
var total int64
session.Model(&model.Redeem{}).Count(&total)
var redeems []model.Redeem
offset := (page - 1) * pageSize
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&redeems).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var items = make([]vo.Redeem, 0)
userIds := make([]uint, 0)
for _, v := range redeems {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
}
for _, v := range redeems {
var r vo.Redeem
err = utils.CopyObject(v, &r)
if err != nil {
continue
}
r.Id = v.Id
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
items = append(items, r)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *RedeemHandler) Create(c *gin.Context) {
var data struct {
Name string `json:"name"`
Power int `json:"power"`
Num int `json:"num"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
counter := 0
codes := make([]string, 0)
var errMsg = ""
if data.Num > 0 {
for i := 0; i < data.Num; i++ {
code, err := utils.GenRedeemCode(32)
if err != nil {
errMsg = err.Error()
continue
}
err = h.DB.Create(&model.Redeem{
Code: code,
Name: data.Name,
Power: data.Power,
Enabled: true,
}).Error
if err != nil {
errMsg = err.Error()
continue
}
codes = append(codes, code)
counter++
}
}
if counter == 0 {
resp.ERROR(c, errMsg)
return
}
resp.SUCCESS(c, gin.H{
"counter": counter,
})
}
func (h *RedeemHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Model(&model.Redeem{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *RedeemHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -1,80 +0,0 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RewardHandler struct {
handler.BaseHandler
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
var items []model.Reward
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
for _, v := range items {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
}
for _, v := range items {
var r vo.Reward
err := utils.CopyObject(v, &r)
if err != nil {
continue
}
r.Id = v.Id
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
r.UpdatedAt = v.UpdatedAt.Unix()
rewards = append(rewards, r)
}
}
resp.SUCCESS(c, rewards)
}
func (h *RewardHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
@@ -16,7 +17,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"fmt" "github.com/go-redis/redis/v8"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -26,10 +27,11 @@ import (
type UserHandler struct { type UserHandler struct {
handler.BaseHandler handler.BaseHandler
licenseService *service.LicenseService licenseService *service.LicenseService
redis *redis.Client
} }
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler { func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService, redisCli *redis.Client) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService} return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
} }
// List 用户列表 // List 用户列表
@@ -49,7 +51,7 @@ func (h *UserHandler) List(c *gin.Context) {
} }
session.Model(&model.User{}).Count(&total) session.Model(&model.User{}).Count(&total)
res := session.Offset(offset).Limit(pageSize).Find(&items) res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var user vo.User var user vo.User
@@ -73,6 +75,8 @@ func (h *UserHandler) Save(c *gin.Context) {
Id uint `json:"id"` Id uint `json:"id"`
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
ChatRoles []string `json:"chat_roles"` ChatRoles []string `json:"chat_roles"`
ChatModels []int `json:"chat_models"` ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"` ExpiredTime string `json:"expired_time"`
@@ -87,7 +91,7 @@ func (h *UserHandler) Save(c *gin.Context) {
// 检测最大注册人数 // 检测最大注册人数
var totalUser int64 var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser) h.DB.Model(&model.User{}).Count(&totalUser)
if int(totalUser) >= h.licenseService.GetLicense().UserNum { if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return return
} }
@@ -102,6 +106,8 @@ func (h *UserHandler) Save(c *gin.Context) {
} }
var oldPower = user.Power var oldPower = user.Power
user.Username = data.Username user.Username = data.Username
user.Email = data.Email
user.Mobile = data.Mobile
user.Status = data.Status user.Status = data.Status
user.Vip = data.Vip user.Vip = data.Vip
user.Power = data.Power user.Power = data.Power
@@ -109,9 +115,11 @@ func (h *UserHandler) Save(c *gin.Context) {
user.ChatModels = utils.JsonEncode(data.ChatModels) user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime) user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user) res = h.DB.Select("username", "mobile", "email", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") logger.Error("error with update database", res.Error)
resp.ERROR(c, res.Error.Error())
return return
} }
// 记录算力日志 // 记录算力日志
@@ -134,12 +142,27 @@ func (h *UserHandler) Save(c *gin.Context) {
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} }
// 如果禁用了用户,则将用户踢下线
if user.Status == false {
key := fmt.Sprintf("users/%v", user.Id)
if _, err := h.redis.Del(c, key).Result(); err != nil {
logger.Error("error with delete session: ", err)
}
}
} else { } else {
// 检查用户是否已经存在
h.DB.Where("username", data.Username).First(&user)
if user.Id > 0 {
resp.ERROR(c, "用户名已存在")
return
}
salt := utils.RandString(8) salt := utils.RandString(8)
u := model.User{ u := model.User{
Username: data.Username, Username: data.Username,
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Password: utils.GenPassword(data.Password, salt), Password: utils.GenPassword(data.Password, salt),
Mobile: data.Mobile,
Email: data.Email,
Avatar: "/images/avatar/user.png", Avatar: "/images/avatar/user.png",
Salt: salt, Salt: salt,
Power: data.Power, Power: data.Power,
@@ -148,6 +171,11 @@ func (h *UserHandler) Save(c *gin.Context) {
ChatModels: utils.JsonEncode(data.ChatModels), ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime), ExpiredTime: utils.Str2stamp(data.ExpiredTime),
} }
if h.licenseService.GetLicense().Configs.DeCopy {
u.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
}
res = h.DB.Create(&u) res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo) _ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id userVo.Id = u.Id
@@ -156,7 +184,7 @@ func (h *UserHandler) Save(c *gin.Context) {
} }
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, res.Error.Error())
return return
} }
@@ -192,33 +220,69 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
} }
func (h *UserHandler) Remove(c *gin.Context) { func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := c.Query("id")
if id <= 0 { ids := c.QueryArray("ids[]")
if id != "" {
ids = append(ids, id)
}
if len(ids) == 0 {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{}) tx := h.DB.Begin()
if res.Error != nil { var err error
resp.ERROR(c, "删除失败") for _, id = range ids {
// 删除用户
if err = tx.Where("id", id).Delete(&model.User{}).Error; err != nil {
break
}
// 删除聊天记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatItem{}).Error; err != nil {
break
}
// 删除聊天历史记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatMessage{}).Error; err != nil {
break
}
// 删除登录日志
if err = tx.Where("user_id = ?", id).Delete(&model.UserLoginLog{}).Error; err != nil {
break
}
// 删除算力日志
if err = tx.Where("user_id = ?", id).Delete(&model.PowerLog{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.InviteLog{}).Error; err != nil {
break
}
// 删除众筹日志
if err = tx.Where("user_id = ?", id).Delete(&model.Redeem{}).Error; err != nil {
break
}
// 删除绘图任务
if err = tx.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SdJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.DallJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SunoJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.VideoJob{}).Error; err != nil {
break
}
}
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return return
} }
tx.Commit()
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@@ -8,13 +8,13 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"errors"
"fmt"
"gorm.io/gorm" "gorm.io/gorm"
"strings" "strings"
@@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
} }
var user model.User var user model.User
res := h.DB.First(&user, userId) res := h.DB.Where("id", userId).First(&user)
// 更新缓存 // 更新缓存
if res.Error == nil { if res.Error == nil {
c.Set(types.LoginUserCache, user) c.Set(types.LoginUserCache, user)

View File

@@ -0,0 +1,44 @@
package handler
import (
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatAppTypeHandler struct {
BaseHandler
}
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var appType vo.AppType
err = utils.CopyObject(v, &appType)
if err != nil {
continue
}
appType.Id = v.Id
appType.CreatedAt = v.CreatedAt.Unix()
appTypes = append(appTypes, appType)
}
resp.SUCCESS(c, appTypes)
}

522
api/handler/chat_handler.go Normal file
View File

@@ -0,0 +1,522 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"html/template"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ChatHandler struct {
BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
userService *service.UserService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
return &ChatHandler{
BaseHandler: BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
ChatContexts: types.NewLMap[string, []types.Message](),
userService: userService,
}
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
if !h.App.Debug {
defer func() {
if r := recover(); r != nil {
logger.Error("Recover message from error: ", r)
}
}()
}
var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
return errors.New("未授权用户,您正在进行非法操作!")
}
var userVo vo.User
err := utils.CopyObject(user, &userVo)
userVo.Id = user.Id
if err != nil {
return errors.New("User 对象转换失败," + err.Error())
}
if userVo.Status == false {
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
}
if userVo.Power < session.Model.Power {
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d[立即购买](/member)。", userVo.Power, session.Model.Power)
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
return errors.New("您的账号已经过期,请联系管理员!")
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
}
var req = types.ApiRequest{
Model: session.Model.Value,
}
// 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") {
utils.SendChunkMsg(ws, "AI 正在思考...\n")
req.Stream = false
session.Start = time.Now().Unix()
} else {
req.MaxTokens = session.Model.MaxTokens
req.Temperature = session.Model.Temperature
req.Stream = session.Stream
}
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
var items []model.Function
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
}
}
// 加载聊天上下文
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.ChatContexts.Has(session.ChatId) {
messages = h.ChatContexts.Get(session.ChatId)
} else {
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
}
chatCtx = append(chatCtx, ms)
}
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for i := len(messages) - 1; i >= 0; i-- {
v := messages[i]
tks, _ = utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)
for _, m := range chatCtx {
reqMgs = append(reqMgs, m)
}
fullPrompt := prompt
text := prompt
// extract files in prompt
files := utils.ExtractFileURLs(prompt)
logger.Debugf("detected FILES: %+v", files)
// 如果不是逆向模型,则提取文件内容
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
strings.HasSuffix(session.Model.Value, "claude-3")) {
contents := make([]string, 0)
var file model.File
for _, v := range files {
h.DB.Where("url = ?", v).First(&file)
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
if err != nil {
logger.Error("error with read file: ", err)
} else {
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
text = strings.Replace(text, v, "", 1)
}
if len(contents) > 0 {
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML)\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
}
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
if tokens > session.Model.MaxContext {
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
}
}
logger.Debug("最终Prompt", fullPrompt)
// extract images from prompt
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
})
content = data
} else {
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
logger.Debugf("%+v", req.Messages)
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
// Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) {
var data struct {
Text string `json:"text"`
Model string `json:"model"`
ChatId string `json:"chat_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
//if data.Text == "" && data.ChatId != "" {
// var item model.ChatMessage
// userId, _ := c.Get(types.LoginUserID)
// res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
// if res.Error != nil {
// resp.ERROR(c, res.Error.Error())
// return
// }
// resp.SUCCESS(c, item.Tokens)
// return
//}
tokens, err := utils.CalcTokens(data.Text, data.Model)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, tokens)
}
func getTotalTokens(req types.ApiRequest) int {
encode := utils.JsonEncode(req.Messages)
var items []map[string]interface{}
err := utils.JsonDecode(encode, &items)
if err != nil {
return 0
}
tokens := 0
for _, item := range items {
content, ok := item["content"]
if ok && !utils.IsEmptyValue(content) {
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
if err == nil {
tokens += t
}
}
}
return tokens
}
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")
if h.ReqCancelFunc.Has(sessionId) {
h.ReqCancelFunc.Get(sessionId)()
h.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
// ONLY allow apiURL in blank list
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
logger.Debugf("对话请求消息体:%+v", req)
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
return client.Do(request)
}
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
}
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
Type: types.PowerConsume,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
})
if err != nil {
logger.Error(err)
}
}
func (h *ChatHandler) saveChatHistory(
req types.ApiRequest,
usage Usage,
message types.Message,
chatCtx []types.Message,
session *types.ChatSession,
role model.ChatRole,
userVo vo.User,
promptCreatedAt time.Time,
replyCreatedAt time.Time) {
useMsg := types.Message{Role: "user", Content: usage.Prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
var promptTokens, replyTokens, totalTokens int
if usage.PromptTokens > 0 {
promptTokens = usage.PromptTokens
} else {
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(usage.Prompt),
Tokens: promptTokens,
TotalTokens: promptTokens,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
err := h.DB.Save(&historyUserMsg).Error
if err != nil {
logger.Error("failed to save prompt history message: ", err)
}
// for reply
// 计算本次对话消耗的总 token 数量
if usage.CompletionTokens > 0 {
replyTokens = usage.CompletionTokens
totalTokens = usage.TotalTokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
totalTokens = replyTokens + getTotalTokens(req)
}
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: usage.Content,
Tokens: replyTokens,
TotalTokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
err = h.DB.Create(&historyReplyMsg).Error
if err != nil {
logger.Error("failed to save reply history message: ", err)
}
// 更新用户算力
if session.Model.Power > 0 {
h.subUserPower(userVo, session, promptTokens, replyTokens)
}
// 保存当前会话
var chatItem model.ChatItem
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
if err != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = userVo.Id
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(usage.Prompt) > 30 {
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
} else {
chatItem.Title = usage.Prompt
}
chatItem.Model = req.Model
err = h.DB.Create(&chatItem).Error
if err != nil {
logger.Error("failed to save chat item: ", err)
}
}
}
// 将AI回复消息中生成的图片链接下载到本地
func (h *ChatHandler) extractImgUrl(text string) string {
pattern := `!\[([^\]]*)]\(([^)]+)\)`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(text, -1)
// 下载图片并替换链接地址
for _, match := range matches {
imageURL := match[2]
logger.Debug(imageURL)
// 对于相同地址的图片,已经被替换了,就不再重复下载了
if !strings.Contains(text, imageURL) {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
text = strings.ReplaceAll(text, imageURL, newImgURL)
}
return text
}

View File

@@ -1,4 +1,4 @@
package chatimpl package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0) var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem var chats []model.ChatItem
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats) h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
if res.Error == nil { if len(chats) == 0 {
var roleIds = make([]uint, 0) resp.SUCCESS(c, items)
for _, chat := range chats { return
roleIds = append(roleIds, chat.RoleId) }
}
var roles []model.ChatRole
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
roleMap[role.Id] = role
}
for _, chat := range chats { var roleIds = make([]uint, 0)
var item vo.ChatItem var modelValues = make([]string, 0)
err := utils.CopyObject(chat, &item) for _, chat := range chats {
if err == nil { roleIds = append(roleIds, chat.RoleId)
item.Id = chat.Id modelValues = append(modelValues, chat.Model)
item.Icon = roleMap[chat.RoleId].Icon }
items = append(items, item)
}
}
}
var roles []model.ChatRole
var models []model.ChatModel
roleMap := make(map[uint]model.ChatRole)
modelMap := make(map[string]model.ChatModel)
h.DB.Where("id IN ?", roleIds).Find(&roles)
h.DB.Where("value IN ?", modelValues).Find(&models)
for _, role := range roles {
roleMap[role.Id] = role
}
for _, m := range models {
modelMap[m.Value] = m
}
for _, chat := range chats {
var item vo.ChatItem
err := utils.CopyObject(chat, &item)
if err == nil {
item.Id = chat.Id
item.Icon = roleMap[chat.RoleId].Icon
item.ModelId = modelMap[chat.Model].Id
items = append(items, item)
}
} }
resp.SUCCESS(c, items) resp.SUCCESS(c, items)
} }
@@ -96,7 +105,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
for _, chat := range chats { for _, chat := range chats {
chatIds = append(chatIds, chat.ChatId) chatIds = append(chatIds, chat.ChatId)
// 清空会话上下文 // 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId) h.ChatContexts.Delete(chat.ChatId)
} }
err = h.DB.Transaction(func(tx *gorm.DB) error { err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{}) res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
@@ -108,8 +117,6 @@ func (h *ChatHandler) Clear(c *gin.Context) {
if res.Error != nil { if res.Error != nil {
return res.Error return res.Error
} }
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
return nil return nil
}) })
@@ -175,7 +182,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
// TODO: 是否要删除 MidJourney 绘画记录和图片文件? // TODO: 是否要删除 MidJourney 绘画记录和图片文件?
// 清空会话上下文 // 清空会话上下文
h.App.ChatContexts.Delete(chatId) h.ChatContexts.Delete(chatId)
resp.SUCCESS(c, types.OkMsg) resp.SUCCESS(c, types.OkMsg)
} }

View File

@@ -13,6 +13,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -30,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0) var chatModels = make([]vo.ChatModel, 0)
var res *gorm.DB var res *gorm.DB
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
t := c.Query("type")
if t != "" {
session = session.Where("type", t)
}
// 如果用户没有登录,则加载所有开放模型 // 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) { if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items) res = session.Where("open", true).Order("sort_num ASC").Find(&items)
} else { } else {
user, _ := h.GetLoginUser(c) user, _ := h.GetLoginUser(c)
var models []int var models []int
@@ -43,7 +49,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
} }
// 查询用户有权限访问的模型以及所有开放的模型 // 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where( res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true), h.DB.Where("id IN ?", models).Or("open", true),
).Order("sort_num ASC").Find(&items) ).Order("sort_num ASC").Find(&items)
} }

View File

@@ -29,45 +29,63 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
// List 获取用户聊天应用列表 // List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all") tid := h.GetInt(c, "tid", 0)
var roles []model.ChatRole
session := h.DB.Where("enable", true)
if tid > 0 {
session = session.Where("tid", tid)
}
err := session.Order("sort_num ASC").Find(&roles).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
}
// ListByUser 获取用户添加的角色列表
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var roles []model.ChatRole var roles []model.ChatRole
var roleVos = make([]vo.ChatRole, 0) session := h.DB.Where("enable", true)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles) // 如果用户没登录,则获取所有角色
if userId > 0 {
var user model.User
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return
}
// 保证用户至少有一个角色可用
if len(roleKeys) > 0 {
session = session.Where("marker IN ?", roleKeys)
}
}
if id > 0 {
session = session.Or("id", id)
}
res := session.Order("sort_num ASC").Find(&roles)
if res.Error != nil { if res.Error != nil {
resp.SUCCESS(c, roleVos) resp.ERROR(c, res.Error.Error())
return
}
// 获取所有角色
if userId == 0 || all {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
return
}
var user model.User
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return return
} }
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles { for _, r := range roles {
if !utils.ContainsStr(roleKeys, r.Key) {
continue
}
var v vo.ChatRole var v vo.ChatRole
err := utils.CopyObject(r, &v) err := utils.CopyObject(r, &v)
if err == nil { if err == nil {
@@ -94,10 +112,9 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return return
} }
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) err = h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)).Error
if res.Error != nil { if err != nil {
logger.Error("添加应用失败:", err) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败!")
return return
} }

View File

@@ -1,215 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
}
if len(responseBody.Choices) == 0 {
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if strings.Contains(res.Error.Message, "maximum context length") {
logger.Error(res.Error.Message)
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
}
}
return nil
}

View File

@@ -1,280 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)
type baiduResp struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
Result string `json:"result"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
var resp baiduResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
}
if resp.IsEnd {
break
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
}
return nil
}
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 24 * 20 // access_token 有效期
key := strings.Split(apiKey, "|")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
tokenString = fmt.Sprintf("%s", r["access_token"])
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, nil
}

View File

@@ -1,635 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
var ErrImg = "![](/images/wx.png)"
var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
}
}
func (h *ChatHandler) Init() {
// 如果后台有上传微信客服微信二维码,则覆盖
if h.App.SysConfig.WechatCardURL != "" {
ErrImg = fmt.Sprintf("![](%s)", h.App.SysConfig.WechatCardURL)
}
}
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
modelId := h.GetInt(c, "model_id", 0)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
modelId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
return
}
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
return
}
session = &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
Username: user.Username,
UserId: user.Id,
}
h.App.ChatSession.Put(sessionId, session)
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
}
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Name: chatModel.Name,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
h.Init()
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.App.ChatClients.Delete(sessionId)
h.App.ChatSession.Delete(sessionId)
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
h.App.ReqCancelFunc.Delete(sessionId)
}
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Infof("回答完毕: %v", message.Content)
}
}
}()
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
if !h.App.Debug {
defer func() {
if r := recover(); r != nil {
logger.Error("Recover message from error: ", r)
}
}()
}
var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!")
return res.Error
}
var userVo vo.User
err := utils.CopyObject(user, &userVo)
userVo.Id = user.Id
if err != nil {
return errors.New("User 对象转换失败," + err.Error())
}
if userVo.Status == false {
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
required := parameters["required"]
delete(parameters, "required")
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
// Fixed: compatible for gpt4-turbo-xxx model
if !strings.HasPrefix(req.Model, "gpt-4-turbo-") {
tool.Function.Required = required
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 加载聊天上下文
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
}
chatCtx = append(chatCtx, ms)
}
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)
for _, m := range chatCtx {
reqMgs = append(reqMgs, m)
}
if session.Model.Platform == types.QWen {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: prompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
text := prompt
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
data = append(data, gin.H{
"type": "text",
"text": text,
})
content = data
} else {
content = prompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform),
})
return nil
}
// Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) {
var data struct {
Text string `json:"text"`
Model string `json:"model"`
ChatId string `json:"chat_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c, item.Tokens)
return
}
tokens, err := utils.CalcTokens(data.Text, data.Model)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, tokens)
}
func getTotalTokens(req types.ApiRequest) int {
encode := utils.JsonEncode(req.Messages)
var items []map[string]interface{}
err := utils.JsonDecode(encode, &items)
if err != nil {
return 0
}
tokens := 0
for _, item := range items {
content, ok := item["content"]
if ok && !utils.IsEmptyValue(content) {
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
if err == nil {
tokens += t
}
}
}
return tokens
}
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")
if h.App.ReqCancelFunc.Has(sessionId) {
h.App.ReqCancelFunc.Get(sessionId)()
h.App.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Debug().Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
}
var apiURL string
switch session.Model.Platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if session.Model.Platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
logger.Debugf(utils.JsonEncode(req))
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure:
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM:
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
return client.Do(request)
}
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
}
// 将AI回复消息中生成的图片链接下载到本地
func (h *ChatHandler) extractImgUrl(text string) string {
pattern := `!\[([^\]]*)]\(([^)]+)\)`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(text, -1)
// 下载图片并替换链接地址
for _, match := range matches {
imageURL := match[2]
logger.Debug(imageURL)
// 对于相同地址的图片,已经被替换了,就不再重复下载了
if !strings.Contains(text, imageURL) {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
text = strings.ReplaceAll(text, imageURL, newImgURL)
}
return text
}

View File

@@ -1,243 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt/v5"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var event, content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "event:") {
event = line[6:]
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
switch event {
case "add":
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
contents = append(contents, content)
case "finish":
break
case "error":
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
break
case "interrupted":
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"code"`
Success bool `json:"success"`
Msg string `json:"msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if !res.Success {
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
}
}
return nil
}
func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 2
key := strings.Split(apiKey, ".")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": key[0],
"timestamp": time.Now().Unix(),
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
})
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
delete(token.Header, "typ")
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString([]byte(key[1]))
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, err
}

View File

@@ -1,311 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
req2 "github.com/imroc/req/v3"
)
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
utils.ReplyMessage(ws, err.Error())
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && tool.Function.Name == "" {
arguments = append(arguments, tool.Function.Arguments)
continue
}
}
// 兼容 Function Call
fun := responseBody.Choices[0].Delta.FunctionCall
if fun.Name != "" {
tool = *new(types.ToolCall)
tool.Function.Name = fun.Name
} else if toolCall {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(tool) {
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg)
}
continue
}
if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// output stopped
if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
if isNew {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
if toolCall { // 调用函数完成任务
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", function.Token).
SetBody(params).
SetSuccessResult(&apiRes).Post(function.Action)
errMsg := ""
if err != nil {
errMsg = err.Error()
} else if r.IsErrorState() {
errMsg = r.Status
}
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: apiRes.Data,
})
contents = append(contents, utils.InterfaceToString(apiRes.Data))
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
logger.Error(res.Error.Message)
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
}
}
return nil
}

View File

@@ -1,248 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
content = line[5:]
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
}
return nil
}

View File

@@ -1,336 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"html/template"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
)
type xunFeiResp struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
} `json:"text"`
} `json:"choices"`
Usage struct {
Text struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
var Model2URL = map[string]string{
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
}
// use the last unused key
if res.Error != nil {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
if err != nil {
logger.Error(readResp(resp) + err.Error())
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
} else if resp.StatusCode != 101 {
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
}
data := buildRequest(key[0], req)
fmt.Printf("%+v", data)
fmt.Println(apiURL)
err = conn.WriteJSON(data)
if err != nil {
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
return nil
}
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.Error("error with read message:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
break
}
// 解析数据
var result xunFeiResp
err = json.Unmarshal(msg, &result)
if err != nil {
logger.Error("error with parsing JSON:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
return nil
}
if result.Header.Code != 0 {
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
return nil
}
content = result.Payload.Choices.Text[0].Content
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
contents = append(contents, content)
// 第一个结果
if result.Payload.Choices.Status == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
if result.Payload.Choices.Status == 2 { // 最终结果
_ = conn.Close() // 关闭连接
break
}
select {
case <-ctx.Done():
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
return nil
default:
continue
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
return nil
}
// 构建 websocket 请求实体
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
return map[string]interface{}{
"header": map[string]interface{}{
"app_id": appid,
},
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",
},
},
"payload": map[string]interface{}{
"message": map[string]interface{}{
"text": req.Messages,
},
},
}
}
// 创建鉴权 URL
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
ul, err := url.Parse(hostURL)
if err != nil {
return "", err
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
signStr := strings.Join(signString, "\n")
sha := hmacWithSha256(signStr, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
return hostURL + "?" + v.Encode(), nil
}
// 使用 sha256 签名
func hmacWithSha256(data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
// 读取响应
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}

View File

@@ -9,6 +9,7 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
@@ -19,10 +20,11 @@ import (
type ConfigHandler struct { type ConfigHandler struct {
BaseHandler BaseHandler
licenseService *service.LicenseService
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
} }
// Get 获取指定的系统配置 // Get 获取指定的系统配置
@@ -44,3 +46,9 @@ func (h *ConfigHandler) Get(c *gin.Context) {
resp.SUCCESS(c, value) resp.SUCCESS(c, value)
} }
// License 获取 License 配置
func (h *ConfigHandler) License(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license.Configs)
}

View File

@@ -8,18 +8,16 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"net/http"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
@@ -27,15 +25,17 @@ import (
type DallJobHandler struct { type DallJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
service *dalle.Service dallService *dalle.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService
} }
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
return &DallJobHandler{ return &DallJobHandler{
service: service, dallService: service,
uploader: manager, uploader: manager,
userService: userService,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -43,57 +43,13 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
} }
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool { func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return false return false
} }
if user.Power < h.App.SysConfig.DallPower {
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false return false
} }
@@ -127,20 +83,16 @@ func (h *DallJobHandler) Image(c *gin.Context) {
return return
} }
h.service.PushTask(types.DallTask{ h.dallService.PushTask(types.DallTask{
JobId: job.Id, ClientId: data.ClientId,
UserId: uint(userId), JobId: job.Id,
Prompt: data.Prompt, UserId: uint(userId),
Quality: data.Quality, Prompt: data.Prompt,
Size: data.Size, Quality: data.Quality,
Style: data.Style, Size: data.Size,
Power: job.Power, Style: data.Style,
Power: job.Power,
}) })
client := h.service.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -159,13 +111,13 @@ func (h *DallJobHandler) ImgWall(c *gin.Context) {
// JobList 获取 SD 任务列表 // JobList 获取 SD 任务列表
func (h *DallJobHandler) JobList(c *gin.Context) { func (h *DallJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status") finish := h.GetBool(c, "finish")
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0) page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0) pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish") publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish) err, jobs := h.getData(finish, userId, page, pageSize, publish)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
@@ -175,11 +127,11 @@ func (h *DallJobHandler) JobList(c *gin.Context) {
} }
// JobList 获取任务列表 // JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if finish { if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress >= ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
} }
@@ -193,11 +145,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize) session = session.Offset(offset).Limit(pageSize)
} }
// 统计总数
var total int64
session.Model(&model.DallJob{}).Count(&total)
var items []model.DallJob var items []model.DallJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
return res.Error, nil return res.Error, vo.Page{}
} }
var jobs = make([]vo.DallJob, 0) var jobs = make([]vo.DallJob, 0)
@@ -210,30 +165,39 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
jobs = append(jobs, job) jobs = append(jobs, job)
} }
return nil, jobs return nil, vo.NewPage(total, page, pageSize, jobs)
} }
// Remove remove task image // Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) { func (h *DallJobHandler) Remove(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint `json:"id"` userId := h.GetLoginUserId(c)
UserId uint `json:"user_id"` var job model.DallJob
ImgURL string `json:"img_url"` if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
} resp.ERROR(c, "记录不存在")
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return return
} }
// remove job recode // 删除任务
res := h.DB.Delete(&model.DallJob{Id: data.Id}) tx := h.DB.Begin()
if res.Error != nil { tx.Delete(&job)
resp.ERROR(c, res.Error.Error()) // 如果任务未完成,或者任务失败,则恢复用户算力
return if job.Progress != 100 {
err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
tx.Commit()
// remove image // remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL) err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil { if err != nil {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
@@ -243,18 +207,13 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
// Publish 发布/取消发布图片到画廊显示 // Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) { func (h *DallJobHandler) Publish(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint `json:"id"` userId := h.GetLoginUserId(c)
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享 action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true) err := h.DB.Model(&model.DallJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, err.Error())
return return
} }

View File

@@ -8,15 +8,16 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"errors"
"fmt"
"strings" "strings"
"time" "time"
@@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
resp.SUCCESS(c, content) resp.SUCCESS(c, content)
} }
// List 获取所有的工具函数列表
func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function
err := h.DB.Where("enabled", true).Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
tools := make([]vo.Function, 0)
for _, v := range items {
var f vo.Function
err = utils.CopyObject(v, &f)
if err != nil {
continue
}
f.Action = ""
f.Token = ""
tools = append(tools, f)
}
resp.SUCCESS(c, tools)
}

View File

@@ -9,7 +9,6 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
@@ -59,23 +58,16 @@ func (h *InviteHandler) Code(c *gin.Context) {
// List Log 用户邀请记录 // List Log 用户邀请记录
func (h *InviteHandler) List(c *gin.Context) { func (h *InviteHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
var data struct { pageSize := h.GetInt(c, "page_size", 20)
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId) session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64 var total int64
session.Model(&model.InviteLog{}).Count(&total) session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog var items []model.InviteLog
var list = make([]vo.InviteLog, 0) var list = make([]vo.InviteLog, 0)
offset := (data.Page - 1) * data.PageSize offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var v vo.InviteLog var v vo.InviteLog
@@ -89,7 +81,7 @@ func (h *InviteHandler) List(c *gin.Context) {
} }
} }
} }
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
} }
// Hits 访问邀请码 // Hits 访问邀请码

View File

@@ -8,110 +8,66 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
) )
// MarkMapHandler 生成思维导图 // MarkMapHandler 生成思维导图
type MarkMapHandler struct { type MarkMapHandler struct {
BaseHandler BaseHandler
clients *types.LMap[int, *types.WsClient] clients *types.LMap[int, *types.WsClient]
userService *service.UserService
} }
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler {
return &MarkMapHandler{ return &MarkMapHandler{
BaseHandler: BaseHandler{App: app, DB: db}, BaseHandler: BaseHandler{App: app, DB: db},
clients: types.NewLMap[int, *types.WsClient](), clients: types.NewLMap[int, *types.WsClient](),
userService: userService,
} }
} }
func (h *MarkMapHandler) Client(c *gin.Context) { // Generate 生成思维导图
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) func (h *MarkMapHandler) Generate(c *gin.Context) {
if err != nil { var data struct {
logger.Error(err) Prompt string `json:"prompt"`
ModelId int `json:"model_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return return
} }
modelId := h.GetInt(c, "model_id", 0) userId := h.GetLoginUserId(c)
userId := h.GetInt(c, "user_id", 0)
client := types.NewWsClient(ws)
h.clients.Put(userId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.clients.Delete(userId)
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
continue
}
// change model
if message.Type == "model_id" {
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
continue
}
logger.Info("Receive a message: ", message.Content)
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
}
}
}()
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
var user model.User var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId) err := h.DB.Where("id", userId).First(&user, userId).Error
if res.Error != nil { if err != nil {
return fmt.Errorf("error with query user info: %v", res.Error) resp.ERROR(c, "error with query user info")
return
} }
var chatModel model.ChatModel var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel) err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
if res.Error != nil { if err != nil {
return fmt.Errorf("error with query chat model: %v", res.Error) resp.ERROR(c, "error with query chat model")
} return
if user.Status == false {
return errors.New("当前用户被禁用")
} }
if user.Power < chatModel.Power { if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power) resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power))
return
} }
messages := make([]interface{}, 0) messages := make([]interface{}, 0)
messages = append(messages, types.Message{Role: "system", Content: ` messages = append(messages, types.Message{Role: "system", Content: `
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: 你是一位非常优秀的思维导图助手, 你能帮助用户整理思路,根据用户提供的主题或内容,快速生成结构清晰,有条理的思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
# Geek-AI 助手 # Geek-AI 助手
## 完整的开源系统 ## 完整的开源系统
@@ -128,128 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
### 支付宝 ### 支付宝
### 微信 ### 微信
另外,除此之外不要任何解释性语句。 请直接生成结果,不要任何解释性语句。
`}) `})
messages = append(messages, types.Message{Role: "user", Content: prompt}) messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", data.Prompt)})
var req = types.ApiRequest{ content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
if err != nil { if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err) resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
return
} }
defer response.Body.Close() // 扣减算力
if chatModel.Power > 0 {
contentType := response.Header.Get("Content-Type") err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
if strings.Contains(contentType, "text/event-stream") { Type: types.PowerConsume,
// 循环读取 Chunk 消息 Model: chatModel.Value,
var message = types.Message{} Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
scanner := bufio.NewScanner(response.Body) })
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
return fmt.Errorf("error with decode data: %v", err)
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
body, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
return fmt.Errorf("读取响应失败: %v", err) resp.ERROR(c, "error with save power log, "+err.Error())
} return
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("解析响应失败: %v", err)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
// remove key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
return errors.New("请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
return errors.New("请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else {
return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
} }
} }
return nil resp.SUCCESS(c, content)
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", types.OpenAI).
Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := apiKey.ApiURL
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
return client.Do(request)
} }

View File

@@ -27,9 +27,15 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
// List 数据列表 // List 数据列表
func (h *MenuHandler) List(c *gin.Context) { func (h *MenuHandler) List(c *gin.Context) {
index := h.GetBool(c, "index")
var items []model.Menu var items []model.Menu
var list = make([]vo.Menu, 0) var list = make([]vo.Menu, 0)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items) session := h.DB.Session(&gorm.Session{})
session = session.Where("enabled", true)
if index {
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
}
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var product vo.Menu var product vo.Menu

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
@@ -17,29 +18,27 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"encoding/base64"
"fmt"
"net/http"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
) )
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
pool *mj.ServicePool mjService *mj.Service
snowflake *service.Snowflake snowflake *service.Snowflake
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService
} }
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
return &MidJourneyHandler{ return &MidJourneyHandler{
snowflake: snowflake, snowflake: snowflake,
pool: pool, mjService: service,
uploader: manager, uploader: manager,
userService: userService,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -59,52 +58,26 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true return true
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct { var data struct {
SessionId string `json:"session_id"`
TaskType string `json:"task_type"` TaskType string `json:"task_type"`
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"` NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"` Rate string `json:"rate"`
Model string `json:"model"` Model string `json:"model"` // 模型
Chaos int `json:"chaos"` Chaos int `json:"chaos"` // 创意度取值范围: 0-100
Raw bool `json:"raw"` Raw bool `json:"raw"` // 是否开启原始模型
Seed int64 `json:"seed"` Seed int64 `json:"seed"` // 随机数
Stylize int `json:"stylize"` Stylize int `json:"stylize"` // 风格化
ImgArr []string `json:"img_arr"` ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"` Tile bool `json:"tile"` // 重复平铺
Quality float32 `json:"quality"` Quality float32 `json:"quality"` // 画质
Iw float32 `json:"iw"` Iw float32 `json:"iw"`
CRef string `json:"cref"` //生成角色一致的图像 CRef string `json:"cref"` //生成角色一致的图像
SRef string `json:"sref"` //生成风格一致的图像 SRef string `json:"sref"` //生成风格一致的图像
@@ -202,59 +175,45 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ h.mjService.PushTask(types.MjTask{
Id: job.Id, Id: job.Id,
ClientId: data.ClientId,
TaskId: taskId, TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskType(data.TaskType), Type: types.TaskType(data.TaskType),
Prompt: data.Prompt, Prompt: data.Prompt,
NegPrompt: data.NegPrompt, NegPrompt: data.NegPrompt,
Params: params, Params: params,
UserId: userId, UserId: userId,
ImgArr: data.ImgArr, ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
}) })
client := h.pool.Clients.Get(uint(job.UserId)) // update user's power
if client != nil { err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
_ = client.Send([]byte("Task Updated")) Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
} }
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
type reqVo struct { type reqVo struct {
Index int `json:"index"` Index int `json:"index"`
ClientId string `json:"client_id"`
ChannelId string `json:"channel_id"` ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"` MessageHash string `json:"message_hash"`
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
ChatId string `json:"chat_id"`
RoleId int `json:"role_id"`
Icon string `json:"icon"`
} }
// Upscale send upscale command to MidJourney Bot // Upscale send upscale command to MidJourney Bot
func (h *MidJourneyHandler) Upscale(c *gin.Context) { func (h *MidJourneyHandler) Upscale(c *gin.Context) {
var data reqVo var data reqVo
if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
@@ -272,7 +231,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
UserId: userId, UserId: userId,
TaskId: taskId, TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower, Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
@@ -281,46 +239,36 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ h.mjService.PushTask(types.MjTask{
Id: job.Id, Id: job.Id,
SessionId: data.SessionId, ClientId: data.ClientId,
Type: types.TaskUpscale, Type: types.TaskUpscale,
Prompt: data.Prompt,
UserId: userId, UserId: userId,
ChannelId: data.ChannelId, ChannelId: data.ChannelId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
}) })
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
// 记录算力变化日志 Type: types.PowerConsume,
if tx.Error == nil && tx.RowsAffected > 0 { Model: "mid-journey",
user, _ := h.GetLoginUser(c) Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
h.DB.Create(&model.PowerLog{ })
UserId: user.Id, if err != nil {
Username: user.Username, resp.ERROR(c, err.Error())
Type: types.PowerConsume, return
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// Variation send variation command to MidJourney Bot // Variation send variation command to MidJourney Bot
func (h *MidJourneyHandler) Variation(c *gin.Context) { func (h *MidJourneyHandler) Variation(c *gin.Context) {
var data reqVo var data reqVo
if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
@@ -339,7 +287,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
UserId: userId, UserId: userId,
TaskId: taskId, TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower, Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
@@ -348,40 +295,28 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ h.mjService.PushTask(types.MjTask{
Id: job.Id, Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskVariation, Type: types.TaskVariation,
Prompt: data.Prompt, ClientId: data.ClientId,
UserId: userId, UserId: userId,
Index: data.Index, Index: data.Index,
ChannelId: data.ChannelId, ChannelId: data.ChannelId,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
}) })
client := h.pool.Clients.Get(uint(job.UserId)) err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
if client != nil { Type: types.PowerConsume,
_ = client.Send([]byte("Task Updated")) Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
} }
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -400,13 +335,13 @@ func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
// JobList 获取 MJ 任务列表 // JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) { func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status") finish := h.GetBool(c, "finish")
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0) page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0) pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish") publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish) err, jobs := h.getData(finish, userId, page, pageSize, publish)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
@@ -416,10 +351,10 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
} }
// JobList 获取 MJ 任务列表 // JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if finish { if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress >= ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
} }
@@ -434,10 +369,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
session = session.Offset(offset).Limit(pageSize) session = session.Offset(offset).Limit(pageSize)
} }
// 统计总数
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var items []model.MidJourneyJob var items []model.MidJourneyJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
return res.Error, nil return res.Error, vo.Page{}
} }
var jobs = make([]vo.MidJourneyJob, 0) var jobs = make([]vo.MidJourneyJob, 0)
@@ -447,71 +386,56 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
if err != nil { if err != nil {
continue continue
} }
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} else {
job.ImgURL = job.OrgURL
}
}
jobs = append(jobs, job) jobs = append(jobs, job)
} }
return nil, jobs return nil, vo.NewPage(total, page, pageSize, jobs)
} }
// Remove remove task image // Remove remove task image
func (h *MidJourneyHandler) Remove(c *gin.Context) { func (h *MidJourneyHandler) Remove(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint `json:"id"` userId := h.GetInt(c, "user_id", 0)
UserId uint `json:"user_id"` var job model.MidJourneyJob
ImgURL string `json:"img_url"` if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
} resp.ERROR(c, "记录不存在")
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return return
} }
// remove job recode // remove job recode
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id}) tx := h.DB.Begin()
if res.Error != nil { tx.Delete(&job)
resp.ERROR(c, res.Error.Error()) // 如果任务未完成,或者任务失败,则恢复用户算力
return if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "mid-journey",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
tx.Commit()
// remove image // remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL) err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil { if err != nil {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// Publish 发布图片到画廊显示 // Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) { func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint `json:"id"` userId := h.GetInt(c, "user_id", 0)
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享 action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
} err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
if err := c.ShouldBindJSON(&data); err != nil { if err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, err.Error())
return
}
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return return
} }

161
api/handler/net_handler.go Normal file
View File

@@ -0,0 +1,161 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"io"
"net/http"
"time"
)
type NetHandler struct {
BaseHandler
uploaderManager *oss.UploaderManager
}
func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *NetHandler {
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *NetHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
logger.Info("upload file: ", file.Name)
// cut the file name if it's too long
if len(file.Name) > 100 {
file.Name = file.Name[:90] + file.Ext
}
userId := h.GetLoginUserId(c)
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}
func (h *NetHandler) List(c *gin.Context) {
var data struct {
Urls []string `json:"urls,omitempty"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
session := h.DB.Session(&gorm.Session{})
session = session.Where("user_id = ?", userId)
if len(data.Urls) > 0 {
session = session.Where("url IN ?", data.Urls)
}
// 统计总数
var total int64
session.Model(&model.File{}).Count(&total)
if data.Page > 0 && data.PageSize > 0 {
offset := (data.Page - 1) * data.PageSize
session = session.Offset(offset).Limit(data.PageSize)
}
err := session.Order("id desc").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, files))
}
// Remove remove files
func (h *NetHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}
func (h *NetHandler) Download(c *gin.Context) {
fileUrl := c.Query("url")
// 使用http工具下载文件
if fileUrl == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
// 使用http.Get下载文件
r, err := http.Get(fileUrl)
if err != nil {
resp.ERROR(c, err.Error())
return
}
defer r.Body.Close()
if r.StatusCode != http.StatusOK {
resp.ERROR(c, "error status"+r.Status)
return
}
c.Status(http.StatusOK)
// 将下载的文件内容写入响应
_, _ = io.Copy(c.Writer, r.Body)
}

View File

@@ -0,0 +1,227 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
req2 "github.com/imroc/req/v3"
"io"
"strings"
"time"
)
type Usage struct {
Prompt string `json:"prompt,omitempty"`
Content string `json:"content,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIResVo struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage Usage `json:"usage"`
}
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{Role: "assistant"}
var contents = make([]string, 0)
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return errors.New(line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
continue
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.SendChunkMsg(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && tool.Function.Name == "" {
arguments = append(arguments, tool.Function.Arguments)
continue
}
}
// 兼容 Function Call
fun := responseBody.Choices[0].Delta.FunctionCall
if fun.Name != "" {
tool = *new(types.ToolCall)
tool.Function.Name = fun.Name
} else if toolCall {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(tool) {
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.SendChunkMsg(ws, callMsg)
contents = append(contents, callMsg)
}
continue
}
if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// output stopped
if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
utils.SendChunkMsg(ws, responseBody.Choices[0].Delta.Content)
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
if toolCall { // 调用函数完成任务
params := make(map[string]interface{})
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", function.Token).
SetBody(params).
SetSuccessResult(&apiRes).Post(function.Action)
errMsg := ""
if err != nil {
errMsg = err.Error()
} else if r.IsErrorState() {
errMsg = r.Status
}
if errMsg != "" || apiRes.Code != types.Success {
errMsg = "调用函数工具出错:" + apiRes.Message + errMsg
contents = append(contents, errMsg)
} else {
errMsg = utils.InterfaceToString(apiRes.Data)
contents = append(contents, errMsg)
}
utils.SendChunkMsg(ws, errMsg)
}
// 消息发送成功
if len(contents) > 0 {
usage := Usage{
Prompt: prompt,
Content: strings.Join(contents, ""),
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
message.Content = usage.Content
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else { // 非流式输出
var respVo OpenAIResVo
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败:%v", body)
}
err = json.Unmarshal(body, &respVo)
if err != nil {
return fmt.Errorf("解析响应失败:%v", body)
}
content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
utils.SendChunkMsg(ws, content)
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
}
return nil
}

View File

@@ -14,6 +14,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -27,23 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// List 订单列表
func (h *OrderHandler) List(c *gin.Context) { func (h *OrderHandler) List(c *gin.Context) {
var data struct { page := h.GetInt(c, "page", 1)
Page int `json:"page"` pageSize := h.GetInt(c, "page_size", 20)
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess) session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64 var total int64
session.Model(&model.Order{}).Count(&total) session.Model(&model.Order{}).Count(&total)
var items []model.Order var items []model.Order
var list = make([]vo.Order, 0) var list = make([]vo.Order, 0)
offset := (data.Page - 1) * data.PageSize offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var order vo.Order var order vo.Order
@@ -52,11 +48,51 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix() order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
if !ok {
payMethod = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
if !ok {
payName = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
list = append(list, order) list = append(list, order)
} else { } else {
logger.Error(err) logger.Error(err)
} }
} }
} }
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
}
// Query 查询订单状态
func (h *OrderHandler) Query(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no")
var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
if order.Status == types.OrderPaidSuccess {
resp.SUCCESS(c, gin.H{"status": order.Status})
return
}
counter := 0
for {
time.Sleep(time.Second)
var item model.Order
h.DB.Where("order_no = ?", orderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
}
counter++
}
resp.SUCCESS(c, gin.H{"status": order.Status})
} }

View File

@@ -8,6 +8,8 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"embed"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
@@ -15,13 +17,7 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"embed"
"encoding/base64"
"fmt"
"github.com/shopspring/decimal"
"math"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time" "time"
@@ -29,343 +25,213 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const ( type PayWay struct {
PayWayAlipay = "支付宝" Name string `json:"name"`
PayWayXunHu = "虎皮椒" Value string `json:"value"`
PayWayJs = "PayJS" }
)
// PaymentHandler 支付服务回调 handler // PaymentHandler 支付服务回调 handler
type PaymentHandler struct { type PaymentHandler struct {
BaseHandler BaseHandler
alipayService *payment.AlipayService alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService huPiPayService *payment.HuPiPayService
js *payment.PayJS geekPayService *payment.GeekPayService
snowflake *service.Snowflake wechatPayService *payment.WechatPayService
fs embed.FS snowflake *service.Snowflake
lock sync.Mutex userService *service.UserService
fs embed.FS
lock sync.Mutex
signKey string // 用来签名的随机秘钥
} }
func NewPaymentHandler( func NewPaymentHandler(
server *core.AppServer, server *core.AppServer,
alipayService *payment.AlipayService, alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService, huPiPayService *payment.HuPiPayService,
js *payment.PayJS, geekPayService *payment.GeekPayService,
wechatPayService *payment.WechatPayService,
db *gorm.DB, db *gorm.DB,
userService *service.UserService,
snowflake *service.Snowflake, snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler { fs embed.FS) *PaymentHandler {
return &PaymentHandler{ return &PaymentHandler{
alipayService: alipayService, alipayService: alipayService,
huPiPayService: huPiPayService, huPiPayService: huPiPayService,
js: js, geekPayService: geekPayService,
snowflake: snowflake, wechatPayService: wechatPayService,
fs: fs, snowflake: snowflake,
lock: sync.Mutex{}, userService: userService,
fs: fs,
lock: sync.Mutex{},
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: server, App: server,
DB: db, DB: db,
}, },
signKey: utils.RandString(32),
} }
} }
func (h *PaymentHandler) DoPay(c *gin.Context) { func (h *PaymentHandler) Pay(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no") var data struct {
payWay := h.GetTrim(c, "pay_way") PayWay string `json:"pay_way"`
PayType string `json:"pay_type"`
if orderNo == "" { ProductId int `json:"product_id"`
UserId int `json:"user_id"`
Device string `json:"device"`
Host string `json:"host"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
var order model.Order var product model.Product
res := h.DB.Where("order_no = ?", orderNo).First(&order) err := h.DB.Where("id", data.ProductId).First(&product).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "Order not found") resp.ERROR(c, "Product not found")
return return
} }
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回 orderNo, err := h.snowflake.Next(false)
if order.Status == types.OrderPaidSuccess { if err != nil {
resp.ERROR(c, "This order had been paid, please do not pay twice") resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
err = h.DB.Where("id", data.UserId).First(&user).Error
if err != nil {
resp.NotAuth(c)
return return
} }
// 更新扫码状态 amount := product.Discount
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned) var payURL, returnURL, notifyURL string
if payWay == "alipay" { // 支付宝 switch data.PayWay {
// 生成支付链接 case "alipay":
notifyURL := h.App.Config.AlipayConfig.NotifyURL if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
returnURL := "" // 关闭同步回跳 notifyURL = h.App.Config.AlipayConfig.NotifyURL
amount := fmt.Sprintf("%.2f", order.Amount) } else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
}
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
returnURL = h.App.Config.AlipayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
money := fmt.Sprintf("%.2f", amount)
if data.Device == "wechat" {
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
} else {
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
}
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
if err != nil { if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error()) resp.ERROR(c, "error with generate pay url: "+err.Error())
return return
} }
break
c.Redirect(302, uri) case "wechat":
return if h.App.Config.WechatPayConfig.NotifyURL != "" {
} else if payWay == "hupi" { // 虎皮椒支付 notifyURL = h.App.Config.WechatPayConfig.NotifyURL
params := payment.HuPiPayReq{ } else {
Version: "1.1", notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
TradeOrderId: orderNo, }
TotalFee: fmt.Sprintf("%f", order.Amount), if data.Device == "wechat" {
Title: order.Subject, payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL, OutTradeNo: orderNo,
WapName: "极客学长", TotalFee: int(amount * 100),
Subject: product.Name,
NotifyURL: notifyURL,
ClientIP: c.ClientIP(),
})
} else {
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name,
NotifyURL: notifyURL,
})
} }
r, err := h.huPiPayService.Pay(params)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
break
c.Redirect(302, r.URL)
}
resp.ERROR(c, "Invalid operations")
}
// OrderQuery 查询订单状态
func (h *PaymentHandler) OrderQuery(c *gin.Context) {
var data struct {
OrderNo string `json:"order_no"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var order model.Order
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
if order.Status == types.OrderPaidSuccess {
resp.SUCCESS(c, gin.H{"status": order.Status})
return
}
counter := 0
for {
time.Sleep(time.Second)
var item model.Order
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
}
counter++
}
resp.SUCCESS(c, gin.H{"status": order.Status})
}
// PayQrcode 生成支付 URL 二维码
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
var payWay string
var notifyURL string
switch data.PayWay {
case "hupi": case "hupi":
payWay = PayWayXunHu if h.App.Config.HuPiPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
default:
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
// PayJs 单独处理,只能用官方生成的二维码
if data.PayWay == "payjs" {
params := payment.JPayReq{
TotalFee: int(math.Ceil(order.Amount * 100)),
OutTradeNo: order.OrderNo,
Subject: product.Name,
}
r := h.js.Pay(params)
if r.IsOK() {
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
return
} else { } else {
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg) notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
return
} }
} if h.App.Config.HuPiPayConfig.ReturnURL != "" {
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
var logo string
if data.PayWay == "alipay" {
logo = "res/img/alipay.jpg"
} else if data.PayWay == "hupi" {
if h.App.Config.HuPiPayConfig.Name == "wechat" {
logo = "res/img/wechat-pay.jpg"
} else { } else {
logo = "res/img/alipay.jpg" returnURL = fmt.Sprintf("%s/payReturn", data.Host)
} }
} r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
file, err := h.fs.Open(logo)
if err != nil {
resp.ERROR(c, "error with open qrcode log file: "+err.Error())
return
}
parse, err := url.Parse(notifyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
imgData, err := utils.GenQrcode(imageURL, 400, file)
if err != nil {
resp.ERROR(c, err.Error())
return
}
imgDataBase64 := base64.StdEncoding.EncodeToString(imgData)
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
}
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
params := payment.HuPiPayReq{
Version: "1.1", Version: "1.1",
TradeOrderId: orderNo, TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount), TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name, Title: product.Name,
NotifyURL: notifyURL, NotifyURL: notifyURL,
ReturnURL: returnURL, ReturnURL: returnURL,
CallbackURL: returnURL, WapName: "GeekAI助手",
WapName: "极客学长", })
}
r, err := h.huPiPayService.Pay(params)
if err != nil { if err != nil {
logger.Error("error with generating Pay URL: ", err.Error()) resp.ERROR(c, err.Error())
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return return
} }
payURL = r.URL payURL = r.URL
case "payjs": break
payWay = PayWayJs case "geek":
notifyURL = h.App.Config.JPayConfig.NotifyURL if h.App.Config.GeekPayConfig.NotifyURL != "" {
returnURL = h.App.Config.JPayConfig.ReturnURL notifyURL = h.App.Config.GeekPayConfig.NotifyURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart() } else {
params := url.Values{} notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
params.Add("total_fee", fmt.Sprintf("%d", totalFee)) }
params.Add("out_trade_no", orderNo) if h.App.Config.GeekPayConfig.ReturnURL != "" {
params.Add("body", product.Name) data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
params.Add("notify_url", notifyURL) }
params.Add("auto", "0") if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
payURL = h.js.PayH5(params) returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
case "alipay": } else {
payWay = PayWayAlipay returnURL = fmt.Sprintf("%s/payReturn", data.Host)
notifyURL = h.App.Config.AlipayConfig.NotifyURL }
returnURL = h.App.Config.AlipayConfig.ReturnURL params := payment.GeekPayParams{
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name) OutTradeNo: orderNo,
Method: "web",
Name: product.Name,
Money: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
Type: data.PayType,
ReturnURL: returnURL,
NotifyURL: notifyURL,
}
res, err := h.geekPayService.Pay(params)
if err != nil { if err != nil {
resp.ERROR(c, "error with generating Pay URL: "+err.Error()) resp.ERROR(c, err.Error())
return return
} }
payURL = res.PayURL
default: default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay) resp.ERROR(c, "不支持的支付渠道")
return return
} }
// 创建订单 // 创建订单
remark := types.OrderRemark{ remark := types.OrderRemark{
Days: product.Days, Days: product.Days,
@@ -374,7 +240,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
Price: product.Price, Price: product.Price,
Discount: product.Discount, Discount: product.Discount,
} }
order := model.Order{ order := model.Order{
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
@@ -383,26 +248,24 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
Subject: product.Name, Subject: product.Name,
Amount: amount, Amount: amount,
Status: types.OrderNotPaid, Status: types.OrderNotPaid,
PayWay: payWay, PayWay: data.PayWay,
PayType: data.PayType,
Remark: utils.JsonEncode(remark), Remark: utils.JsonEncode(remark),
} }
res = h.DB.Create(&order) err = h.DB.Create(&order).Error
if res.Error != nil || res.RowsAffected == 0 { if err != nil {
resp.ERROR(c, "error with create order: "+res.Error.Error()) resp.ERROR(c, "error with create order: "+err.Error())
return return
} }
resp.SUCCESS(c, payURL) resp.SUCCESS(c, payURL)
} }
// 异步通知回调公共逻辑 // 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order) err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
if res.Error != nil { if err != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error) return fmt.Errorf("error with fetch order: %v", err)
logger.Error(err)
return err
} }
h.lock.Lock() h.lock.Lock()
@@ -414,45 +277,24 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
} }
var user model.User var user model.User
res = h.DB.First(&user, order.UserId) err = h.DB.First(&user, order.UserId).Error
if res.Error != nil { if err != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error) return fmt.Errorf("error with fetch user info: %v", err)
logger.Error(err)
return err
} }
var remark types.OrderRemark var remark types.OrderRemark
err := utils.JsonDecode(order.Remark, &remark) err = utils.JsonDecode(order.Remark, &remark)
if err != nil { if err != nil {
err := fmt.Errorf("error with decode order remark: %v", err) return fmt.Errorf("error with decode order remark: %v", err)
logger.Error(err)
return err
} }
var opt string // 增加用户算力
var power int err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
if remark.Days > 0 { // VIP 充值 Type: types.PowerRecharge,
if user.ExpiredTime >= time.Now().Unix() { Model: order.PayWay,
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix() Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
opt = "VIP充值VIP 没到期,只延期不增加算力" })
} else { if err != nil {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
power = h.App.SysConfig.VipMonthPower
opt = "VIP充值"
}
user.Vip = true
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
return err return err
} }
@@ -460,29 +302,16 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
order.PayTime = time.Now().Unix() order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo order.TradeNo = tradeNo
res = h.DB.Updates(&order) err = h.DB.Updates(&order).Error
if res.Error != nil { if err != nil {
err := fmt.Errorf("error with update order info: %v", res.Error) return fmt.Errorf("error with update order info: %v", err)
logger.Error(err)
return err
} }
// 更新产品销量 // 更新产品销量
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1)) err = h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).
UpdateColumn("sales", gorm.Expr("sales + ?", 1)).Error
// 记录算力充值日志 if err != nil {
if opt != "" { return fmt.Errorf("error with update product sales: %v", err)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
} }
return nil return nil
@@ -490,17 +319,22 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// GetPayWays 获取支付方式 // GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) { func (h *PaymentHandler) GetPayWays(c *gin.Context) {
data := gin.H{} payWays := make([]gin.H, 0)
if h.App.Config.AlipayConfig.Enabled { if h.App.Config.AlipayConfig.Enabled {
data["alipay"] = gin.H{"name": "alipay"} payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
} }
if h.App.Config.HuPiPayConfig.Enabled { if h.App.Config.HuPiPayConfig.Enabled {
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name} payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
} }
if h.App.Config.JPayConfig.Enabled { if h.App.Config.GeekPayConfig.Enabled {
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name} for _, v := range h.App.Config.GeekPayConfig.Methods {
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
}
} }
resp.SUCCESS(c, data) if h.App.Config.WechatPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
}
resp.SUCCESS(c, payWays)
} }
// HuPiPayNotify 虎皮椒支付异步回调 // HuPiPayNotify 虎皮椒支付异步回调
@@ -513,15 +347,17 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
orderNo := c.Request.Form.Get("trade_order_id") orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id") tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,订单 NO%s交易流水号%s", orderNo, tradeNo) logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
if err = h.huPiPayService.Check(tradeNo); err != nil { if err = h.huPiPayService.Check(orderNo); err != nil {
logger.Error("订单校验失败:", err) logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
err = h.notify(orderNo, tradeNo) err = h.notify(orderNo, tradeNo)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
@@ -537,18 +373,18 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return return
} }
// TODO验证交易签名 result := h.alipayService.TradeVerify(c.Request)
res := h.alipayService.TradeVerify(c.Request.Form) logger.Infof("收到支付宝商号订单支付回调:%+v", result)
logger.Infof("验证支付结果:%+v", res) if !result.Success() {
if !res.Success() { logger.Error("订单校验失败:", result.Message)
logger.Error("订单校验失败:", res.Message)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
tradeNo := c.Request.Form.Get("trade_no") tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(res.OutTradeNo, tradeNo) err = h.notify(result.OutTradeNo, tradeNo)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
@@ -556,33 +392,59 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
c.String(http.StatusOK, "success") c.String(http.StatusOK, "success")
} }
// PayJsNotify PayJs 支付异步回调 // GeekPayNotify 支付异步回调
func (h *PaymentHandler) PayJsNotify(c *gin.Context) { func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
var params = make(map[string]string)
for k := range c.Request.URL.Query() {
params[k] = c.Query(k)
}
logger.Infof("收到GeekPay订单支付回调%+v", params)
// 检查支付状态
if params["trade_status"] != "TRADE_SUCCESS" {
c.String(http.StatusOK, "success")
return
}
sign := h.geekPayService.Sign(params)
if sign != c.Query("sign") {
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
c.String(http.StatusOK, "fail")
return
}
err := h.notify(params["out_trade_no"], params["trade_no"])
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// WechatPayNotify 微信商户支付异步回调
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
err := c.Request.ParseForm() err := c.Request.ParseForm()
if err != nil { if err != nil {
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
orderNo := c.Request.Form.Get("out_trade_no") result := h.wechatPayService.TradeVerify(c.Request)
returnCode := c.Request.Form.Get("return_code") logger.Infof("收到微信商号订单支付回调:%+v", result)
logger.Infof("收到订单支付回调,订单 NO%s支付结果代码%v", orderNo, returnCode) if !result.Success() {
// 支付失败
if returnCode != "1" {
return
}
// 校验订单支付状态
tradeNo := c.Request.Form.Get("payjs_order_id")
err = h.js.Check(tradeNo)
if err != nil {
logger.Error("订单校验失败:", err) logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail") c.JSON(http.StatusBadRequest, gin.H{
"code": "FAIL",
"message": err.Error(),
})
return return
} }
err = h.notify(orderNo, tradeNo) err = h.notify(result.OutTradeNo, result.TradeId)
if err != nil { if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }

View File

@@ -0,0 +1,128 @@
package handler
import (
"fmt"
"geekai/core"
"geekai/store/model"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"time"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// OpenAI Realtime API Relay Server
type RealtimeHandler struct {
BaseHandler
}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
// 获取客户端请求中指定的子协议
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
md := c.Query("model")
userId := h.GetLoginUserId(c)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
c.Abort()
return
}
// 将 HTTP 协议升级为 Websocket 协议
subProtocols := strings.Split(clientProtocols, ",")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: subProtocols,
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
defer ws.Close()
// 目前只针对 VIP 用户可以访问
if !user.Vip {
sendError(ws, "当前功能只针对 VIP 用户开放")
c.Abort()
return
}
var apiKey model.ApiKey
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
if apiKey.Id == 0 {
sendError(ws, "管理员未配置 Realtime API KEY")
c.Abort()
return
}
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
// 连接到真实的后端服务器,传入相同的子协议
headers := http.Header{}
// 修正子协议内容
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
if clientProtocols != "" {
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
}
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
if err != nil {
sendError(ws, "桥接后端 API 失败:"+err.Error())
c.Abort()
return
}
defer backendConn.Close()
// 确保协议一致性,如果失败返回
if ws.Subprotocol() != backendConn.Subprotocol() {
sendError(ws, "Websocket 子协议不匹配")
c.Abort()
return
}
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
// 开始双向转发
errorChan := make(chan error, 2)
go relay(ws, backendConn, errorChan)
go relay(backendConn, ws, errorChan)
// 等待其中一个连接关闭
err = <-errorChan
logger.Infof("Relay ended: %v", err)
}
func relay(src, dst *websocket.Conn, errorChan chan error) {
for {
messageType, message, err := src.ReadMessage()
if err != nil {
errorChan <- err
return
}
err = dst.WriteMessage(messageType, message)
if err != nil {
errorChan <- err
return
}
}
}
func sendError(ws *websocket.Conn, message string) {
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
if err != nil {
logger.Error(err)
}
}

View File

@@ -0,0 +1,88 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"sync"
"time"
)
type RedeemHandler struct {
BaseHandler
lock sync.Mutex
userService *service.UserService
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler {
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
}
func (h *RedeemHandler) Verify(c *gin.Context) {
var data struct {
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
h.lock.Lock()
defer h.lock.Unlock()
var item model.Redeem
res := h.DB.Where("code", data.Code).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的兑换码!")
return
}
if !item.Enabled {
resp.ERROR(c, "当前兑换码已被禁用!")
return
}
if item.RedeemedAt > 0 {
resp.ERROR(c, "当前兑换码已使用,请勿重复使用!")
return
}
tx := h.DB.Begin()
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
Type: types.PowerRedeem,
Model: "兑换码",
Remark: fmt.Sprintf("兑换码核销,算力:%d兑换码%s...", item.Power, item.Code[:10]),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 更新核销状态
item.RedeemedAt = time.Now().Unix()
item.UserId = userId
err = tx.Updates(&item).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -1,106 +0,0 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
BaseHandler
lock sync.Mutex
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
func (h *RewardHandler) Verify(c *gin.Context) {
var data struct {
TxId string `json:"tx_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
}
// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
data.TxId = strings.ReplaceAll(data.TxId, " ", "")
h.lock.Lock()
defer h.lock.Unlock()
var item model.Reward
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!")
return
}
if item.Status {
resp.ERROR(c, "当前众筹交易流水号已经被核销,请不要重复核销!")
return
}
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 更新核销状态
item.Status = true
item.UserId = user.Id
item.Exchange = utils.JsonEncode(exchange)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
@@ -18,12 +19,8 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"fmt"
"net/http"
"time" "time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
@@ -31,19 +28,27 @@ import (
type SdJobHandler struct { type SdJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
pool *sd.ServicePool sdService *sd.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
snowflake *service.Snowflake snowflake *service.Snowflake
leveldb *store.LevelDB leveldb *store.LevelDB
userService *service.UserService
} }
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { func NewSdJobHandler(app *core.AppServer,
db *gorm.DB,
service *sd.Service,
manager *oss.UploaderManager,
snowflake *service.Snowflake,
userService *service.UserService,
levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{ return &SdJobHandler{
pool: pool, sdService: service,
uploader: manager, uploader: manager,
snowflake: snowflake, snowflake: snowflake,
leveldb: levelDB, leveldb: levelDB,
userService: userService,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -51,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, man
} }
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) preCheck(c *gin.Context) bool { func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
@@ -79,11 +63,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
return false
}
if user.Power < h.App.SysConfig.SdPower { if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false return false
@@ -99,10 +78,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return return
} }
var data struct { var data types.SdTaskParams
SessionId string `json:"session_id"`
types.SdTaskParams
}
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
@@ -167,35 +143,23 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return return
} }
h.pool.PushTask(types.SdTask{ h.sdService.PushTask(types.SdTask{
Id: int(job.Id), Id: int(job.Id),
SessionId: data.SessionId, ClientId: data.ClientId,
Type: types.TaskImage, Type: types.TaskImage,
Params: params, Params: params,
UserId: userId, UserId: userId,
}) })
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
// 记录算力变化日志 Type: types.PowerConsume,
if tx.Error == nil && tx.RowsAffected > 0 { Model: "stable-diffusion",
user, _ := h.GetLoginUser(c) Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
h.DB.Create(&model.PowerLog{ })
UserId: user.Id, if err != nil {
Username: user.Username, resp.ERROR(c, err.Error())
Type: types.PowerConsume, return
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
} }
resp.SUCCESS(c) resp.SUCCESS(c)
@@ -216,13 +180,13 @@ func (h *SdJobHandler) ImgWall(c *gin.Context) {
// JobList 获取 SD 任务列表 // JobList 获取 SD 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) { func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status") finish := h.GetBool(c, "finish")
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0) page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0) pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish") publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish) err, jobs := h.getData(finish, userId, page, pageSize, publish)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
@@ -232,11 +196,11 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
} }
// JobList 获取 MJ 任务列表 // JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if finish { if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress >= ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
} }
@@ -251,10 +215,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
session = session.Offset(offset).Limit(pageSize) session = session.Offset(offset).Limit(pageSize)
} }
// 统计总数
var total int64
session.Model(&model.SdJob{}).Count(&total)
var items []model.SdJob var items []model.SdJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
return res.Error, nil return res.Error, vo.Page{}
} }
var jobs = make([]vo.SdJob, 0) var jobs = make([]vo.SdJob, 0)
@@ -264,68 +232,58 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
if err != nil { if err != nil {
continue continue
} }
if item.Progress < 100 {
// 从 leveldb 中获取图片预览数据
var imageData string
err = h.leveldb.Get(item.TaskId, &imageData)
if err == nil {
job.ImgURL = "data:image/png;base64," + imageData
}
}
jobs = append(jobs, job) jobs = append(jobs, job)
} }
return nil, jobs return nil, vo.NewPage(total, page, pageSize, jobs)
} }
// Remove remove task image // Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) { func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint `json:"id"` userId := h.GetLoginUserId(c)
UserId uint `json:"user_id"` var job model.SdJob
ImgURL string `json:"img_url"` if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
} resp.ERROR(c, "记录不存在")
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return return
} }
// remove job recode // 删除任务
res := h.DB.Delete(&model.SdJob{Id: data.Id}) tx := h.DB.Begin()
if res.Error != nil { tx.Delete(&job)
resp.ERROR(c, res.Error.Error()) // 如果任务未完成,或者任务失败,则恢复用户算力
return if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d Err: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
tx.Commit()
// remove image // remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL) err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil { if err != nil {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte(sd.Finished))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// Publish 发布/取消发布图片到画廊显示 // Publish 发布/取消发布图片到画廊显示
func (h *SdJobHandler) Publish(c *gin.Context) { func (h *SdJobHandler) Publish(c *gin.Context) {
var data struct { id := h.GetInt(c, "id", 0)
Id uint `json:"id"` userId := h.GetLoginUserId(c)
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享 action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true) err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, err.Error())
return return
} }

View File

@@ -49,28 +49,50 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
var data struct { var data struct {
Receiver string `json:"receiver"` // 接收者 Receiver string `json:"receiver"` // 接收者
Key string `json:"key"` Key string `json:"key"`
Dots string `json:"dots"` Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if h.App.SysConfig.EnabledVerify {
if !h.captcha.Check(data) { var check bool
resp.ERROR(c, "验证码错误,请先完人机验证") if data.X != 0 {
return check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
} }
code := utils.RandomNumber(6) code := utils.RandomNumber(6)
var err error var err error
if strings.Contains(data.Receiver, "@") { // email if strings.Contains(data.Receiver, "@") { // email
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") { if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!") resp.ERROR(c, "系统已禁用邮箱注册!")
return return
} }
// 检查邮箱后缀是否在白名单
if len(h.App.SysConfig.EmailWhiteList) > 0 {
inWhiteList := false
for _, suffix := range h.App.SysConfig.EmailWhiteList {
if strings.HasSuffix(data.Receiver, suffix) {
inWhiteList = true
break
}
}
if !inWhiteList {
resp.ERROR(c, "邮箱后缀不在白名单中")
return
}
}
err = h.smtp.SendVerifyCode(data.Receiver, code) err = h.smtp.SendVerifyCode(data.Receiver, code)
} else { } else {
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") { if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!") resp.ERROR(c, "系统已禁用手机号注册!")
return return
} }
@@ -89,5 +111,9 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
return return
} }
resp.SUCCESS(c) if h.App.Debug {
resp.SUCCESS(c, code)
} else {
resp.SUCCESS(c)
}
} }

373
api/handler/suno_handler.go Normal file
View File

@@ -0,0 +1,373 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/suno"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type SunoHandler struct {
BaseHandler
sunoService *suno.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
return &SunoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
sunoService: service,
uploader: uploader,
userService: userService,
}
}
func (h *SunoHandler) Create(c *gin.Context) {
var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"`
Model string `json:"model"`
Tags string `json:"tags"`
Title string `json:"title"`
Type int `json:"type"`
RefTaskId string `json:"ref_task_id"` // 续写的任务id
ExtendSecs int `json:"extend_secs"` // 续写秒数
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
SongId string `json:"song_id,omitempty"` // 要拼接的歌曲id
AudioURL string `json:"audio_url,omitempty"` // 上传自己创作的歌曲
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
// 歌曲拼接
if data.SongId != "" && data.Type == 3 {
var song model.SunoJob
if err := h.DB.Where("song_id = ?", data.SongId).First(&song).Error; err == nil {
data.Instrumental = song.Instrumental
data.Model = song.ModelName
data.Tags = song.Tags
}
// 拼接歌词
var refSong model.SunoJob
if err := h.DB.Where("song_id = ?", data.RefSongId).First(&refSong).Error; err == nil {
data.Prompt = fmt.Sprintf("%s\n%s", song.Prompt, refSong.Prompt)
}
}
// 插入数据库
job := model.SunoJob{
UserId: int(h.GetLoginUserId(c)),
Prompt: data.Prompt,
Instrumental: data.Instrumental,
ModelName: data.Model,
Tags: data.Tags,
Title: data.Title,
Type: data.Type,
RefSongId: data.RefSongId,
RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower,
SongId: utils.RandString(32),
}
if data.Lyrics != "" {
job.Prompt = data.Lyrics
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.sunoService.PushTask(types.SunoTask{
ClientId: data.ClientId,
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Title: job.Title,
RefTaskId: data.RefTaskId,
RefSongId: data.RefSongId,
ExtendSecs: data.ExtendSecs,
Prompt: job.Prompt,
Tags: data.Tags,
Model: data.Model,
Instrumental: data.Instrumental,
SongId: data.SongId,
AudioURL: data.AudioURL,
})
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数
var total int64
session.Model(&model.SunoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.SunoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 初始化续写关系
songIds := make([]string, 0)
for _, v := range list {
if v.RefTaskId != "" {
songIds = append(songIds, v.RefSongId)
}
}
var tasks []model.SunoJob
h.DB.Where("song_id IN ?", songIds).Find(&tasks)
songMap := make(map[string]model.SunoJob)
for _, t := range tasks {
songMap[t.SongId] = t
}
// 转换为 VO
items := make([]vo.SunoJob, 0)
for _, v := range list {
var item vo.SunoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if s, ok := songMap[v.RefSongId]; ok {
item.RefSong = map[string]interface{}{
"id": s.Id,
"title": s.Title,
"cover": s.CoverURL,
"audio": s.AudioURL,
}
}
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *SunoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.SunoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 只有失败,或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
return
}
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 恢复用户算力
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
}
func (h *SunoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *SunoHandler) Update(c *gin.Context) {
var data struct {
Id int `json:"id"`
Title string `json:"title"`
Cover string `json:"cover"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id == 0 || data.Title == "" || data.Cover == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
var item model.SunoJob
if err := h.DB.Where("id", data.Id).Where("user_id", userId).First(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
item.Title = data.Title
item.CoverURL = data.Cover
if err := h.DB.Updates(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Detail 歌曲详情
func (h *SunoHandler) Detail(c *gin.Context) {
songId := c.Query("song_id")
if songId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
var item model.SunoJob
if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
// 读取用户信息
var user model.User
if err := h.DB.Where("id", item.UserId).First(&user).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
var itemVo vo.SunoJob
if err := utils.CopyObject(item, &itemVo); err != nil {
resp.ERROR(c, err.Error())
return
}
itemVo.CreatedAt = item.CreatedAt.Unix()
itemVo.User = map[string]interface{}{
"nickname": user.Nickname,
"avatar": user.Avatar,
}
resp.SUCCESS(c, itemVo)
}
// Play 增加歌曲播放次数
func (h *SunoHandler) Play(c *gin.Context) {
songId := c.Query("song_id")
if songId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
}
const genLyricTemplate = `
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
请以【%s】为主题创作一首歌曲歌曲时间不要太短3分钟左右不要输出任何解释性的内容。
输出格式如下:
歌曲名称
第一节:
{{歌词内容}}
副歌:
{{歌词内容}}
第二节:
{{歌词内容}}
副歌:
{{歌词内容}}
尾声:
{{歌词内容}}
`
// Lyric 生成歌词
func (h *SunoHandler) Lyric(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}

View File

@@ -3,15 +3,52 @@ package handler
import ( import (
"geekai/service" "geekai/service"
"geekai/service/payment" "geekai/service/payment"
"github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
) )
type TestHandler struct { type TestHandler struct {
db *gorm.DB db *gorm.DB
snowflake *service.Snowflake snowflake *service.Snowflake
js *payment.PayJS js *payment.GeekPayService
} }
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler { func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js} return &TestHandler{db: db, snowflake: snowflake, js: js}
} }
func (h *TestHandler) SseTest(c *gin.Context) {
//c.Header("Body-Type", "text/event-stream")
//c.Header("Cache-Control", "no-cache")
//c.Header("Connection", "keep-alive")
//
//
//// 模拟实时数据更新
//for i := 0; i < 10; i++ {
// // 发送 SSE 数据
// _, err := fmt.Fprintf(c.Writer, "data: %v\n\n", data)
// if err != nil {
// return
// }
// c.Writer.Flush() // 确保立即发送数据
// time.Sleep(1 * time.Second) // 每秒发送一次数据
//}
//c.Abort()
}
func (h *TestHandler) PostTest(c *gin.Context) {
var data struct {
Message string `json:"message"`
UserId uint `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 将参数存储在上下文中
c.Set("data", data)
c.Next()
}

View File

@@ -1,101 +0,0 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
userId := h.GetLoginUserId(c)
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}
func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 {
for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
}
resp.SUCCESS(c, files)
}
// Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
@@ -15,7 +16,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"fmt" "github.com/imroc/req/v3"
"strings" "strings"
"time" "time"
@@ -32,6 +33,8 @@ type UserHandler struct {
searcher *xdb.Searcher searcher *xdb.Searcher
redis *redis.Client redis *redis.Client
licenseService *service.LicenseService licenseService *service.LicenseService
captcha *service.CaptchaService
userService *service.UserService
} }
func NewUserHandler( func NewUserHandler(
@@ -39,12 +42,16 @@ func NewUserHandler(
db *gorm.DB, db *gorm.DB,
searcher *xdb.Searcher, searcher *xdb.Searcher,
client *redis.Client, client *redis.Client,
captcha *service.CaptchaService,
userService *service.UserService,
licenseService *service.LicenseService) *UserHandler { licenseService *service.LicenseService) *UserHandler {
return &UserHandler{ return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app}, BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher, searcher: searcher,
redis: client, redis: client,
captcha: captcha,
licenseService: licenseService, licenseService: licenseService,
userService: userService,
} }
} }
@@ -54,14 +61,33 @@ func (h *UserHandler) Register(c *gin.Context) {
var data struct { var data struct {
RegWay string `json:"reg_way"` RegWay string `json:"reg_way"`
Username string `json:"username"` Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
Code string `json:"code"` Code string `json:"code"`
InviteCode string `json:"invite_code"` InviteCode string `json:"invite_code"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
data.Password = strings.TrimSpace(data.Password) data.Password = strings.TrimSpace(data.Password)
if len(data.Password) < 8 { if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符") resp.ERROR(c, "密码长度不能少于8个字符")
@@ -71,15 +97,22 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检测最大注册人数 // 检测最大注册人数
var totalUser int64 var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser) h.DB.Model(&model.User{}).Count(&totalUser)
if int(totalUser) >= h.licenseService.GetLicense().UserNum { if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return return
} }
// 检查验证码 // 检查验证码
var key string var key string
if data.RegWay == "email" || data.RegWay == "mobile" { if data.RegWay == "email" {
key = CodeStorePrefix + data.Username key = CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
} else if data.RegWay == "mobile" {
key = CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误") resp.ERROR(c, "验证码错误")
@@ -97,9 +130,19 @@ func (h *UserHandler) Register(c *gin.Context) {
} }
} }
// check if the username is exists // check if the username is existing
var item model.User var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item) session := h.DB.Session(&gorm.Session{})
if data.Mobile != "" {
session = session.Where("mobile = ?", data.Mobile)
data.Username = data.Mobile
} else if data.Email != "" {
session = session.Where("email = ?", data.Email)
data.Username = data.Email
} else if data.Username != "" {
session = session.Where("username = ?", data.Username)
}
session.First(&item)
if item.Id > 0 { if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册") resp.ERROR(c, "该用户名已经被注册")
return return
@@ -108,8 +151,9 @@ func (h *UserHandler) Register(c *gin.Context) {
salt := utils.RandString(8) salt := utils.RandString(8)
user := model.User{ user := model.User{
Username: data.Username, Username: data.Username,
Mobile: data.Mobile,
Email: data.Email,
Password: utils.GenPassword(data.Password, salt), Password: utils.GenPassword(data.Password, salt),
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Avatar: "/images/avatar/user.png", Avatar: "/images/avatar/user.png",
Salt: salt, Salt: salt,
Status: true, Status: true,
@@ -118,10 +162,19 @@ func (h *UserHandler) Register(c *gin.Context) {
Power: h.App.SysConfig.InitPower, Power: h.App.SysConfig.InitPower,
} }
res = h.DB.Create(&user) // 被邀请人也获得赠送算力
if res.Error != nil { if data.InviteCode != "" {
resp.ERROR(c, "保存数据失败") user.Power += h.App.SysConfig.InvitePower
logger.Error(res.Error) }
if h.licenseService.GetLicense().Configs.DeCopy {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
}
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
resp.ERROR(c, err.Error())
return return
} }
@@ -130,35 +183,35 @@ func (h *UserHandler) Register(c *gin.Context) {
// 增加邀请数量 // 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 { if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower)) err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
// 记录邀请算力充值日志 Type: types.PowerInvite,
var inviter model.User Model: "",
h.DB.Where("id", inviteCode.UserId).First(&inviter) Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
}) })
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
// 添加邀请记录 // 添加邀请记录
h.DB.Create(&model.InviteLog{ err := tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId, InviterId: inviteCode.UserId,
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
InviteCode: inviteCode.Code, InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
}) }).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
} }
tx.Commit()
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码 _ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token // 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id, "user_id": user.Id,
@@ -175,7 +228,7 @@ func (h *UserHandler) Register(c *gin.Context) {
resp.ERROR(c, "error with save token: "+err.Error()) resp.ERROR(c, "error with save token: "+err.Error())
return return
} }
resp.SUCCESS(c, tokenString) resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
} }
// Login 用户登录 // Login 用户登录
@@ -183,20 +236,41 @@ func (h *UserHandler) Login(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
needVerify, err := h.redis.Get(c, verifyKey).Bool()
if h.App.SysConfig.EnabledVerify && needVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var user model.User var user model.User
res := h.DB.Where("username = ?", data.Username).First(&user) res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil { if res.Error != nil {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名不存在") resp.ERROR(c, "用户名不存在")
return return
} }
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
if password != user.Password { if password != user.Password {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名或密码错误") resp.ERROR(c, "用户名或密码错误")
return return
} }
@@ -229,12 +303,14 @@ func (h *UserHandler) Login(c *gin.Context) {
return return
} }
// 保存到 redis // 保存到 redis
key := fmt.Sprintf("users/%d", user.Id) sessionKey := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error()) resp.ERROR(c, "error with save token: "+err.Error())
return return
} }
resp.SUCCESS(c, tokenString) // 移除登录行为验证码
h.redis.Del(c, verifyKey)
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
} }
// Logout 注 销 // Logout 注 销
@@ -246,21 +322,176 @@ func (h *UserHandler) Logout(c *gin.Context) {
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// CLogin 第三方登录请求二维码
func (h *UserHandler) CLogin(c *gin.Context) {
returnURL := h.GetTrim(c, "return_url")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
resp.SUCCESS(c, res.Data)
}
// CLoginCallback 第三方登录回调
func (h *UserHandler) CLoginCallback(c *gin.Context) {
loginType := c.Query("login_type")
code := c.Query("code")
userId := h.GetInt(c, "user_id", 0)
action := c.Query("action")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
// login successfully
data := res.Data.(map[string]interface{})
var user model.User
if action == "bind" && userId > 0 {
err = h.DB.Where("openid", data["openid"]).First(&user).Error
if err == nil {
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
return
}
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.ERROR(c, "绑定用户不存在")
return
}
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": ""})
return
}
session := gin.H{}
tx := h.DB.Where("openid", data["openid"]).First(&user)
if tx.Error != nil {
// create new user
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
salt := utils.RandString(8)
password := fmt.Sprintf("%d", utils.RandomNumber(8))
user = model.User{
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
Password: utils.GenPassword(password, salt),
Avatar: fmt.Sprintf("%s", data["avatar"]),
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
Power: h.App.SysConfig.InitPower,
OpenId: fmt.Sprintf("%s", data["openid"]),
Nickname: fmt.Sprintf("%s", data["nickname"]),
}
tx = h.DB.Create(&user)
if tx.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(tx.Error)
return
}
session["username"] = user.Username
session["password"] = password
} else { // login directly
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
}
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
session["token"] = tokenString
resp.SUCCESS(c, session)
}
// Session 获取/验证会话 // Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) { func (h *UserHandler) Session(c *gin.Context) {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err == nil { if err != nil {
var userVo vo.User resp.NotAuth(c, err.Error())
err := utils.CopyObject(user, &userVo) return
if err != nil {
resp.ERROR(c)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} else {
resp.NotAuth(c)
} }
var userVo vo.User
err = utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 用户 VIP 到期
if user.ExpiredTime > 0 && user.ExpiredTime < time.Now().Unix() {
h.DB.Model(&user).UpdateColumn("vip", false)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} }
type userProfile struct { type userProfile struct {
@@ -347,20 +578,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
} }
newPass := utils.GenPassword(data.Password, user.Salt) newPass := utils.GenPassword(data.Password, user.Salt)
res := h.DB.Model(&user).UpdateColumn("password", newPass) err = h.DB.Model(&user).UpdateColumn("password", newPass).Error
if res.Error != nil { if err != nil {
logger.Error("更新数据库失败: ", res.Error) resp.ERROR(c, err.Error())
resp.ERROR(c, "更新数据库失败")
return return
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// ResetPass 重置密码 // ResetPass 找回密码
func (h *UserHandler) ResetPass(c *gin.Context) { func (h *UserHandler) ResetPass(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Type string `json:"type"` // 验证类别mobile, email
Mobile string `json:"mobile"` // 手机号
Email string `json:"email"` // 邮箱地址
Code string `json:"code"` // 验证码 Code string `json:"code"` // 验证码
Password string `json:"password"` // 新密码 Password string `json:"password"` // 新密码
} }
@@ -369,37 +601,47 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
return return
} }
session := h.DB.Session(&gorm.Session{})
var key string
if data.Type == "email" {
session = session.Where("email", data.Email)
key = CodeStorePrefix + data.Email
} else if data.Type == "mobile" {
session = session.Where("mobile", data.Mobile)
key = CodeStorePrefix + data.Mobile
} else {
resp.ERROR(c, "验证类别错误")
return
}
var user model.User var user model.User
res := h.DB.Where("username", data.Username).First(&user) err := session.First(&user).Error
if res.Error != nil { if err != nil {
resp.ERROR(c, "用户不存在!") resp.ERROR(c, "用户不存在!")
return return
} }
// 检查验证码 // 检查验证码
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误") resp.ERROR(c, "验证码错误")
return return
} }
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
user.Password = password err = h.DB.Model(&user).UpdateColumn("password", password).Error
res = h.DB.Updates(&user) if err != nil {
if res.Error != nil { resp.ERROR(c, err.Error())
resp.ERROR(c)
} else { } else {
h.redis.Del(c, key) h.redis.Del(c, key)
resp.SUCCESS(c) resp.SUCCESS(c)
} }
} }
// BindUsername 重置账 // BindMobile 绑定手机
func (h *UserHandler) BindUsername(c *gin.Context) { func (h *UserHandler) BindMobile(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Mobile string `json:"mobile"`
Code string `json:"code"` Code string `json:"code"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -407,7 +649,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
} }
// 检查验证码 // 检查验证码
key := CodeStorePrefix + data.Username key := CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误") resp.ERROR(c, "验证码错误")
@@ -416,21 +658,56 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定 // 检查手机号是否被其他账号绑定
var item model.User var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item) res := h.DB.Where("mobile", data.Mobile).First(&item)
if res.Error == nil { if res.Error == nil {
resp.ERROR(c, "该号已经其他账号绑定") resp.ERROR(c, "该手机号已经绑定了其他账号,请更换手机号")
return return
} }
user, err := h.GetLoginUser(c) userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("mobile", data.Mobile).Error
if err != nil { if err != nil {
resp.NotAuth(c) resp.ERROR(c, err.Error())
return return
} }
res = h.DB.Model(&user).UpdateColumn("username", data.Username) _ = h.redis.Del(c, key) // 删除短信验证码
if res.Error != nil { resp.SUCCESS(c)
resp.ERROR(c, "更新数据库失败") }
// BindEmail 绑定邮箱
func (h *UserHandler) BindEmail(c *gin.Context) {
var data struct {
Email string `json:"email"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检查验证码
key := CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.DB.Where("email", data.Email).First(&item)
if res.Error == nil {
resp.ERROR(c, "该邮箱地址已经绑定了其他账号,请更邮箱地址")
return
}
userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("email", data.Email).Error
if err != nil {
resp.ERROR(c, err.Error())
return return
} }

View File

@@ -0,0 +1,227 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/video"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type VideoHandler struct {
BaseHandler
videoService *video.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
return &VideoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
videoService: service,
uploader: uploader,
userService: userService,
}
}
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"`
ExpandPrompt bool `json:"expand_prompt,omitempty"`
Loop bool `json:"loop,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
if data.Prompt == "" {
resp.ERROR(c, "prompt is needed")
return
}
userId := int(h.GetLoginUserId(c))
params := types.VideoParams{
PromptOptimize: data.ExpandPrompt,
Loop: data.Loop,
StartImgURL: data.FirstFrameImg,
EndImgURL: data.EndFrameImg,
}
// 插入数据库
job := model.VideoJob{
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
Params: utils.JsonEncode(params),
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.videoService.PushTask(types.VideoTask{
ClientId: data.ClientId,
Id: job.Id,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
})
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "luma",
Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *VideoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
t := c.Query("type")
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
all := h.GetBool(c, "all")
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
if t != "" {
session = session.Where("type", t)
}
if all {
session = session.Where("publish", 0).Where("progress", 100)
} else {
session = session.Where("user_id", h.GetLoginUserId(c))
}
// 统计总数
var total int64
session.Model(&model.VideoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.VideoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 转换为 VO
items := make([]vo.VideoJob, 0)
for _, v := range list {
var item vo.VideoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if item.VideoURL == "" {
item.VideoURL = v.WaterURL
}
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *VideoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 只有失败或者超时的任务才能删除
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
return
}
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 恢复算力
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "luma",
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
}
func (h *VideoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.DB.Model(&job).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

150
api/handler/ws_handler.go Normal file
View File

@@ -0,0 +1,150 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
)
// Websocket 连接处理 handler
type WebsocketHandler struct {
BaseHandler
wsService *service.WebsocketService
chatHandler *ChatHandler
}
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
return &WebsocketHandler{
BaseHandler: BaseHandler{App: app, DB: db},
chatHandler: chatHandler,
wsService: s,
}
}
func (h *WebsocketHandler) Client(c *gin.Context) {
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: strings.Split(clientProtocols, ","),
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
clientId := c.Query("client_id")
client := types.NewWsClient(ws, clientId)
userId := h.GetLoginUserId(c)
if userId == 0 {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
h.wsService.Clients.Put(clientId, client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.wsService.Clients.Delete(clientId)
break
}
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
logger.Debugf("Receive a message:%+v", message)
if message.Type == types.MsgTypePing {
utils.SendChannelMsg(client, types.ChPing, "pong")
continue
}
// 当前只处理聊天消息,其他消息全部丢弃
var chatMessage types.ChatMessage
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
if err != nil || message.Channel != types.ChChat {
logger.Warnf("invalid message body:%+v", message.Body)
continue
}
var chatRole model.ChatRole
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
if err != nil || !chatRole.Enable {
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
continue
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
chatMessage.RoleId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
if err != nil || chatModel.Enabled == false {
utils.SendAndFlush(client, "当前AI模型暂未启用请更换模型后再发起对话")
continue
}
session := &types.ChatSession{
ClientIP: c.ClientIP(),
UserId: userId,
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
if chat.Id > 0 {
chatModel.Id = chat.ModelId
chatMessage.RoleId = int(chat.RoleId)
}
session.ChatId = chatMessage.ChatId
session.Tools = chatMessage.Tools
session.Stream = chatMessage.Stream
// 复制模型数据
err = utils.CopyObject(chatModel, &session.Model)
if err != nil {
logger.Error(err, chatModel)
}
ctx, cancel := context.WithCancel(context.Background())
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
if err != nil {
logger.Error(err)
utils.SendAndFlush(client, err.Error())
} else {
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
logger.Infof("回答完毕: %v", message.Body)
}
}
}()
}

View File

@@ -14,7 +14,6 @@ import (
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/handler/admin" "geekai/handler/admin"
"geekai/handler/chatimpl"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/dalle" "geekai/service/dalle"
@@ -23,7 +22,8 @@ import (
"geekai/service/payment" "geekai/service/payment"
"geekai/service/sd" "geekai/service/sd"
"geekai/service/sms" "geekai/service/sms"
"geekai/service/wx" "geekai/service/suno"
"geekai/service/video"
"geekai/store" "geekai/store"
"io" "io"
"log" "log"
@@ -127,10 +127,10 @@ func main() {
// 创建控制器 // 创建控制器
fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler), fx.Provide(handler.NewUserHandler),
fx.Provide(chatimpl.NewChatHandler), fx.Provide(handler.NewChatHandler),
fx.Provide(handler.NewUploadHandler), fx.Provide(handler.NewNetHandler),
fx.Provide(handler.NewSmsHandler), fx.Provide(handler.NewSmsHandler),
fx.Provide(handler.NewRewardHandler), fx.Provide(handler.NewRedeemHandler),
fx.Provide(handler.NewCaptchaHandler), fx.Provide(handler.NewCaptchaHandler),
fx.Provide(handler.NewMidJourneyHandler), fx.Provide(handler.NewMidJourneyHandler),
fx.Provide(handler.NewChatModelHandler), fx.Provide(handler.NewChatModelHandler),
@@ -145,8 +145,8 @@ func main() {
fx.Provide(admin.NewAdminHandler), fx.Provide(admin.NewAdminHandler),
fx.Provide(admin.NewApiKeyHandler), fx.Provide(admin.NewApiKeyHandler),
fx.Provide(admin.NewUserHandler), fx.Provide(admin.NewUserHandler),
fx.Provide(admin.NewChatRoleHandler), fx.Provide(admin.NewChatAppHandler),
fx.Provide(admin.NewRewardHandler), fx.Provide(admin.NewRedeemHandler),
fx.Provide(admin.NewDashboardHandler), fx.Provide(admin.NewDashboardHandler),
fx.Provide(admin.NewChatModelHandler), fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler), fx.Provide(admin.NewProductHandler),
@@ -160,13 +160,12 @@ func main() {
return service.NewCaptchaService(config.ApiConfig) return service.NewCaptchaService(config.ApiConfig)
}), }),
fx.Provide(oss.NewUploaderManager), fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService), fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) { fx.Invoke(func(s *dalle.Service) {
service.Run() s.Run()
service.CheckTaskNotify() s.CheckTaskNotify()
service.DownloadImages() s.DownloadImages()
service.CheckTaskStatus() s.CheckTaskStatus()
}), }),
// 邮件服务 // 邮件服务
@@ -177,39 +176,43 @@ func main() {
licenseService.SyncLicense() licenseService.SyncLicense()
}), }),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
if config.WeChatBot {
err := bot.Run()
if err != nil {
logger.Error("微信登录失败:", err)
}
}
}),
// MidJourney service pool // MidJourney service pool
fx.Provide(mj.NewServicePool), fx.Provide(mj.NewService),
fx.Invoke(func(pool *mj.ServicePool) { fx.Provide(mj.NewClient),
if pool.HasAvailableService() { fx.Invoke(func(s *mj.Service) {
pool.DownloadImages() s.Run()
pool.CheckTaskNotify() s.SyncTaskProgress()
pool.SyncTaskProgress() s.CheckTaskNotify()
} s.DownloadImages()
}), }),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewServicePool), fx.Provide(sd.NewService),
fx.Invoke(func(pool *sd.ServicePool) { fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
if pool.HasAvailableService() { s.Run()
pool.CheckTaskNotify() s.CheckTaskStatus()
pool.CheckTaskStatus() s.CheckTaskNotify()
}
}), }),
fx.Provide(suno.NewService),
fx.Invoke(func(s *suno.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(video.NewService),
fx.Invoke(func(s *video.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay), fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewPayJS), fx.Provide(payment.NewJPayService),
fx.Provide(payment.NewWechatService),
fx.Provide(service.NewSnowflake), fx.Provide(service.NewSnowflake),
fx.Provide(service.NewXXLJobExecutor), fx.Provide(service.NewXXLJobExecutor),
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) { fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
@@ -222,8 +225,9 @@ func main() {
// 注册路由 // 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
group := s.Engine.Group("/api/role/") group := s.Engine.Group("/api/app/")
group.GET("list", h.List) group.GET("list", h.List)
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateRole) group.POST("update", h.UpdateRole)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) { fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
@@ -235,12 +239,14 @@ func main() {
group.GET("profile", h.Profile) group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate) group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass) group.POST("password", h.UpdatePass)
group.POST("bind/username", h.BindUsername) group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.POST("resetPass", h.ResetPass) group.POST("resetPass", h.ResetPass)
group.GET("clogin", h.CLogin)
group.GET("clogin/callback", h.CLoginCallback)
}), }),
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
group := s.Engine.Group("/api/chat/") group := s.Engine.Group("/api/chat/")
group.Any("new", h.ChatHandle)
group.GET("list", h.List) group.GET("list", h.List)
group.GET("detail", h.Detail) group.GET("detail", h.Detail)
group.POST("update", h.Update) group.POST("update", h.Update)
@@ -250,10 +256,11 @@ func main() {
group.POST("tokens", h.Tokens) group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate) group.GET("stop", h.StopGenerate)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
s.Engine.POST("/api/upload", h.Upload) s.Engine.POST("/api/upload", h.Upload)
s.Engine.GET("/api/upload/list", h.List) s.Engine.POST("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove) s.Engine.GET("/api/upload/remove", h.Remove)
s.Engine.GET("/api/download", h.Download)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/") group := s.Engine.Group("/api/sms/")
@@ -266,42 +273,42 @@ func main() {
group.GET("slide/get", h.SlideGet) group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck) group.POST("slide/check", h.SlideCheck)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
group := s.Engine.Group("/api/reward/") group := s.Engine.Group("/api/redeem/")
group.POST("verify", h.Verify) group.POST("verify", h.Verify)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/") group := s.Engine.Group("/api/mj/")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.POST("upscale", h.Upscale) group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation) group.POST("variation", h.Variation)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall) group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove) group.GET("remove", h.Remove)
group.POST("publish", h.Publish) group.GET("publish", h.Publish)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall) group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove) group.GET("remove", h.Remove)
group.POST("publish", h.Publish) group.GET("publish", h.Publish)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/") group := s.Engine.Group("/api/config/")
group.GET("get", h.Get) group.GET("get", h.Get)
group.GET("license", h.License)
}), }),
// 管理后台控制器 // 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/") group := s.Engine.Group("/api/admin/config")
group.POST("config/update", h.Update) group.POST("update", h.Update)
group.GET("config/get", h.Get) group.GET("get", h.Get)
group.POST("active", h.Active) group.POST("active", h.Active)
group.GET("config/get/license", h.GetLicense) group.GET("fixData", h.FixData)
group.GET("license", h.GetLicense)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/") group := s.Engine.Group("/api/admin/")
@@ -329,18 +336,20 @@ func main() {
group.GET("loginLog", h.LoginLog) group.GET("loginLog", h.LoginLog)
group.POST("resetPass", h.ResetPass) group.POST("resetPass", h.ResetPass)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ChatRoleHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
group := s.Engine.Group("/api/admin/role/") group := s.Engine.Group("/api/admin/role/")
group.GET("list", h.List) group.GET("list", h.List)
group.POST("save", h.Save) group.POST("save", h.Save)
group.POST("sort", h.Sort) group.POST("sort", h.Sort)
group.POST("set", h.Set) group.POST("set", h.Set)
group.POST("remove", h.Remove) group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
group := s.Engine.Group("/api/admin/reward/") group := s.Engine.Group("/api/admin/redeem/")
group.GET("list", h.List) group.GET("list", h.List)
group.POST("remove", h.Remove) group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/") group := s.Engine.Group("/api/admin/dashboard/")
@@ -360,14 +369,12 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/") group := s.Engine.Group("/api/payment/")
group.GET("doPay", h.DoPay) group.POST("doPay", h.Pay)
group.GET("payWays", h.GetPayWays) group.GET("payWays", h.GetPayWays)
group.POST("query", h.OrderQuery) group.POST("notify/alipay", h.AlipayNotify)
group.POST("qrcode", h.PayQrcode) group.GET("notify/geek", h.GeekPayNotify)
group.POST("mobile", h.Mobile) group.POST("notify/wechat", h.WechatPayNotify)
group.POST("alipay/notify", h.AlipayNotify) group.POST("notify/hupi", h.HuPiPayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/") group := s.Engine.Group("/api/admin/product/")
@@ -381,10 +388,12 @@ func main() {
group := s.Engine.Group("/api/admin/order/") group := s.Engine.Group("/api/admin/order/")
group.POST("list", h.List) group.POST("list", h.List)
group.GET("remove", h.Remove) group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) { fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
group := s.Engine.Group("/api/order/") group := s.Engine.Group("/api/order/")
group.POST("list", h.List) group.GET("list", h.List)
group.GET("query", h.Query)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
group := s.Engine.Group("/api/product/") group := s.Engine.Group("/api/product/")
@@ -395,7 +404,7 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) { fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/") group := s.Engine.Group("/api/invite/")
group.GET("code", h.Code) group.GET("code", h.Code)
group.POST("list", h.List) group.GET("list", h.List)
group.GET("hits", h.Hits) group.GET("hits", h.Hits)
}), }),
@@ -409,13 +418,6 @@ func main() {
group.GET("token", h.GenToken) group.GET("token", h.GenToken)
}), }),
// 验证码
fx.Provide(admin.NewCaptchaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
group := s.Engine.Group("/api/admin/login/")
group.GET("captcha", h.GetCaptcha)
}),
fx.Provide(admin.NewUploadHandler), fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) { fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload) s.Engine.POST("/api/admin/upload", h.Upload)
@@ -427,6 +429,7 @@ func main() {
group.POST("weibo", h.WeiBo) group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao) group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3) group.POST("dalle3", h.Dall3)
group.GET("list", h.List)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/") group := s.Engine.Group("/api/admin/chat/")
@@ -460,24 +463,67 @@ func main() {
}), }),
fx.Provide(handler.NewMarkMapHandler), fx.Provide(handler.NewMarkMapHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
group := s.Engine.Group("/api/markMap/") s.Engine.POST("/api/markMap/gen", h.Generate)
group.Any("client", h.Client)
}), }),
fx.Provide(handler.NewDallJobHandler), fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall") group := s.Engine.Group("/api/dall")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall) group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove) group.GET("remove", h.Remove)
group.POST("publish", h.Publish) group.GET("publish", h.Publish)
}),
fx.Provide(handler.NewSunoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
group := s.Engine.Group("/api/suno")
group.POST("create", h.Create)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.POST("update", h.Update)
group.GET("detail", h.Detail)
group.GET("play", h.Play)
group.POST("lyric", h.Lyric)
}),
fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.POST("luma/create", h.LumaCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Provide(admin.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
group := s.Engine.Group("/api/admin/app/type")
group.POST("save", h.Save)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}),
fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type")
group.GET("list", h.List)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")
group.Any("sse", h.PostTest, h.SseTest)
}),
fx.Provide(service.NewWebsocketService),
fx.Provide(handler.NewWebsocketHandler),
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
s.Engine.Any("/api/ws", h.Client)
}), }),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) { fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() { go func() {
err := s.Run(db) err := s.Run(db)
if err != nil { if err != nil {
log.Fatal(err) logger.Error(err)
os.Exit(0)
} }
}() }()
}), }),
@@ -493,6 +539,25 @@ func main() {
}, },
}) })
}), }),
fx.Provide(admin.NewImageHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
group := s.Engine.Group("/api/admin/image")
group.POST("/list/mj", h.MjList)
group.POST("/list/sd", h.SdList)
group.POST("/list/dall", h.DallList)
group.GET("/remove", h.Remove)
}),
fx.Provide(admin.NewMediaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
group := s.Engine.Group("/api/admin/media")
group.POST("/list/suno", h.SunoList)
group.POST("/list/luma", h.LumaList)
group.GET("/remove", h.Remove)
}),
fx.Provide(handler.NewRealtimeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
s.Engine.Any("/api/realtime", h.Connection)
}),
) )
// 启动应用程序 // 启动应用程序
go func() { go func() {

BIN
api/res/img/geek-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
api/res/img/qq-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@@ -8,16 +8,15 @@ package dalle
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"errors"
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"time" "time"
@@ -35,27 +34,32 @@ type Service struct {
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client userService *service.UserService
wsService *service.WebsocketService
clientIds map[uint]string
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](), wsService: wsService,
uploadManager: manager, uploadManager: manager,
userService: userService,
clientIds: map[uint]string{},
} }
} }
// PushTask push a new mj task in to task queue // PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.DallTask) { func (s *Service) PushTask(task types.DallTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task) logger.Infof("add a new DALL-E task to the task list: %+v", task)
s.taskQueue.RPush(task) s.taskQueue.RPush(task)
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Info("Starting DALL-E job consumer...")
go func() { go func() {
for { for {
var task types.DallTask var task types.DallTask
@@ -64,15 +68,16 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
continue continue
} }
logger.Infof("handle a new DALL-E task: %+v", task)
s.clientIds[task.JobId] = task.ClientId
_, err = s.Image(task, false) _, err = s.Image(task, false)
if err != nil { if err != nil {
logger.Errorf("error with image task: %v", err) logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": -1, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
} }
} }
}() }()
@@ -108,13 +113,12 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task) logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt prompt := task.Prompt
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
if err != nil { if err == nil {
return "", fmt.Errorf("error with translate prompt: %v", err) prompt = content
logger.Debugf("重写后提示词:%s", prompt)
} }
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
} }
var user model.User var user model.User
@@ -123,14 +127,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", errors.New("insufficient of power") return "", errors.New("insufficient of power")
} }
// 扣减算力
err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
})
if err != nil {
return "", fmt.Errorf("error with decrease power: %v", err)
}
// get image generation API KEY // get image generation API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI). err = s.db.Where("type", "dalle").
Where("type", "img").
Where("enabled", true). Where("enabled", true).
Order("last_used_at ASC").First(&apiKey) Order("last_used_at ASC").First(&apiKey).Error
if tx.Error != nil { if err != nil {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error) return "", fmt.Errorf("no available DALL-E api key: %v", err)
} }
var res imgRes var res imgRes
@@ -138,36 +151,42 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 { if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R() s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
} }
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). reqBody := imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value). SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{ SetBody(reqBody).
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
Style: task.Style,
Quality: task.Quality,
}).
SetErrorResult(&errRes). SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL) SetSuccessResult(&res).
Post(apiURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error with send request: %v", err) return "", fmt.Errorf("error with send request: %v", err)
} }
if r.IsErrorState() { if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error) return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
} }
// update the api key last use time // update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress // update task progress
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100, "progress": 100,
"org_url": res.Data[0].Url, "org_url": res.Data[0].Url,
"prompt": prompt, "prompt": prompt,
}) }).Error
if err != nil {
return "", fmt.Errorf("err with update database: %v", err)
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
var content string var content string
if sync { if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
@@ -177,25 +196,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL) content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
} }
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil return content, nil
} }
@@ -203,19 +203,42 @@ func (s *Service) CheckTaskNotify() {
go func() { go func() {
logger.Info("Running DALL-E task notify checking ...") logger.Info("Running DALL-E task notify checking ...")
for { for {
var message sd.NotifyMessage var message service.NotifyMessage
err := s.notifyQueue.LPop(&message) err := s.notifyQueue.LPop(&message)
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChDall, message.Message)
if err != nil { }
}()
}
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running DALL-E task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue continue
} }
for _, job := range jobs {
// 超时的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
time.Sleep(time.Second * 10)
} }
}() }()
} }
@@ -240,6 +263,8 @@ func (s *Service) DownloadImages() {
if err != nil { if err != nil {
logger.Error("error with download image: %s, error: %v", imgURL, err) logger.Error("error with download image: %s, error: %v", imgURL, err)
continue continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
} }
} }
@@ -251,7 +276,7 @@ func (s *Service) DownloadImages() {
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) { func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image // sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false) imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -261,47 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil { if res.Error != nil {
return "", err return "", err
} }
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil return imgURL, nil
} }
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@@ -8,50 +8,41 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/store" "geekai/store"
"strings"
"time" "time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"github.com/shirou/gopsutil/host"
) )
type LicenseService struct { type LicenseService struct {
config types.ApiConfig config types.ApiConfig
levelDB *store.LevelDB levelDB *store.LevelDB
license *types.License license *types.License
urlWhiteList []string urlWhiteList []string
machineId string machineId string
} }
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService { func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License var license types.License
var machineId string
_ = levelDB.Get(types.LicenseKey, &license)
info, err := host.Info()
if err == nil {
machineId = info.HostID
}
logger.Infof("License: %+v", license)
return &LicenseService{ return &LicenseService{
config: server.Config.ApiConfig, config: server.Config.ApiConfig,
levelDB: levelDB, levelDB: levelDB,
license: &license, license: &license,
machineId: machineId, machineId: "",
} }
} }
type License struct { type License struct {
Name string `json:"name"` Name string `json:"name"`
License string `json:"license"` License string `json:"license"`
Mid string `json:"mid"` MachineId string `json:"mid"`
ActiveAt int64 `json:"active_at"` ActiveAt int64 `json:"active_at"`
ExpiredAt int64 `json:"expired_at"` ExpiredAt int64 `json:"expired_at"`
UserNum int `json:"user_num"` UserNum int `json:"user_num"`
Configs types.LicenseConfig `json:"configs"`
} }
// ActiveLicense 激活 License // ActiveLicense 激活 License
@@ -80,7 +71,7 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
s.license = &types.License{ s.license = &types.License{
Key: license, Key: license,
MachineId: machineId, MachineId: machineId,
UserNum: res.Data.UserNum, Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt, ExpiredAt: res.Data.ExpiredAt,
IsActive: true, IsActive: true,
} }
@@ -100,7 +91,7 @@ func (s *LicenseService) SyncLicense() {
if err != nil { if err != nil {
retryCounter++ retryCounter++
if retryCounter < 5 { if retryCounter < 5 {
logger.Error(err) logger.Warn(err)
} }
s.license.IsActive = false s.license.IsActive = false
} else { } else {
@@ -118,30 +109,33 @@ func (s *LicenseService) SyncLicense() {
} }
func (s *LicenseService) fetchLicense() (*types.License, error) { func (s *LicenseService) fetchLicense() (*types.License, error) {
var res struct { //var res struct {
Code types.BizCode `json:"code"` // Code types.BizCode `json:"code"`
Message string `json:"message"` // Message string `json:"message"`
Data License `json:"data"` // Data License `json:"data"`
} //}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") //apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
response, err := req.C().R(). //response, err := req.C().R().
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). // SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL) // SetSuccessResult(&res).Post(apiURL)
if err != nil { //if err != nil {
return nil, fmt.Errorf("发送激活请求失败: %v", err) // return nil, fmt.Errorf("发送激活请求失败: %v", err)
} //}
if response.IsErrorState() { //if response.IsErrorState() {
return nil, fmt.Errorf("激活失败:%v", response.Status) // return nil, fmt.Errorf("激活失败:%v", response.Status)
} //}
if res.Code != types.Success { //if res.Code != types.Success {
return nil, fmt.Errorf("激活失败:%v", res.Message) // return nil, fmt.Errorf("激活失败:%v", res.Message)
} //}
return &types.License{ return &types.License{
Key: res.Data.License, Key: "abc",
MachineId: res.Data.Mid, MachineId: "abc",
UserNum: res.Data.UserNum, Configs: types.LicenseConfig{
ExpiredAt: res.Data.ExpiredAt, UserNum: 10000,
DeCopy: false,
},
ExpiredAt: 0,
IsActive: true, IsActive: true,
}, nil }, nil
} }
@@ -175,21 +169,29 @@ func (s *LicenseService) GetLicense() *types.License {
// IsValidApiURL 判断是否合法的中转 URL // IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error { func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行 // 获得许可授权的直接放行
if s.license.IsActive { return nil
if s.license.MachineId != s.machineId { //if s.license.IsActive {
return errors.New("系统使用了盗版的许可证书") // if s.license.MachineId != s.machineId {
} // return errors.New("系统使用了盗版的许可证书")
// }
if time.Now().Unix() > s.license.ExpiredAt { //
return errors.New("系统许可证书已经过期") // if time.Now().Unix() > s.license.ExpiredAt {
} // return errors.New("系统许可证书已经过期")
return nil // }
} // return nil
//}
for _, v := range s.urlWhiteList { //
if strings.HasPrefix(uri, v) { //if len(s.urlWhiteList) == 0 {
return nil // urls, err := s.fetchUrlWhiteList()
} // if err == nil {
} // s.urlWhiteList = urls
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri) // }
//}
//
//for _, v := range s.urlWhiteList {
// if strings.HasPrefix(uri, v) {
// return nil
// }
//}
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
} }

View File

@@ -7,15 +7,28 @@ package mj
// * @Author yangjian102621@163.com // * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "geekai/core/types" import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"time"
type Client interface { "github.com/gin-gonic/gin"
Imagine(task types.MjTask) (ImageRes, error) )
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error) // Client MidJourney client
Upscale(task types.MjTask) (ImageRes, error) type Client struct {
Variation(task types.MjTask) (ImageRes, error) client *req.Client
QueryTask(taskId string) (QueryRes, error) licenseService *service.LicenseService
db *gorm.DB
} }
type ImageReq struct { type ImageReq struct {
@@ -33,13 +46,8 @@ type ImageRes struct {
Description string `json:"description"` Description string `json:"description"`
Properties struct { Properties struct {
} `json:"properties"` } `json:"properties"`
Result string `json:"result"` Result string `json:"result"`
} Channel string `json:"channel,omitempty"`
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} }
type QueryRes struct { type QueryRes struct {
@@ -66,3 +74,177 @@ type QueryRes struct {
Status string `json:"status"` Status string `json:"status"`
SubmitTime int `json:"submitTime"` SubmitTime int `json:"submitTime"`
} }
var logger = logger2.GetLogger()
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
return &Client{
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
db: db,
}
}
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
var res ImageRes
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
if channel != "" {
session = session.Where("api_url", channel)
}
var apiKey model.ApiKey
err := session.Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(body).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
errMsg, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
}
// update the api key last used time
if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
logger.Error("update api key last used time error: ", err)
}
res.Channel = apiKey.ApiURL
return res, nil
}
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
var apiKey model.ApiKey
err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
if err != nil {
return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}

View File

@@ -1,240 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
}
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
}
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &PlusClient{}

View File

@@ -1,218 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"github.com/go-redis/redis/v8"
"time"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, licenseService *service.LicenseService) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
for k, config := range appConfig.MjPlusConfigs {
if config.Enabled == false {
continue
}
err := licenseService.IsValidApiURL(config.ApiURL)
if err != nil {
logger.Error(err)
continue
}
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
plusService.Run()
}()
services = append(services, plusService)
}
for k, config := range appConfig.MjProxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
proxyService := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
proxyService.Run()
}()
services = append(services, proxyService)
}
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
}
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
if res.Error != nil {
continue
}
for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second * 10)
}
}()
}
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if s.Name == name {
return s
}
}
return nil
}

View File

@@ -1,185 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -11,10 +11,11 @@ import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/sd" "geekai/service/oss"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8"
"strings" "strings"
"time" "time"
@@ -23,106 +24,115 @@ import (
// Service MJ 绘画服务 // Service MJ 绘画服务
type Service struct { type Service struct {
Name string // service Name client *Client // MJ Client
Client Client // MJ Client taskQueue *store.RedisQueue
taskQueue *store.RedisQueue notifyQueue *store.RedisQueue
notifyQueue *store.RedisQueue db *gorm.DB
db *gorm.DB wsService *service.WebsocketService
uploaderManager *oss.UploaderManager
clientIds map[uint]string
} }
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
Name: name, db: db,
db: db, taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
taskQueue: taskQueue, notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
notifyQueue: notifyQueue, client: client,
Client: cli, wsService: wsService,
uploaderManager: manager,
clientIds: map[uint]string{},
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name) logger.Info("Starting MidJourney job consumer for service")
for { go func() {
var task types.MjTask for {
err := s.taskQueue.LPop(&task) var task types.MjTask
if err != nil { err := s.taskQueue.LPop(&task)
logger.Errorf("taking task with error: %v", err) if err != nil {
continue logger.Errorf("taking task with error: %v", err)
} continue
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
} }
}
// translate negative prompt // translate prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt)) content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.NegPrompt = content task.Prompt = content
} else { } else {
logger.Warnf("error with translate prompt: %v", err) logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
} }
}
var job model.MidJourneyJob // use fast mode as default
tx := s.db.Where("id = ?", task.Id).First(&job) if task.Mode == "" {
if tx.Error != nil { task.Mode = "fast"
logger.Error("任务不存在任务ID", task.TaskId) }
continue s.clientIds[task.Id] = task.ClientId
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) var job model.MidJourneyJob
var res ImageRes tx := s.db.Where("id = ?", task.Id).First(&job)
switch task.Type { if tx.Error != nil {
case types.TaskImage: logger.Error("任务不存在任务ID", task.TaskId)
res, err = s.Client.Imagine(task) continue
break }
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) { logger.Infof("handle a new MidJourney task: %+v", task)
errMsg := fmt.Sprintf("%v,%s", err, res.Description) var res ImageRes
logger.Error("绘画任务执行失败:", errMsg) switch task.Type {
job.Progress = -1 case types.TaskImage:
job.ErrMsg = errMsg res, err = s.client.Imagine(task)
// update the task progress break
case types.TaskUpscale:
res, err = s.client.Upscale(task)
break
case types.TaskVariation:
res, err = s.client.Variation(task)
break
case types.TaskBlend:
res, err = s.client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = service.FailTaskProgress
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = res.Channel
s.db.Updates(&job) s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
} }
logger.Infof("任务提交成功:%+v", res) }()
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
} }
type CBReq struct { type CBReq struct {
@@ -143,46 +153,6 @@ type CBReq struct {
} `json:"properties"` } `json:"properties"`
} }
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}
func GetImageHash(action string) string { func GetImageHash(action string) string {
split := strings.Split(action, "::") split := strings.Split(action, "::")
if len(split) > 5 { if len(split) > 5 {
@@ -190,3 +160,150 @@ func GetImageHash(action string) string {
} }
return split[len(split)-1] return split[len(split)-1]
} }
func (s *Service) CheckTaskNotify() {
go func() {
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("receive a new mj notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChMj, message.Message)
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[v.Id],
UserId: v.UserId,
JobId: int(v.Id),
Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
err := s.db.Where("progress < ?", 100).Find(&jobs).Error
if err != nil {
continue
}
for _, job := range jobs {
// 10 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
continue
}
if job.ChannelId == "" {
continue
}
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
if err != nil {
logger.Errorf("error with query task: %v", err)
continue
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": task.FailReason,
})
logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFailed})
continue
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
err = s.db.Updates(&job).Error
if err != nil {
logger.Errorf("error with update database: %v", err)
continue
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := service.TaskStatusRunning
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: message})
}
}
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -84,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil }, nil
} }
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var imageData []byte var fileData []byte
var err error var err error
if useProxy { if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL) fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else { } else {
imageData, err = utils.DownloadImage(imageURL, "") fileData, err = utils.DownloadImage(fileURL, "")
} }
if err != nil { if err != nil {
return "", fmt.Errorf("error with download image: %v", err) return "", fmt.Errorf("error with download image: %v", err)
} }
parse, err := url.Parse(imageURL) parse, err := url.Parse(fileURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
fileExt := utils.GetImgExt(parse.Path) fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据 // 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -57,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil }, nil
} }
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
parse, err := url.Parse(imageURL) parse, err := url.Parse(fileURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
@@ -69,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
} }
if useProxy { if useProxy {
err = utils.DownloadFile(imageURL, filePath, s.proxyURL) err = utils.DownloadFile(fileURL, filePath, s.proxyURL)
} else { } else {
err = utils.DownloadFile(imageURL, filePath, "") err = utils.DownloadFile(fileURL, filePath, "")
} }
if err != nil { if err != nil {
return "", fmt.Errorf("error with download image: %v", err) return "", fmt.Errorf("error with download image: %v", err)

View File

@@ -44,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
} }
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var imageData []byte var fileData []byte
var err error var err error
if useProxy { if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL) fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else { } else {
imageData, err = utils.DownloadImage(imageURL, "") fileData, err = utils.DownloadImage(fileURL, "")
} }
if err != nil { if err != nil {
return "", fmt.Errorf("error with download image: %v", err) return "", fmt.Errorf("error with download image: %v", err)
} }
parse, err := url.Parse(imageURL) parse, err := url.Parse(fileURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
@@ -65,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
context.Background(), context.Background(),
s.config.Bucket, s.config.Bucket,
filename, filename,
strings.NewReader(string(imageData)), strings.NewReader(string(fileData)),
int64(len(imageData)), int64(len(fileData)),
minio.PutObjectOptions{ContentType: "image/png"}) minio.PutObjectOptions{ContentType: "image/png"})
if err != nil { if err != nil {
return "", err return "", err
@@ -89,7 +89,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
fileExt := utils.GetImgExt(file.Filename) fileExt := utils.GetImgExt(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Content-Type"), ContentType: file.Header.Get("Body-Type"),
}) })
if err != nil { if err != nil {
return File{}, fmt.Errorf("error uploading to MinIO: %v", err) return File{}, fmt.Errorf("error uploading to MinIO: %v", err)

View File

@@ -93,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
} }
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var imageData []byte var fileData []byte
var err error var err error
if useProxy { if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL) fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else { } else {
imageData, err = utils.DownloadImage(imageURL, "") fileData, err = utils.DownloadImage(fileURL, "")
} }
if err != nil { if err != nil {
return "", fmt.Errorf("error with download image: %v", err) return "", fmt.Errorf("error with download image: %v", err)
} }
parse, err := url.Parse(imageURL) parse, err := url.Parse(fileURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
@@ -113,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
ret := storage.PutRet{} ret := storage.PutRet{}
extra := storage.PutExtra{} extra := storage.PutExtra{}
// 上传文件字节数据 // 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra) err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -23,7 +23,7 @@ type File struct {
} }
type Uploader interface { type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error) PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error) PutUrlFile(url string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error) PutBase64(imageData string) (string, error)
Delete(fileURL string) error Delete(fileURL string) error
} }

View File

@@ -8,12 +8,13 @@ package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"context"
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"github.com/smartwalle/alipay/v3" "github.com/go-pay/gopay"
"log" "github.com/go-pay/gopay/alipay"
"net/url" "net/http"
"os" "os"
) )
@@ -35,93 +36,96 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
return nil, fmt.Errorf("error with read App Private key: %v", err) return nil, fmt.Errorf("error with read App Private key: %v", err)
} }
xClient, err := alipay.New(config.AppId, priKey, !config.SandBox) client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
if err != nil { if err != nil {
return nil, fmt.Errorf("error with initialize alipay service: %v", err) return nil, fmt.Errorf("error with initialize alipay service: %v", err)
} }
if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil { //client.DebugSwitch = gopay.DebugOn // 开启调试模式
return nil, fmt.Errorf("error with loading App PublicKey: %v", err) client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
} SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil { SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
return nil, fmt.Errorf("error with loading alipay RootCert: %v", err)
} if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil { return nil, fmt.Errorf("error with load payment public key: %v", err)
return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err)
} }
return &AlipayService{config: &config, client: xClient}, nil return &AlipayService{config: &config, client: client}, nil
} }
func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) { type AlipayParams struct {
var p = alipay.TradeWapPay{} OutTradeNo string `json:"out_trade_no"`
p.NotifyURL = notifyURL Subject string `json:"subject"`
p.ReturnURL = returnURL TotalFee string `json:"total_fee"`
p.Subject = subject ReturnURL string `json:"return_url"`
p.OutTradeNo = outTradeNo NotifyURL string `json:"notify_url"`
p.TotalAmount = Amount
p.ProductCode = "QUICK_WAP_WAY"
res, err := s.client.TradeWapPay(p)
if err != nil {
return "", err
}
return res.String(), err
} }
func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) { func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
var p = alipay.TradePagePay{} bm := make(gopay.BodyMap)
p.NotifyURL = notifyURL bm.Set("subject", params.Subject)
p.ReturnURL = returnURL bm.Set("out_trade_no", params.OutTradeNo)
p.Subject = subject bm.Set("quit_url", params.ReturnURL)
p.OutTradeNo = outTradeNo bm.Set("total_amount", params.TotalFee)
p.TotalAmount = amount bm.Set("product_code", "QUICK_WAP_WAY")
p.ProductCode = "FAST_INSTANT_TRADE_PAY" return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
res, err := s.client.TradePagePay(p) }
if err != nil {
return "", nil
}
return res.String(), err func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
} }
// TradeVerify 交易验证 // TradeVerify 交易验证
func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo { func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
err := s.client.VerifySign(reqForm) notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil { if err != nil {
log.Println("异步通知验证签名发生错误", err)
return NotifyVo{ return NotifyVo{
Status: 0, Status: Failure,
Message: "异步通知验证签名发生错误", Message: "error with parse notify request: " + err.Error(),
} }
} }
return s.TradeQuery(reqForm.Get("out_trade_no")) _, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
}
return s.TradeQuery(request.Form.Get("out_trade_no"))
} }
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo { func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
var p = alipay.TradeQuery{} bm := make(gopay.BodyMap)
p.OutTradeNo = outTradeNo bm.Set("out_trade_no", outTradeNo)
rsp, err := s.client.TradeQuery(p)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil { if err != nil {
return NotifyVo{ return NotifyVo{
Status: 0, Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(), Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
} }
} }
if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" { if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{ return NotifyVo{
Status: 1, Status: Success,
OutTradeNo: rsp.OutTradeNo, OutTradeNo: rsp.Response.OutTradeNo,
TradeNo: rsp.TradeNo, TradeId: rsp.Response.TradeNo,
Amount: rsp.TotalAmount, Amount: rsp.Response.TotalAmount,
Subject: rsp.Subject, Subject: rsp.Response.Subject,
Message: "OK", Message: "OK",
} }
} else { } else {
return NotifyVo{ return NotifyVo{
Status: 0, Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo, Message: "异步查询验证订单信息发生错误" + outTradeNo,
} }
} }
@@ -134,16 +138,3 @@ func readKey(filename string) (string, error) {
} }
return string(data), nil return string(data), nil
} }
type NotifyVo struct {
Status int
OutTradeNo string
TradeNo string
Amount string
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == 1
}

View File

@@ -0,0 +1,139 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
// GeekPayService Geek 支付服务
type GeekPayService struct {
config *types.GeekPayConfig
}
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
return &GeekPayService{
config: &appConfig.GeekPayConfig,
}
}
type GeekPayParams struct {
Method string `json:"method"` // 接口类型
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
}
// Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
p := map[string]string{
"pid": s.config.AppId,
//"method": params.Method,
"device": params.Device,
"type": params.Type,
"out_trade_no": params.OutTradeNo,
"name": params.Name,
"money": params.Money,
"clientip": params.ClientIP,
"notify_url": params.NotifyURL,
"return_url": params.ReturnURL,
"timestamp": fmt.Sprintf("%d", time.Now().Unix()),
}
p["sign"] = s.Sign(p)
p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p)
}
func (s *GeekPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数
var keys []string
for k := range params {
if params[k] == "" || k == "sign" || k == "sign_type" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
// 构建待签名字符串
var signStr strings.Builder
for _, k := range keys {
signStr.WriteString(k)
signStr.WriteString("=")
signStr.WriteString(params[k])
signStr.WriteString("&")
}
signString := strings.TrimSuffix(signStr.String(), "&") + s.config.PrivateKey
return utils.Md5(signString)
}
type GeekPayResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
TradeNo string `json:"trade_no"`
PayURL string `json:"payurl"`
QrCode string `json:"qrcode"`
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
}
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{}
for k, v := range params {
form.Add(k, v)
}
apiURL := fmt.Sprintf("%s/mapi.php", endpoint)
logger.Infof(apiURL)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // 取消 SSL 证书验证
},
}
client := &http.Client{Transport: tr}
resp, err := client.PostForm(apiURL, form)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
logger.Debugf(string(body))
if err != nil {
return nil, err
}
var r GeekPayResp
err = json.Unmarshal(body, &r)
if err != nil {
return nil, errors.New("当前支付渠道暂不支持")
}
if r.Code != 1 {
return nil, errors.New(r.Msg)
}
return &r, nil
}

View File

@@ -37,7 +37,7 @@ func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
} }
} }
type HuPiPayReq struct { type HuPiPayParams struct {
AppId string `json:"appid"` AppId string `json:"appid"`
Version string `json:"version"` Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"` TradeOrderId string `json:"trade_order_id"`
@@ -49,9 +49,11 @@ type HuPiPayReq struct {
CallbackURL string `json:"callback_url"` CallbackURL string `json:"callback_url"`
Time string `json:"time"` Time string `json:"time"`
NonceStr string `json:"nonce_str"` NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
} }
type HuPiResp struct { type HuPiPayResp struct {
Openid interface{} `json:"openid"` Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"` UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"` URL string `json:"url"`
@@ -60,7 +62,7 @@ type HuPiResp struct {
} }
// Pay 执行支付请求操作 // Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) { func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{} data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10) simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId params.AppId = s.appId
@@ -78,22 +80,22 @@ func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL) apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data) resp, err := http.PostForm(apiURL, data)
if err != nil { if err != nil {
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err) return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
all, err := io.ReadAll(resp.Body) all, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err) return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
} }
var res HuPiResp var res HuPiPayResp
err = utils.JsonDecode(string(all), &res) err = utils.JsonDecode(string(all), &res)
if err != nil { if err != nil {
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err) return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
} }
if res.ErrCode != 0 { if res.ErrCode != 0 {
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg) return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
} }
return res, nil return res, nil
@@ -125,10 +127,10 @@ func (s *HuPiPayService) Sign(params url.Values) string {
} }
// Check 校验订单状态 // Check 校验订单状态
func (s *HuPiPayService) Check(tradeNo string) error { func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{} data := url.Values{}
data.Add("appid", s.appId) data.Add("appid", s.appId)
data.Add("open_order_id", tradeNo) data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10) stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp) data.Add("time", stamp)
data.Add("nonce_str", stamp) data.Add("nonce_str", stamp)

View File

@@ -1,155 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
type PayJS struct {
config *types.JPayConfig
}
func NewPayJS(appConfig *types.AppConfig) *PayJS {
return &PayJS{
config: &appConfig.JPayConfig,
}
}
type JPayReq struct {
TotalFee int `json:"total_fee"`
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"body"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"callback_url"`
}
type JPayReps struct {
OutTradeNo string `json:"out_trade_no"`
OrderId string `json:"payjs_order_id"`
ReturnCode int `json:"return_code"`
ReturnMsg string `json:"return_msg"`
Sign string `json:"Sign"`
TotalFee string `json:"total_fee"`
CodeUrl string `json:"code_url,omitempty"`
Qrcode string `json:"qrcode,omitempty"`
}
func (r JPayReps) IsOK() bool {
return r.ReturnMsg == "SUCCESS"
}
func (js *PayJS) Pay(param JPayReq) JPayReps {
param.NotifyURL = js.config.NotifyURL
var p = url.Values{}
encode := utils.JsonEncode(param)
m := make(map[string]interface{})
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
p.Add(k, fmt.Sprintf("%v", v))
}
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
cli := http.Client{}
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
r, err := cli.PostForm(apiURL, p)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
defer r.Body.Close()
bs, err := io.ReadAll(r.Body)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
var data JPayReps
err = utils.JsonDecode(string(bs), &data)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
return data
}
func (js *PayJS) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *PayJS) sign(params url.Values) string {
params.Del(`sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += "&key=" + js.config.PrivateKey
md5bs := md5.Sum([]byte(src))
md5res := hex.EncodeToString(md5bs[:])
return strings.ToUpper(md5res)
}
// Check 查询订单支付状态
// @param tradeNo 支付平台交易 ID
func (js *PayJS) Check(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{}
params.Add("payjs_order_id", tradeNo)
params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
defer resp.Body.Close()
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ReturnCode int `json:"return_code"`
Status int `json:"status"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ReturnCode == 1 && r.Status == 1 {
return nil
} else {
logger.Errorf("PayJs 支付验证响应:%s", string(body))
return errors.New("order not paid")
}
}

View File

@@ -0,0 +1,19 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -0,0 +1,144 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
"net/http"
"time"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
priKey, err := readKey(config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -1,131 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for _, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB)
// run sd service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
p.db.Delete(&job)
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -10,88 +10,89 @@ package sd
import ( import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"strings" "github.com/go-redis/redis/v8"
"time" "time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
) )
var logger = logger2.GetLogger()
// SD 绘画服务 // SD 绘画服务
type Service struct { type Service struct {
httpClient *req.Client httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
name string // service name wsService *service.WebsocketService
leveldb *store.LevelDB
} }
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{ return &Service{
name: name,
config: config,
httpClient: req.C(), httpClient: req.C(),
taskQueue: taskQueue, taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: notifyQueue, notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
leveldb: levelDB, wsService: wsService,
uploadManager: manager, uploadManager: manager,
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
for { logger.Infof("Starting Stable-Diffusion job consumer")
var task types.SdTask go func() {
err := s.taskQueue.LPop(&task) for {
if err != nil { var task types.SdTask
logger.Errorf("taking task with error: %v", err) err := s.taskQueue.LPop(&task)
continue if err != nil {
} logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt // translate prompt
if utils.HasChinese(task.Params.Prompt) { if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt)) content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.Params.Prompt = content task.Params.Prompt = content
} else { } else {
logger.Warnf("error with translate prompt: %v", err) logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
} }
} }
}()
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
continue
}
}
} }
// Txt2ImgReq 文生图请求实体 // Txt2ImgReq 文生图请求实体
@@ -104,6 +105,7 @@ type Txt2ImgReq struct {
Width int `json:"width"` Width int `json:"width"`
Height int `json:"height"` Height int `json:"height"`
SamplerName string `json:"sampler_name"` SamplerName string `json:"sampler_name"`
Scheduler string `json:"scheduler"`
EnableHr bool `json:"enable_hr,omitempty"` EnableHr bool `json:"enable_hr,omitempty"`
HrScale int `json:"hr_scale,omitempty"` HrScale int `json:"hr_scale,omitempty"`
HrUpscaler string `json:"hr_upscaler,omitempty"` HrUpscaler string `json:"hr_upscaler,omitempty"`
@@ -122,9 +124,8 @@ type Txt2ImgResp struct {
// TaskProgressResp 任务进度响应实体 // TaskProgressResp 任务进度响应实体
type TaskProgressResp struct { type TaskProgressResp struct {
Progress float64 `json:"progress"` Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"` EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
} }
// Txt2Img 文生图 API // Txt2Img 文生图 API
@@ -137,6 +138,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
Width: task.Params.Width, Width: task.Params.Width,
Height: task.Params.Height, Height: task.Params.Height,
SamplerName: task.Params.Sampler, SamplerName: task.Params.Sampler,
Scheduler: task.Params.Scheduler,
ForceTaskId: task.Params.TaskId, ForceTaskId: task.Params.TaskId,
} }
if task.Params.Seed > 0 { if task.Params.Seed > 0 {
@@ -151,12 +153,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
var res Txt2ImgResp var res Txt2ImgResp
var errChan = make(chan error) var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
var apiKey model.ApiKey
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
}
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
logger.Debugf("send image request to %s", apiURL) logger.Debugf("send image request to %s", apiURL)
// send a request to sd api endpoint // send a request to sd api endpoint
go func() { go func() {
response, err := s.httpClient.R(). response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey). SetHeader("Authorization", apiKey.Value).
SetBody(body). SetBody(body).
SetSuccessResult(&res). SetSuccessResult(&res).
Post(apiURL) Post(apiURL)
@@ -169,6 +178,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return return
} }
// update the last used time
apiKey.LastUsedAt = time.Now().Unix()
s.db.Updates(&apiKey)
// 保存 Base64 图片 // 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil { if err != nil {
@@ -183,7 +196,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return return
} }
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1)) task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)}) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt})
errChan <- nil errChan <- nil
}() }()
@@ -197,21 +210,15 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
default: default:
err, resp := s.checkTaskProgress() err, resp := s.checkTaskProgress(apiKey)
// 更新任务进度 // 更新任务进度
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号 // 发送更新状态信号
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -220,11 +227,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
// 执行任务 // 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
var res TaskProgressResp var res TaskProgressResp
response, err := s.httpClient.R(). response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey). SetHeader("Authorization", apiKey.Value).
SetSuccessResult(&res). SetSuccessResult(&res).
Get(apiURL) Get(apiURL)
if err != nil { if err != nil {
@@ -236,3 +243,52 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
return nil, &res return nil, &res
} }
func (s *Service) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChSd, message.Message)
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -1,24 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import logger2 "geekai/logger"
var logger = logger2.GetLogger()
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)

View File

@@ -28,8 +28,8 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
} }
func (s *SmtpService) SendVerifyCode(to string, code int) error { func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := "Geek-AI 注册验证码" subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", code) body := fmt.Sprintf("【%s】您的验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", s.config.AppName, code)
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host) auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
if s.config.UseTls { if s.config.UseTls {

459
api/service/suno/service.go Normal file
View File

@@ -0,0 +1,459 @@
package suno
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
wsService *service.WebsocketService
clientIds map[string]string
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
uploadManager: manager,
wsService: wsService,
clientIds: map[string]string{},
}
}
func (s *Service) PushTask(task types.SunoTask) {
logger.Infof("add a new Suno task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.SunoJob
s.db.Where("task_id", "").Find(&jobs)
for _, v := range jobs {
s.PushTask(types.SunoTask{
Id: v.Id,
Channel: v.Channel,
UserId: v.UserId,
Type: v.Type,
Title: v.Title,
RefTaskId: v.RefTaskId,
RefSongId: v.RefSongId,
Prompt: v.Prompt,
Tags: v.Tags,
Model: v.ModelName,
Instrumental: v.Instrumental,
ExtendSecs: v.ExtendSecs,
})
}
logger.Info("Starting Suno job consumer...")
go func() {
for {
var task types.SunoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
var r RespVo
if task.Type == 3 && task.SongId != "" { // 歌曲拼接
r, err = s.Merge(task)
} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
r, err = s.Upload(task)
} else { // 歌曲创作
r, err = s.Create(task)
}
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Data,
"channel": r.Channel,
})
s.clientIds[r.Data] = task.ClientId
}
}()
}
type RespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data string `json:"data"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"task_id": task.RefTaskId,
"continue_clip_id": task.RefSongId,
"continue_at": task.ExtendSecs,
"make_instrumental": task.Instrumental,
}
// 灵感模式
if task.Type == 1 {
reqBody["gpt_description_prompt"] = task.Prompt
} else { // 自定义模式
reqBody["prompt"] = task.Prompt
reqBody["tags"] = task.Tags
reqBody["mv"] = task.Model
reqBody["title"] = task.Title
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"clip_id": task.SongId,
"is_infill": false,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"url": task.AudioURL,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
logger.Debugf("client id: %+v", s.wsService.Clients)
client := s.wsService.Clients.Get(message.ClientId)
logger.Debugf("%+v", client)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChSuno, message.Message)
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.SunoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
// 下载图片和音频
logger.Infof("try download cover image: %s", v.CoverURL)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
if err != nil {
logger.Errorf("download image with error: %v", err)
continue
}
logger.Infof("try download audio: %s", v.AudioURL)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue
}
v.CoverURL = coverURL
v.AudioURL = audioURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.SunoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
continue
}
if task.Code != "success" {
logger.Errorf("query task with error: %v", task.Message)
continue
}
logger.Debugf("task: %+v", task.Data.Status)
// 任务完成,删除旧任务插入两条新任务
if task.Data.Status == "SUCCESS" {
var jobId = job.Id
var flag = false
tx := s.db.Begin()
for _, v := range task.Data.Data {
job.Id = 0
job.Progress = 102 // 102 表示资源未下载完成
job.Title = v.Title
job.SongId = v.Id
job.Duration = int(v.Metadata.Duration)
job.Prompt = v.Metadata.Prompt
job.Tags = v.Metadata.Tags
job.ModelName = v.ModelName
job.RawData = utils.JsonEncode(v)
job.CoverURL = v.ImageLargeUrl
job.AudioURL = v.AudioUrl
if err = tx.Create(&job).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
break
}
flag = true
}
// 删除旧任务
if flag {
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
continue
}
}
tx.Commit()
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
} else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason
s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
}
}
time.Sleep(time.Second * 10)
}
}()
}
type QueryRespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data struct {
TaskId string `json:"task_id"`
Action string `json:"action"`
Status string `json:"status"`
FailReason string `json:"fail_reason"`
SubmitTime int `json:"submit_time"`
StartTime int `json:"start_time"`
FinishTime int `json:"finish_time"`
Progress string `json:"progress"`
Data []struct {
Id string `json:"id"`
Title string `json:"title"`
Status string `json:"status"`
Metadata struct {
Tags string `json:"tags"`
Type string `json:"type"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Duration float64 `json:"duration"`
ErrorMessage interface{} `json:"error_message"`
} `json:"metadata"`
AudioUrl string `json:"audio_url"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
ModelName string `json:"model_name"`
DisplayName string `json:"display_name"`
ImageLargeUrl string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
} `json:"data"`
} `json:"data"`
}
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "suno").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno")
}
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return QueryRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

View File

@@ -1,4 +1,18 @@
package service package service
const FailTaskProgress = 101
const (
TaskStatusRunning = "RUNNING"
TaskStatusFinished = "FINISH"
TaskStatusFailed = "FAIL"
)
type NotifyMessage struct {
UserId int `json:"user_id"`
ClientId string `json:"client_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@@ -0,0 +1,83 @@
package service
import (
"fmt"
"geekai/core/types"
"geekai/store/model"
"gorm.io/gorm"
"sync"
"time"
)
type UserService struct {
db *gorm.DB
lock sync.Mutex
}
func NewUserService(db *gorm.DB) *UserService {
return &UserService{db: db, lock: sync.Mutex{}}
}
// IncreasePower 增加用户算力
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
if err != nil {
tx.Rollback()
return err
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
// DecreasePower 减少用户算力
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("扣减算力失败:%v", err)
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerSub,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("记录算力日志失败:%v", err)
}
tx.Commit()
return nil
}

345
api/service/video/luma.go Normal file
View File

@@ -0,0 +1,345 @@
package video
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
wsService *service.WebsocketService
clientIds map[uint]string
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
wsService: wsService,
uploadManager: manager,
clientIds: map[uint]string{},
}
}
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.VideoJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var params types.VideoParams
if err := utils.JsonDecode(v.Params, &params); err != nil {
logger.Errorf("unmarshal params failed: %v", err)
continue
}
s.PushTask(types.VideoTask{
Id: v.Id,
Channel: v.Channel,
UserId: v.UserId,
Type: v.Type,
TaskId: v.TaskId,
Prompt: v.Prompt,
Params: params,
})
}
logger.Info("Starting Video job consumer...")
go func() {
for {
var task types.VideoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
}
}()
}
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
CreatedAt time.Time `json:"created_at"`
Video interface{} `json:"video"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
"user_prompt": task.Prompt,
"expand_prompt": task.Params.PromptOptimize,
"loop": task.Params.Loop,
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 && r.StatusCode != 201 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChLuma, message.Message)
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
if v.WaterURL == "" {
continue
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
logger.Infof("download video success: %s", videoURL)
v.WaterURL = videoURL
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
}
logger.Info("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.State == "completed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
}
time.Sleep(time.Second * 10)
}
}()
}
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
}
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
if r.StatusCode != 200 {
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

13
api/service/ws_service.go Normal file
View File

@@ -0,0 +1,13 @@
package service
import "geekai/core/types"
type WebsocketService struct {
Clients *types.LMap[string, *types.WsClient] // clientId => Client
}
func NewWebsocketService() *WebsocketService {
return &WebsocketService{
Clients: types.NewLMap[string, *types.WsClient](),
}
}

View File

@@ -1,101 +0,0 @@
package wx
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
logger2 "geekai/logger"
"geekai/store/model"
"github.com/eatmoreapple/openwechat"
"github.com/skip2/go-qrcode"
"gorm.io/gorm"
"os"
"strconv"
)
// 微信收款机器人
var logger = logger2.GetLogger()
type Bot struct {
bot *openwechat.Bot
token string
db *gorm.DB
}
func NewWeChatBot(db *gorm.DB) *Bot {
bot := openwechat.DefaultBot(openwechat.Desktop)
return &Bot{
bot: bot,
db: db,
}
}
func (b *Bot) Run() error {
logger.Info("Starting WeChat Bot...")
// set message handler
b.bot.MessageHandler = func(msg *openwechat.Message) {
b.messageHandler(msg)
}
// scan code login callback
b.bot.UUIDCallback = b.qrCodeCallBack
debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG"))
if debug {
reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json")
err = b.bot.HotLogin(reloadStorage, true)
} else {
err = b.bot.Login()
}
if err != nil {
return err
}
logger.Info("微信登录成功!")
return nil
}
// message handler
func (b *Bot) messageHandler(msg *openwechat.Message) {
sender, err := msg.Sender()
if err != nil {
return
}
// 只处理微信支付的推送消息
if sender.NickName == "微信支付" ||
msg.MsgType == openwechat.MsgTypeApp ||
msg.AppMsgType == openwechat.AppMsgTypeUrl {
// 解析支付金额
message := parseTransactionMessage(msg.Content)
transaction := extractTransaction(message)
logger.Infof("解析到收款信息:%+v", transaction)
if transaction.TransId != "" {
var item model.Reward
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
if item.Id > 0 {
logger.Error("当前交易 ID 己经存在!")
return
}
res = b.db.Create(&model.Reward{
TxId: transaction.TransId,
Amount: transaction.Amount,
Remark: transaction.Remark,
Status: false,
})
if res.Error != nil {
logger.Errorf("交易保存失败: %v", res.Error)
}
}
}
}
func (b *Bot) qrCodeCallBack(uuid string) {
logger.Info("请使用微信扫描下面二维码登录")
q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
logger.Info(q.ToString(true))
}

View File

@@ -1,112 +0,0 @@
package wx
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/xml"
"net/url"
"strconv"
"strings"
)
// Message 转账消息
type Message struct {
Des string
Url string
}
// Transaction 解析后的交易信息
type Transaction struct {
TransId string `json:"trans_id"` // 微信转账交易 ID
Amount float64 `json:"amount"` // 微信转账交易金额
Remark string `json:"remark"` // 转账备注
}
// 解析微信转账消息
func parseTransactionMessage(xmlData string) *Message {
decoder := xml.NewDecoder(strings.NewReader(xmlData))
message := Message{}
for {
token, err := decoder.Token()
if err != nil {
break
}
switch se := token.(type) {
case xml.StartElement:
var value string
if se.Name.Local == "des" && message.Des == "" {
if err := decoder.DecodeElement(&value, &se); err == nil {
message.Des = strings.TrimSpace(value)
}
break
}
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
if err := decoder.DecodeElement(&value, &se); err == nil {
if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
message.Url = value
}
}
break
}
}
}
// 兼容旧版消息记录
if message.Url == "" {
var msg struct {
XMLName xml.Name `xml:"msg"`
AppMsg struct {
Des string `xml:"des"`
Url string `xml:"url"`
} `xml:"appmsg"`
}
if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil {
message.Url = msg.AppMsg.Url
}
}
return &message
}
// 导出交易信息
func extractTransaction(message *Message) Transaction {
var tx = Transaction{}
// 导出交易金额和备注
lines := strings.Split(message.Des, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
// 解析收款金额
prefix := "收款金额¥"
if strings.HasPrefix(line, prefix) {
if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
tx.Amount = value
continue
}
}
// 解析收款备注
prefix = "付款方备注"
if strings.HasPrefix(line, prefix) {
tx.Remark = line[len(prefix):]
break
}
}
// 解析交易 ID
parse, err := url.Parse(message.Url)
if err == nil {
tx.TransId = parse.Query().Get("id")
if tx.TransId == "" {
tx.TransId = parse.Query().Get("trans_id")
}
}
return tx
}

Some files were not shown because too many files have changed in this diff Show More