Compare commits

..

197 Commits

Author SHA1 Message Date
RockYang
754ba02263 fixed build script 2024-08-01 18:45:53 +08:00
RockYang
7ddf57ae06 merge v4.0.8 2024-08-01 18:09:00 +08:00
RockYang
cc5180a6f7 update readme 2024-08-01 08:52:46 +08:00
RockYang
9f44c34d34 update docs url 2024-07-16 18:15:34 +08:00
RockYang
b793b81768 update geekai image version 2024-07-05 11:09:13 +08:00
RockYang
233f6e00f0 fixed conflicts 2024-06-30 06:09:12 +08:00
RockYang
b7dba68549 optimize ngin configuration for chat-plus.conf 2024-06-30 05:44:39 +08:00
RockYang
bdea12c51a update docker image mirror url 2024-06-19 08:35:05 +08:00
RockYang
a27d9ea259 Merge branch 'main' of gitee.com:blackfox/geekai 2024-06-19 08:22:00 +08:00
RockYang
7cd824c284 update docker image version to 4.0.6 2024-06-15 15:35:32 +08:00
RockYang
e27d95e2b5 update docker image version to 4.0.6 2024-06-15 15:33:38 +08:00
GeekMaster
6839827db0 Merge branch 'main' of https://github.com/yangjian102621/geekai into main 2024-06-12 15:44:26 +08:00
RockYang
d6a04f96fe Merge pull request #208 from mari1995/main
图片墙,选择框单选问题
2024-06-04 08:33:51 +08:00
RockYang
5f820b9dc1 merge v4.0.6 2024-06-03 14:22:08 +08:00
RockYang
3cc2263dc7 fixed bug for function call error None is not of type 'array' 2024-05-30 09:59:44 +08:00
RockYang
f0a3c5d8ae fixed bug for mobile chat share 2024-05-30 08:37:14 +08:00
RockYang
2a4ef27774 add v4.0.8 database sql file 2024-05-29 17:41:37 +08:00
RockYang
2b057f32aa feat: add dalle3 page for h5 2024-05-29 17:25:01 +08:00
RockYang
bc6451026f feat: add system config for enable rand background image for index page 2024-05-29 16:24:56 +08:00
RockYang
99fd596862 feat: add system config for enable rand background image for index page 2024-05-29 16:23:42 +08:00
RockYang
f0959b5df6 fix markdown formula parse plugin 2024-05-29 13:49:45 +08:00
SSMario
6788edbe9d Update Image.vue 2024-05-27 21:23:11 +08:00
SSMario
3895305882 Update Image.vue 2024-05-27 21:00:59 +08:00
RockYang
1b0938b33f micro fixs 2024-05-27 17:39:17 +08:00
SSMario
c2acbaaa94 Update ImagesWall.vue 2024-05-27 17:04:15 +08:00
RockYang
02faff461a fixed bug for dalle prompt translate 2024-05-27 11:42:14 +08:00
RockYang
e18e5a38c6 put model and app selector on the top of chat page 2024-05-24 12:33:22 +08:00
RockYang
2f9b1b7835 fixed bug for payment api authorization 2024-05-24 11:31:38 +08:00
RockYang
717b137a6d chore: use config value for order pay timeout 2024-05-22 18:15:06 +08:00
RockYang
f755bdccae feat: add sign check for PC QR code payment 2024-05-22 17:47:53 +08:00
RockYang
4bba77ab47 extract code for saving chat history 2024-05-22 15:32:44 +08:00
RockYang
6944a32ff3 check if the api url in whitelist for mj plus client 2024-05-22 11:47:04 +08:00
RockYang
5742b40aee fixed bug for mobile chat page change chat model not work 2024-05-21 17:54:03 +08:00
RockYang
7f1ec90748 auto resize the input element rows, when use inputed more than one line 2024-05-21 17:36:47 +08:00
RockYang
4a99be2f15 remove license code 2024-05-21 16:20:29 +08:00
RockYang
bee19392c1 add logs for updating database failed 2024-05-21 11:55:38 +08:00
RockYang
27c816cf3b merge conflicts for v4.0.5 2024-05-21 11:30:40 +08:00
RockYang
0d81776212 update docker image name 2024-05-21 11:21:27 +08:00
RockYang
00d31a2379 update docker image url 2024-05-21 11:03:11 +08:00
RockYang
cccab31c0f rename project name to geekai 2024-05-20 15:14:02 +08:00
RockYang
5d65505ab7 rename project name to geekai 2024-05-20 15:11:14 +08:00
RockYang
3dc7d0516a update database file 2024-05-19 19:37:16 +08:00
RockYang
50335ebc2d remove code for set left component fixed height 2024-05-18 08:07:09 +08:00
RockYang
bcadee7290 refactor login dialog for front page 2024-05-18 00:27:32 +08:00
RockYang
cac3194d5b finished refactor chat page UI 2024-05-17 19:25:38 +08:00
RockYang
4ddf3bf2bf fixed bugs for send message captcha component 2024-05-17 08:41:09 +08:00
RockYang
d45f9fbad6 fixed bugs for send message captcha component 2024-05-16 22:33:00 +08:00
RockYang
d98b08d7cd feat: add top navbar for front page 2024-05-16 20:10:00 +08:00
RockYang
5a8fe5a6cf feat: support add external link menu 2024-05-16 10:53:00 +08:00
RockYang
36c27d6092 handler chat error in the chat entry func 2024-05-15 15:30:34 +08:00
RockYang
3ab29da8f0 add charge link for insufficient of power 2024-05-15 07:10:31 +08:00
RockYang
3699f024f1 fix bug for white-list api key check 2024-05-14 22:30:42 +08:00
RockYang
3d37a3d367 update readme 2024-05-14 18:23:12 +08:00
RockYang
73d8236697 update change log 2024-05-14 18:20:52 +08:00
RockYang
114d0088dc update version 2024-05-14 18:03:58 +08:00
RockYang
43b6665370 refactor: use waterflow component in mj, sd and dall image drawing page 2024-05-13 19:04:00 +08:00
RockYang
5fb9f84182 fixed sd page waterfall component 2024-05-11 18:27:35 +08:00
RockYang
e35c34ad9a enable to update AI Drawing configuarations in admin console page 2024-05-11 17:27:14 +08:00
RockYang
1a4d798f8b fix: markmap do not cost power in front page 2024-05-11 07:08:14 +08:00
RockYang
afb91a7023 optimize chat handler error handle code 2024-05-10 18:26:19 +08:00
RockYang
dc4c1f7877 fix bug: remove chat role failed 2024-05-10 17:38:55 +08:00
RockYang
bbc8fe2b40 fixed bug for dalle3 task not decrease power 2024-05-10 11:18:37 +08:00
RockYang
3c34e8e0e7 fixed bug for dalle3 task not decrease power 2024-05-10 11:17:52 +08:00
RockYang
57c932f07c add toolbar for markmap component 2024-05-10 06:38:34 +08:00
RockYang
922202734a fixed bug for markmap 2024-05-09 21:55:40 +08:00
RockYang
8b3b0139b0 use proxy for downloading discord images 2024-05-09 18:48:53 +08:00
RockYang
31828a3336 add stable diffusion default negtive prompt system config 2024-05-07 16:49:54 +08:00
RockYang
b270960a04 remove license code 2024-05-07 16:41:35 +08:00
RockYang
5c4899df6e upgrade to v4.0.4 2024-05-07 16:32:05 +08:00
RockYang
9a797bb4a5 add stable diffusion default negtive prompt system config 2024-05-07 16:21:31 +08:00
RockYang
b0c9ffc5a6 handle the exception for web front page 2024-05-06 17:39:58 +08:00
RockYang
f527cc5b98 fixed conflicts 2024-05-06 14:44:09 +08:00
RockYang
debe8dc209 chore: change module name to geekai, add copyright in source code 2024-05-06 14:41:27 +08:00
RockYang
2f0215ac87 Update README.md 2024-05-06 10:45:50 +08:00
RockYang
dd5cc206e5 update LICENSE.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-05-05 03:05:23 +00:00
RockYang
142cd553a3 fix bug for waterflow component 2024-05-05 10:52:29 +08:00
RockYang
657ecccee3 update changelog 2024-05-04 21:32:43 +08:00
RockYang
1232c3cd9c fix bug for license synchronize 2024-05-04 21:30:29 +08:00
RockYang
3ac04a3938 opt: optimize mobile images page styles 2024-05-03 08:14:33 +08:00
RockYang
b7abc42209 fix: fix bug for dalle power refund 2024-05-01 08:48:32 +08:00
RockYang
a48179ce0e chore: update docker image name 2024-05-01 07:59:20 +08:00
RockYang
e589f25a05 opt: styles and view micro optimization 2024-05-01 07:40:56 +08:00
RockYang
cc1a3ce343 opt: close unused websocket connections 2024-04-30 22:54:39 +08:00
RockYang
7bb76d581c chore: replace docker image url with AliYun 2024-04-30 19:08:33 +08:00
RockYang
0d733c0be0 feat: change theme for mobile site is ready 2024-04-30 18:57:15 +08:00
RockYang
8b40ac5b5c feat: mobile page refactor is finished 2024-04-29 19:22:00 +08:00
RockYang
24479814e9 opt: add chat config for mobile chat session 2024-04-29 09:39:23 +08:00
RockYang
99df028237 feat: add index page for mobile 2024-04-28 19:09:26 +08:00
RockYang
b354b88876 theme change is ready 2024-04-27 13:13:28 +08:00
RockYang
5e0be4d10e feat: admin console dark theme 2024-04-26 18:10:17 +08:00
RockYang
468b48151f opt: optimize index and login page UI 2024-04-26 16:07:02 +08:00
RockYang
fa5c036041 synchronize license every 10 secs 2024-04-26 14:35:01 +08:00
RockYang
0fdc588167 change index page background 2024-04-25 18:54:33 +08:00
RockYang
2e023cb8dc update readme 2024-04-25 06:27:30 +08:00
RockYang
e933f32d9c remove unused files 2024-04-24 21:22:56 +08:00
RockYang
bd4b0c4d65 show license info in admin active page, optimize markdown generate prompt 2024-04-24 19:00:28 +08:00
RockYang
0b2501c1d8 output the error to chat page directly, replace the common error message 'AI开小差' 2024-04-24 10:10:03 +08:00
RockYang
9d28e62142 fixed bug for markmap 2024-04-23 20:47:06 +08:00
RockYang
c1d892069e release v4.0.4 2024-04-23 18:47:23 +08:00
RockYang
61b2dbc9f1 optimize styles and release v4.0.4 2024-04-23 18:46:32 +08:00
RockYang
be3245666e feat: hide more navigator items 2024-04-22 16:27:53 +08:00
RockYang
dacdd6fe74 feat: support markmap svg export and download as png image 2024-04-22 14:20:51 +08:00
RockYang
6807f7e88a allow users to select a chatApp to chat in chat app list page 2024-04-22 11:18:55 +08:00
RockYang
087f5ab2d1 optimize index page UI 2024-04-22 10:43:33 +08:00
RockYang
47c5a0387b optimize code for remove timeout and failed image drawing job 2024-04-21 21:44:28 +08:00
RockYang
f9da18ad52 image wall page add dalle 2024-04-21 20:42:42 +08:00
RockYang
5c9025ca22 dalle image page is ready 2024-04-21 20:23:47 +08:00
RockYang
d02cb573fd DO NOT refresh finished jobs when job is running 2024-04-20 21:30:55 +08:00
RockYang
caa538a1d0 fixed markdown generating styles 2024-04-19 18:22:45 +08:00
RockYang
b584b4bfb6 support send email use TLS 2024-04-19 12:04:59 +08:00
RockYang
bda335212d Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2024-04-19 10:56:05 +08:00
RockYang
06f4cdc649 fixed bug for QWen response blank quotes 2024-04-19 10:55:29 +08:00
RockYang
336a7d5b56 fixed bug for QWen response blank quotes 2024-04-19 10:30:02 +08:00
RockYang
a0f464830f fixed bug for chat handler doRequest method 2024-04-16 23:49:56 +08:00
RockYang
9bf7fa4081 compatible freeGPT35 API 2024-04-15 21:03:19 +08:00
RockYang
96ead65774 fixed bug for websocket close tip message 2024-04-15 18:15:15 +08:00
RockYang
7ad41927aa feat: markmap function is ready 2024-04-15 17:23:59 +08:00
RockYang
4ca9dfd9c0 fixed upscale and variation action url 2024-04-15 15:14:49 +08:00
RockYang
8a9f386d8f opt: close the old connection for mj and sd clients 2024-04-15 09:34:20 +08:00
RockYang
adfee8bf58 update version 2024-04-15 09:05:54 +08:00
RockYang
fbfa2a71a9 release v4.0.3 2024-04-15 08:26:07 +08:00
RockYang
9a1368ef17 markmap enable to select ai model 2024-04-15 06:16:53 +08:00
RockYang
31b02b97d3 fixed chat page styles 2024-04-12 21:57:41 +08:00
RockYang
42da38c5c3 feat: markmap page view is ready 2024-04-12 18:49:24 +08:00
RockYang
0a01b55713 feat: allow chat model bind a fixed api key 2024-04-12 17:09:22 +08:00
RockYang
3b292c2a12 chore: update build.sh 2024-04-11 21:26:33 +08:00
RockYang
db0ba0d9a0 feat: use custom mode for mj upscale and variation operarions 2024-04-11 17:32:34 +08:00
RockYang
3a23ff6b42 feat: add index page 2024-04-10 18:23:55 +08:00
RockYang
1e9c5adb0a feat: support for freeGPT35 API 2024-04-10 14:49:07 +08:00
RockYang
abab76ccc6 feat: support gpt-4-turbo-2014-04-09 vision function 2024-04-10 11:47:10 +08:00
RockYang
6efd92806f fixed bug for gpt-4-turbo-2024-0409 model function calls 2024-04-10 10:23:45 +08:00
RockYang
cfe333e89f fix bug: remove timeout task ONLY for unfinished(progress < 100) 2024-04-10 06:25:54 +08:00
RockYang
a7237fe62f Merge branch 'main' of gitee.com:blackfox/chatgpt-plus 2024-04-07 18:33:55 +08:00
RockYang
c3c454b7d7 Merge branch 'main' of github.com:yangjian102621/chatgpt-plus 2024-04-07 18:32:20 +08:00
RockYang
d4d708d44b update database files 2024-04-07 18:26:45 +08:00
RockYang
7f0b6a3a46 set the enable status for adding new api key with default value true 2024-04-07 10:17:10 +08:00
RockYang
c2a7c089d2 update change log 2024-04-07 08:40:43 +08:00
RockYang
df5bd4df60 Merge branch 'dev' 2024-04-07 08:03:14 +08:00
RockYang
79b6010104 optimize ngin configuration for chat-plus.conf 2024-04-07 08:02:47 +08:00
RockYang
97b0a98793 fix 404 error with remove api keys 2024-04-07 08:01:25 +08:00
RockYang
5230f90540 fix bug for remove api 404 errorc 2024-04-06 20:36:52 +08:00
RockYang
803db4e895 feat: show bind model on chat role list page 2024-04-05 21:21:28 +08:00
RockYang
7cee9f2ebb feat: support uploading role icon 2024-04-05 17:41:23 +08:00
RockYang
8be9a21efd feat: allow bind a chat model for chat role 2024-04-05 12:51:18 +08:00
RockYang
6a3e26b566 update change log 2024-04-05 06:58:13 +08:00
RockYang
0355c37bef feat: stable diffusion image drawing on mobile is ready 2024-04-03 18:13:48 +08:00
RockYang
9b7ee538c4 feat: midjourney role and style consistency is ready 2024-04-02 19:01:28 +08:00
RockYang
d900a3d08e update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-04-02 10:00:45 +00:00
RockYang
cdf5b66729 Update README.md 2024-04-02 17:59:53 +08:00
RockYang
1cff4b63cd feat: image preview for stable-diffusion task is ready 2024-04-02 17:24:38 +08:00
RockYang
da14309ef9 feat: support midjourney --cref and --sref for role consistency 2024-04-02 14:59:53 +08:00
RockYang
fbb216fe3b feat: update menu icons, add version in site titles 2024-04-01 18:20:00 +08:00
RockYang
95efbd5659 show power for chat and imaging page 2024-03-31 20:49:12 +08:00
RockYang
4596c1049c opt: change the relative path with absolute path for midjourney image uploading 2024-03-31 17:45:22 +08:00
RockYang
b35d95f0c7 Merge branch 'main' into dev 2024-03-30 11:57:31 +08:00
RockYang
01419df998 remove dead code 2024-03-30 11:57:23 +08:00
RockYang
a6c00c42fa fixed conflicts 2024-03-29 18:10:59 +08:00
RockYang
4cc9db7115 update readme file 2024-03-29 18:06:55 +08:00
RockYang
4f1ed54059 update databases 2024-03-29 17:43:38 +08:00
RockYang
8227a73e35 feat: support custom menu 2024-03-29 15:41:58 +08:00
RockYang
adfd8c1939 add tip for mj image buttons 2024-03-28 21:41:49 +08:00
RockYang
8eed7ff534 fix: fix overflow hidden for admin page 2024-03-28 18:51:09 +08:00
RockYang
c79c4e74d0 fix: fix overflow hidden for mobile page 2024-03-28 18:13:33 +08:00
RockYang
f1855fd0a1 fix: use slide captcha for iphone 2024-03-28 15:00:53 +08:00
RockYang
1f964c74e9 feat: add mj_action_power system config item 2024-03-28 09:53:41 +08:00
RockYang
4fb2c5803c feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy 2024-03-27 18:57:15 +08:00
RockYang
b5947545cb feat: auto translate and rewrite prompt for midjourney and stable-diffusion 2024-03-27 13:45:52 +08:00
RockYang
342b76f666 feat: stable-diffusion refactored, replace websocket api with sdapi 2024-03-26 18:23:08 +08:00
RockYang
be8a0ec184 opt: remove global keyup event bind 2024-03-10 09:45:59 +08:00
RockYang
b02e3aad95 docs: update config.toml 2024-03-10 09:45:59 +08:00
RockYang
08eca511ad docs: add database sql file for v3.2.7 2024-03-10 09:45:59 +08:00
RockYang
c34e911596 feat: allow to view chat message in manager console 2024-03-10 09:45:59 +08:00
RockYang
8a452c3072 feat: download image which ai generated in dialog and replace the image url 2024-03-10 09:45:59 +08:00
RockYang
13bfb14107 feat: image-wall page for mobile is ready 2024-03-10 09:45:59 +08:00
RockYang
4188b0969e feat: allow user config third-party platform openai and mj api key 2024-03-10 09:45:59 +08:00
RockYang
0c27795a10 feat: added delete file function 2024-03-10 09:45:59 +08:00
RockYang
d05693c5c1 fix: verifycation component touch event coordinates misplace in iphone browser 2024-03-10 09:45:59 +08:00
RockYang
c0b2063b38 fix: fix bug for regenerate button did not work 2024-03-10 09:45:59 +08:00
RockYang
4d183747b1 fix: Upscale and Variation task overrite each other 2024-03-10 09:45:59 +08:00
RockYang
08fe1b2f75 feat: midjourney mobile page all function is ready 2024-03-10 09:45:59 +08:00
RockYang
db3e8a267e feat: mobile mj list page is ready 2024-03-10 09:45:59 +08:00
RockYang
8fc62682c4 feat: add mj image list component for mobile page. fixed bug for html tag escape 2024-03-10 09:45:59 +08:00
RockYang
75031914a3 feat: add functions mj page for mobile 2024-03-10 09:45:59 +08:00
RockYang
a4c9fdd95a opt: add default extension for mj image 2024-03-10 09:45:59 +08:00
RockYang
6a9bfeb5aa feat: mj for mobile page payout is ready 2024-03-10 09:45:59 +08:00
RockYang
e654766f60 add docs and github link 2024-03-10 09:45:59 +08:00
RockYang
0ef6955f96 opt: enable use cdn url for mj-plus 2024-03-10 09:45:59 +08:00
RockYang
b4501557c9 feat: LaTeX parse is ready 2024-03-10 09:45:59 +08:00
RockYang
a2ed99e6cb feat: add model field for chat_item and and chat_history data table 2024-03-10 09:45:59 +08:00
RockYang
6bd6bb3885 feat: add err_msg field for mj and sd jobs 2024-03-10 09:45:59 +08:00
RockYang
399cf65fc9 feat: blend and swap face function for midjourney-plus is ready 2024-03-10 09:45:59 +08:00
RockYang
24906a6df1 feat: add blend and swapface task implements for midjourney 2024-03-10 09:45:59 +08:00
RockYang
d772bbebe6 opt: refactor chat session page for mobile device 2024-03-10 09:45:59 +08:00
RockYang
14988853a3 opt: optimize chat list page for mobile 2024-03-10 09:45:59 +08:00
RockYang
7b3f16ac9f feat: use vant replace element-plus as mobile UI framework 2024-03-10 09:45:59 +08:00
RockYang
82b2755c18 feat: add websocket heartbeat message for mj page 2024-03-10 09:45:59 +08:00
RockYang
4e4dc4cb73 Update README.md 2024-03-01 23:08:05 +08:00
287 changed files with 23043 additions and 7261 deletions

View File

@@ -1,5 +1,5 @@
name: Bug 报告 🐛
description: chatgpt-plus 提交错误报告
description: geekai 提交错误报告
labels: ['Bug']
body:
- type: checkboxes

View File

@@ -1,5 +1,5 @@
name: 功能优化 🚀
description: chatgpt-plus 提交优化建议
description: geekai 提交优化建议
labels: ['feature']
body:
- type: checkboxes

View File

@@ -1,7 +1,97 @@
# 更新日志
## v4.0.8
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
* 功能优化:当数据库更新失败的时候记录错误日志
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
* Bug修复修复移动端聊天页面模型切换不生效的Bug
* 功能优化给PC端扫码支付增加签名验证和有效期验证
* Bug修复修复支付码生成API权限控制的问题
* Bug修复模型算力设置为0时不扣减用户算力并且不记录算力消费日志
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
* 功能新增H5端支持 Dalle 绘图
## v4.0.7
* 功能优化升级quic-go支持 Go1.21
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
* Bug修复修复弹窗等于图形验证码一直验证失败的问题
* 功能重构:重构前端 UI 页面,增加顶部导航
* 功能优化:优化 Vue 非父子组件之间的通信方式
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
## v4.0.6
* Bug修复修复PC端画廊页面的瀑布流组件样式错乱问题
* 功能新增:给思维导图增加 ToolBar实现思维导图的放大缩小和定位
* Bug修复修复思维导图不扣费的Bug
* Bug修复修复管理后台角色删除失败的Bug
* Bug修复兼容最新版秋叶SD懒人包的 SD API新增 scheduler 参数
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
* Bug修复修复注册用户提示注册人数达到上限的 Bug
* 功能优化将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
## v4.0.5
* 功能优化:已授权系统在后台显示授权信息
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
* 功能优化:优化首页登录注册页面的 UI
* BUG修复修复License验证的逻辑漏洞
* Bug修复后台添加用户的时候密码规则限制跟前台注册保持一致
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
* 功能新增:移动端新增 DALL-E 绘画功能
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
* 功能新增:移动支持免登录预览功能
* Bug修复解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
* Bug修复修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
## v4.0.4
* Bug修复修复统一千问第二句不回复的问题
* 功能优化MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
* 功能新增Dalle AI 绘画功能实现
* Bug修复修复思维导图格式乱码问题
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
* 功能新增:支持从应用列表直接和某个应用对话
* 功能优化优化算力日志的页面和首页的UI
* 功能新增:支持思维导图导出 PNG 图片下载
## v4.0.3
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
* Bug修复兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
* Bug修复修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
* 功能新增:支持上传图片和视觉模型
* 功能优化:优化聊天页面的复制代码按钮样式乱码
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
* 功能新增支持为角色绑定对话模型比如绑定某个角色只能用GPT3.5或者 GPT4
* 功能新增:支持为模型绑定 API KEY比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
* 功能新增:支持管理后台 Logo 修改
## 4.0.2
* 功能新增:支持前端菜单可以配置
* 功能优化:在登录和注册界面标题显示软件版本号
* 功能优化MJ 绘画支持 --sref 和 --cref 图片一致性参数
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
* Bug修复解决因为图片上传使用相对路径而导致融图失败的问题。
* 功能新增:手机端支持 Stable-Diffusion 绘画
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
## v4.0.1
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口SDAPI 兼容各种 stable-diffusion
发行版,稳定性更强一些
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API兼容
MJ-Plus 中转
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
* Bug修复修复 iphone 手机无法通过图形验证码的Bug使用滑动验证码替换
* Bug修复修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
## v4.0.0
非兼容版本重大重构引入算力概念将系统中所有的能力AI对话MJ绘画SD绘画DALL绘画全部使用算力来兑换。
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力一次 GPT4 对话消耗10个算力。一次 MJ
对话消耗15个算力...
* 功能重构:重构整体系统,全部采用算力来进行结算
* 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
@@ -16,6 +106,7 @@
* 功能新增管理后台新增7日内新增用户和新增订单统计
## v3.2.7
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
* Bug修复修复 issue [
@@ -30,6 +121,7 @@
* 功能新增:后台管理新怎对话查看和检索功能
## v3.2.6
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
* 功能优化:兼用旧版本微信收款消息解析
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
@@ -43,16 +135,18 @@
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
* 功能新增:新增短信宝短信平台发送平台集成
## v3.2.5
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 Plus 账号了!!!
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
Plus 账号了!!!
* 功能优化:增强 markdown 图片和引用块解析。
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
* 功能优化function call 兼用中转 API。
* Bug修复修复部分已知的 Bug。
## v3.2.4.1
* 功能新增:新增 PayJs 支付通道
* Bug修复紧急修复后台添加用户失败问题
* Bug修复紧急修复使用中转 API-KEY 无法绘图的问题

214
LICENSE
View File

@@ -1,21 +1,201 @@
MIT License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright (c) 2023 RockYang
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
1. Definitions.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

123
README.md
View File

@@ -1,124 +1,67 @@
# ChatGPT-Plus
# GeekAI
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
**ChatGPT-PLUS** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。主要有如下特性:
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
* 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
* 基于 Websocket 实现,完美的打字机体验。
* 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求
* 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型
* 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用
* 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道
* 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
* 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
主要特性:
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用
- 基于 Websocket 实现,完美的打字机体验
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型
- 支持 Suno 文生音乐
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
绘画函数插件。
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
- [x] 更友好的 UI 界面
- [x] 支持 Dall-E 文生图功能
- [x] 支持文生思维导图
- [x] 支持为模型绑定指定的 API KEY支持为角色绑定指定的模型等功能
- [x] 支持网站 Logo 版权等信息的修改
## 功能截图
### PC 端聊天界面
![ChatGPT Chat Page](/docs/imgs/gpt.gif)
### AI 对话界面
![ChatGPT new Chat Page](/docs/imgs/chat-new.png)
### MidJourney 专业绘画界面
![mid-journey](/docs/imgs/mj_image.jpg)
### Stable-Diffusion 专业绘画页面
![Stable-Diffusion](/docs/imgs/sd_image.jpg)
![Stable-Diffusion](/docs/imgs/sd_image_detail.jpg)
### 绘图作品展
![ChatGPT image_list](/docs/imgs/image-list.png)
### AI应用列表
![ChatGPT-app-list](/docs/imgs/app-list.jpg)
### 会员充值
![会员充值](/docs/imgs/member.png)
### 自动调用函数插件
![ChatGPT function plugin](/docs/imgs/plugin.png)
![ChatGPT function plugin](/docs/imgs/mj.jpg)
### 管理后台
![ChatGPT admin](/docs/imgs/admin_dashboard.png)
![ChatGPT admin](/docs/imgs/admin_config.jpg)
![ChatGPT admin](/docs/imgs/admin_models.jpg)
![ChatGPT admin](/docs/imgs/admin_user.png)
### 移动端 Web 页面
![Mobile chat list](/docs/imgs/mobile_chat_list.png)
![Mobile chat session](/docs/imgs/mobile_chat_session.png)
![Mobile chat setting](/docs/imgs/mobile_user_profile.png)
![Mobile chat setting](/docs/imgs/mobile_pay.png)
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
### 体验地址
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
## 快速部署
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套
```shell
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)"
```
最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
* 前端访问地址http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
* 后台管理地址http://localhost:8080/admin
* 移动端地址http://localhost:8080/mobile
* 初始后台管理账号admin/admin123
* 初始前端体验账号18575670125/12345678
服务启动成功之后不能立刻使用,需要先登录管理后台 -> API-KEY 去添加一个 OpenAI 或者文心一言,科大讯飞等至少一个平台的 API
KEY。
![](https://ai.r9it.com/docs/images/env/admin_api_keys.png)
另外,如果您目前还没有 OpenAI 的 API KEY的推荐您去 https://gpt.bemore.lol 购买,**无需魔法,高速稳定,且价格还远低于 OpenAI
官方**。
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)
## 使用须知
1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
## 项目地址
* Github 地址https://github.com/yangjian102621/chatgpt-plus
* 码云地址https://gitee.com/blackfox/chatgpt-plus
* Github 地址https://github.com/yangjian102621/geekai
* 码云地址https://gitee.com/blackfox/geekai
## 客户端下载
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
## TODOLIST
* [ ] 支持基于知识库的 AI 问答
* [ ] 会员邀请注册推广功能
* [ ] 文生视频,文生歌曲功能
* [ ] 微信支付功能
## 项目文档
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)。
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。**
@@ -146,4 +89,4 @@ KEY。
![打赏](docs/imgs/donate.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/chatgpt-plus&type=Date)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/geekai&type=Date)

View File

@@ -1,5 +1,5 @@
SHELL=/usr/bin/env bash
NAME := chatgpt-plus
NAME := geekai
all: amd64 arm64
amd64:

View File

@@ -1,6 +1,6 @@
Listen = "0.0.0.0:5678"
ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = ""
@@ -64,24 +64,16 @@ WeChatBot = false
SubDir = ""
Domain = ""
[[MjConfigs]]
Enabled = false
UserToken = ""
BotToken = ""
GuildId = ""
ChanelId = ""
UseCDN = false #是否使用反向代理访问设置为true下面的设置才会生效
DiscordAPI = "" # discord API 反代地址
DiscordCDN = "" # mj 图片反代地址
DiscordGateway = "" # discord 机器人反代地址
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[[MjPlusConfigs]]
Enabled = false
ApiURL = "https://api.chat-plus.net"
CdnURL = "" # CND 加速的 URL如果有的话就设置
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
NotifyURL = "https://ai.r9it.com/api/mj/notify" # 这里需要改成你的域名
[[SdConfigs]]
Enabled = false
@@ -116,7 +108,8 @@ WeChatBot = false
ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口所以如果需要使用邮件功能,请别用阿里云服务器
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"

View File

@@ -1,22 +1,29 @@
package core
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"context"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/nfnt/resize"
"golang.org/x/image/webp"
"gorm.io/gorm"
"image"
"image/jpeg"
"io"
"log"
"net/http"
"os"
"runtime/debug"
@@ -200,10 +207,13 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
func needLogin(c *gin.Context) bool {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/logout" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/admin/logout" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/user/session" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
@@ -215,13 +225,21 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" ||
c.Request.URL.Path == "/api/markMap/client" ||
c.Request.URL.Path == "/api/payment/alipay/notify" ||
c.Request.URL.Path == "/api/payment/hupipay/notify" ||
c.Request.URL.Path == "/api/payment/payjs/notify" ||
c.Request.URL.Path == "/api/payment/doPay" ||
c.Request.URL.Path == "/api/payment/payWays" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false
}
@@ -326,6 +344,10 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
@@ -342,7 +364,9 @@ func staticResourceMiddleware() gin.HandlerFunc {
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
log.Fatal(err)
logger.Error(err)
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)

View File

@@ -1,10 +1,17 @@
package core
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/utils"
"os"
"github.com/BurntSushi/toml"
@@ -23,7 +30,7 @@ func NewDefaultConfig() *types.AppConfig {
SecretKey: utils.RandString(64),
MaxAge: 86400,
},
ApiConfig: types.ChatPlusApiConfig{},
ApiConfig: types.ApiConfig{},
OSS: types.OSSConfig{
Active: "local",
Local: types.LocalStorageConfig{

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// ApiRequest API 请求实体
type ApiRequest struct {
Model string `json:"model,omitempty"` // 兼容百度文心一言
@@ -8,7 +15,7 @@ type ApiRequest struct {
Stream bool `json:"stream"`
Messages []interface{} `json:"messages,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
Tools []interface{} `json:"tools,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"`
@@ -54,14 +61,15 @@ type ChatSession struct {
}
type ChatModel struct {
Id uint `json:"id"`
Platform Platform `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id"` // 绑定 API KEY
}
type ApiError struct {

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"github.com/gorilla/websocket"

View File

@@ -1,26 +1,33 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
)
type AppConfig struct {
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool
MjPlusConfigs []MidJourneyPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config
MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig
AlipayConfig AlipayConfig
@@ -30,6 +37,7 @@ type AppConfig struct {
}
type SmtpConfig struct {
UseTls bool // 是否使用 TLS 发送
Host string
Port int
AppName string // 应用名称
@@ -37,38 +45,31 @@ type SmtpConfig struct {
Password string // 发件人邮箱密码
}
type ChatPlusApiConfig struct {
type ApiConfig struct {
ApiURL string
AppId string
Token string
}
type MidJourneyConfig struct {
Enabled bool
UserToken string
BotToken string
GuildId string // Server ID
ChanelId string // Chanel ID
UseCDN bool
ImgCdnURL string // 图片反代加速地址
DiscordAPI string
DiscordGateway string
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
ApiURL string
ApiKey string
Txt2ImgJsonPath string
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type MidJourneyPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
CdnURL string // CDN 加速地址
ApiKey string
NotifyURL string // 任务进度更新回调地址
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct {
@@ -121,18 +122,64 @@ type RedisConfig struct {
DB int
}
// LicenseKey 存储许可证书的 KEY
const LicenseKey = "Geek-AI-License"
type License struct {
Key string `json:"key"` // 许可证书密钥
MachineId string `json:"machine_id"` // 机器码
ExpiredAt int64 `json:"expired_at"` // 过期时间
IsActive bool `json:"is_active"` // 是否激活
Configs LicenseConfig `json:"configs"`
}
type LicenseConfig struct {
UserNum int `json:"user_num"` // 用户数量
DeCopy bool `json:"de_copy"` // 去版权
}
func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
type Platform string
type Platform struct {
Name string `json:"name"`
Value string `json:"value"`
ChatURL string `json:"chat_url"`
ImgURL string `json:"img_url"`
}
const OpenAI = Platform("OpenAI")
const Azure = Platform("Azure")
const ChatGLM = Platform("ChatGLM")
const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
const QWen = Platform("QWen")
var OpenAI = Platform{
Name: "OpenAI - GPT",
Value: "OpenAI",
ChatURL: "https://api.chat-plus.net/v1/chat/completions",
ImgURL: "https://api.chat-plus.net/v1/images/generations",
}
var Azure = Platform{
Name: "微软 - Azure",
Value: "Azure",
ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15",
}
var ChatGLM = Platform{
Name: "智谱 - ChatGLM",
Value: "ChatGLM",
ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke",
}
var Baidu = Platform{
Name: "百度 - 文心大模型",
Value: "Baidu",
ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
}
var XunFei = Platform{
Name: "讯飞 - 星火大模型",
Value: "XunFei",
ChatURL: "wss://spark-api.xf-yun.com/{version}/chat",
}
var QWen = Platform{
Name: "阿里 - 通义千问",
Value: "QWen",
ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
}
type SystemConfig struct {
Title string `json:"title,omitempty"`
@@ -143,7 +190,7 @@ type SystemConfig struct {
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机mobile邮箱注册email,账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
@@ -151,15 +198,20 @@ type SystemConfig struct {
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text"` // 会员页面充值说明
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3消耗算力
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"`
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
RandBg bool `json:"rand_bg"` // 前端首页是否启用随机背景
}

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type ToolCall struct {
Type string `json:"type"`
Function struct {
@@ -8,19 +15,13 @@ type ToolCall struct {
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function Function `json:"function"`
}
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters Parameters `json:"parameters"`
}
type Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]Property `json:"properties"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"sync"

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OrderStatus int
const (

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OSSConfig struct {
Active string
Local LocalStorageConfig

View File

@@ -1,11 +1,17 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const LoginUserID = "LOGIN_USER_ID"
const LoginUserCache = "LOGIN_USER_CACHE"
const UserAuthHeader = "Authorization"
const AdminAuthHeader = "Admin-Authorization"
const ChatTokenHeader = "Chat-Token"
// Session configs struct
type Session struct {

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SMSConfig struct {
Active string
Ali SmsConfigAli

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// TaskType 任务类别
type TaskType string
@@ -25,6 +32,8 @@ type MjTask struct {
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
NegPrompt string `json:"neg_prompt,omitempty"`
Params string `json:"full_prompt"`
Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
@@ -36,25 +45,38 @@ type SdTask struct {
SessionId string `json:"session_id"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"`
}
type SdTaskParams struct {
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
Scheduler string `json:"scheduler"`
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
}
// DallTask DALL-E task
type DallTask struct {
JobId uint `json:"job_id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
}

View File

@@ -1,5 +1,12 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// BizVo 业务返回 VO
type BizVo struct {
Code BizCode `json:"code"`
@@ -21,7 +28,7 @@ const (
WsStart = WsMsgType("start")
WsMiddle = WsMsgType("middle")
WsEnd = WsMsgType("end")
WsMjImg = WsMsgType("mj")
WsErr = WsMsgType("error")
)
type BizCode int

View File

@@ -1,6 +1,8 @@
module chatplus
module geekai
go 1.19
go 1.21
toolchain go1.22.4
require (
github.com/BurntSushi/toml v1.1.0
@@ -25,22 +27,28 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0
require github.com/bg5t/mydiscordgo v0.28.1
require (
github.com/mojocn/base64Captcha v1.3.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
)
require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/mock v0.4.0 // indirect
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
@@ -51,7 +59,6 @@ require (
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
@@ -68,9 +75,7 @@ require (
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
github.com/quic-go/quic-go v0.35.1 // indirect
github.com/quic-go/quic-go v0.45.0 // indirect
github.com/refraction-networking/utls v1.3.2 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
@@ -80,14 +85,14 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
go.uber.org/dig v1.16.1 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.10.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
@@ -106,7 +111,7 @@ require (
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/fx v1.19.3
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/crypto v0.12.0
golang.org/x/sys v0.11.0 // indirect
golang.org/x/crypto v0.23.0
golang.org/x/sys v0.20.0 // indirect
gorm.io/gorm v1.25.1
)

View File

@@ -7,13 +7,12 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8=
github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -29,7 +28,9 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
@@ -41,8 +42,12 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -67,18 +72,18 @@ github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJ
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -86,6 +91,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -136,10 +142,16 @@ github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
@@ -157,12 +169,8 @@ github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYX
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8=
github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
@@ -170,6 +178,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
@@ -196,6 +206,12 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
@@ -206,8 +222,9 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
@@ -216,6 +233,9 @@ go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk=
go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA=
go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM=
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
@@ -224,39 +244,35 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -265,8 +281,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -276,34 +292,32 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -1,14 +1,21 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"context"
"fmt"
"github.com/go-redis/redis/v8"

View File

@@ -1,13 +1,21 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -49,6 +57,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey.Name = data.Name
res := h.DB.Save(&apiKey)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -65,14 +74,24 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
status := h.GetBool(c, "status")
t := h.GetTrim(c, "type")
platform := h.GetTrim(c, "platform")
session := h.DB.Session(&gorm.Session{})
if status {
session = session.Where("enabled", true)
}
if t != "" {
session = session.Where("type", t)
}
if platform != "" {
session = session.Where("platform", platform)
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := h.DB.Find(&items)
res := session.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
@@ -104,6 +123,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -111,19 +131,17 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -1,9 +1,16 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/handler"
"chatplus/utils/resp"
"geekai/core"
"geekai/handler"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -33,11 +40,6 @@ type chatItemVo struct {
}
func (h *ChatHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct {
Title string `json:"title"`
UserId uint `json:"user_id"`
@@ -259,6 +261,7 @@ func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil {
logger.Error("error with update database", tx.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -1,16 +1,23 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type ChatModelHandler struct {
@@ -34,6 +41,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id,omitempty"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -51,13 +59,17 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
KeyId: data.KeyId,
Power: data.Power}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
var res *gorm.DB
if data.Id > 0 {
item.Id = data.Id
res = h.DB.Select("*").Omit("created_at").Updates(&item)
} else {
res = h.DB.Create(&item)
}
res := h.DB.Save(&item)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -75,31 +87,45 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
platform := h.GetTrim(c, "platform")
if enable {
session = session.Where("enabled", enable)
}
if platform != "" {
session = session.Where("platform", platform)
}
var items []model.ChatModel
var cms = make([]vo.ChatModel, 0)
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel
err := utils.CopyObject(item, &cm)
if err == nil {
cm.Id = item.Id
cm.CreatedAt = item.CreatedAt.Unix()
cm.UpdatedAt = item.UpdatedAt.Unix()
cms = append(cms, cm)
} else {
logger.Error(err)
}
if res.Error != nil {
resp.SUCCESS(c, cms)
return
}
// initialize key name
keyIds := make([]int, 0)
for _, v := range items {
keyIds = append(keyIds, v.KeyId)
}
var keys []model.ApiKey
keyMap := make(map[uint]string)
h.DB.Where("id IN ?", keyIds).Find(&keys)
for _, v := range keys {
keyMap[v.Id] = v.Name
}
for _, item := range items {
var cm vo.ChatModel
err := utils.CopyObject(item, &cm)
if err == nil {
cm.Id = item.Id
cm.CreatedAt = item.CreatedAt.Unix()
cm.UpdatedAt = item.UpdatedAt.Unix()
cm.KeyName = keyMap[uint(item.KeyId)]
cms = append(cms, cm)
} else {
logger.Error(err)
}
}
resp.SUCCESS(c, cms)
@@ -119,6 +145,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -139,6 +166,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
for index, id := range data.Ids {
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -156,6 +184,7 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -1,16 +1,24 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type ChatRoleHandler struct {
@@ -40,6 +48,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
}
res := h.DB.Save(&role)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -50,11 +59,6 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
}
func (h *ChatRoleHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.DB.Order("sort_num ASC").Find(&items)
@@ -63,6 +67,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return
}
// initialize model mane for role
modelIds := make([]int, 0)
for _, v := range items {
if v.ModelId > 0 {
modelIds = append(modelIds, v.ModelId)
}
}
modelNameMap := make(map[int]string)
if len(modelIds) > 0 {
var models []model.ChatModel
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
if tx.Error == nil {
for _, m := range models {
modelNameMap[int(m.Id)] = m.Name
}
}
}
for _, v := range items {
var role vo.ChatRole
err := utils.CopyObject(v, &role)
@@ -70,6 +93,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
role.Id = v.Id
role.CreatedAt = v.CreatedAt.Unix()
role.UpdatedAt = v.UpdatedAt.Unix()
role.ModelName = modelNameMap[role.ModelId]
roles = append(roles, role)
}
}
@@ -92,6 +116,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
for index, id := range data.Ids {
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -114,6 +139,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -121,19 +147,15 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "删除失败!")
return
}

View File

@@ -1,23 +1,45 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/mj"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
)
type ConfigHandler struct {
handler.BaseHandler
levelDB *store.LevelDB
licenseService *service.LicenseService
mjServicePool *mj.ServicePool
sdServicePool *sd.ServicePool
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB,
mjServicePool: mjPool,
sdServicePool: sdPool,
licenseService: licenseService,
}
}
func (h *ConfigHandler) Update(c *gin.Context) {
@@ -70,11 +92,6 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
key := c.Query("key")
var config model.Config
res := h.DB.Where("marker", key).First(&config)
@@ -92,3 +109,89 @@ func (h *ConfigHandler) Get(c *gin.Context) {
resp.SUCCESS(c, value)
}
// Active 激活系统
func (h *ConfigHandler) Active(c *gin.Context) {
var data struct {
License string `json:"license"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
info, err := host.Info()
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.licenseService.ActiveLicense(data.License, info.HostID)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, info.HostID)
}
// GetLicense 获取 License 信息
func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license)
}
// GetAppConfig 获取内置配置
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
resp.SUCCESS(c, gin.H{
"mj_plus": h.App.Config.MjPlusConfigs,
"mj_proxy": h.App.Config.MjProxyConfigs,
"sd": h.App.Config.SdConfigs,
"platforms": Platforms,
})
}
// SaveDrawingConfig 保存AI绘画配置
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
var data struct {
Sd []types.StableDiffusionConfig `json:"sd"`
MjPlus []types.MjPlusConfig `json:"mj_plus"`
MjProxy []types.MjProxyConfig `json:"mj_proxy"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
changed := false
if configChanged(data.Sd, h.App.Config.SdConfigs) {
logger.Debugf("SD 配置变动了")
h.App.Config.SdConfigs = data.Sd
h.sdServicePool.InitServices(data.Sd)
changed = true
}
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
logger.Debugf("MidJourney 配置变动了")
h.App.Config.MjPlusConfigs = data.MjPlus
h.App.Config.MjProxyConfigs = data.MjProxy
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
changed = true
}
if changed {
err := core.SaveConfig(h.App.Config)
if err != nil {
resp.ERROR(c, "更新配置文档失败!")
return
}
}
resp.SUCCESS(c)
}
func configChanged(c1 interface{}, c2 interface{}) bool {
encode1 := utils.JsonEncode(c1)
encode2 := utils.JsonEncode(c2)
return utils.Md5(encode1) != utils.Md5(encode2)
}

View File

@@ -1,11 +1,18 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"gorm.io/gorm"

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/golang-jwt/jwt/v5"
@@ -64,6 +71,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -71,11 +79,6 @@ func (h *FunctionHandler) Set(c *gin.Context) {
}
func (h *FunctionHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Function
res := h.DB.Find(&items)
if res.Error != nil {
@@ -101,6 +104,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
if id > 0 {
res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -0,0 +1,132 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MenuHandler struct {
handler.BaseHandler
}
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *MenuHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
URL string `json:"url"`
SortNum int `json:"sort_num"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Save(&model.Menu{
Id: data.Id,
Name: data.Name,
Icon: data.Icon,
URL: data.URL,
SortNum: data.SortNum,
Enabled: data.Enabled,
})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu
err := utils.CopyObject(item, &product)
if err == nil {
list = append(list, product)
}
}
}
resp.SUCCESS(c, list)
}
func (h *MenuHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *MenuHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
for index, id := range data.Ids {
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}
func (h *MenuHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Menu{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -22,11 +29,6 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
}
func (h *OrderHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct {
OrderNo string `json:"order_no"`
Status int `json:"status"`
@@ -92,6 +94,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -67,5 +74,11 @@ func (h *PowerLogHandler) List(c *gin.Context) {
list = append(list, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
// 统计消费算力总和
var totalPower float64
if len(data.Date) == 2 {
session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
}
resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
}

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
@@ -50,6 +57,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
}
res := h.DB.Save(&item)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -65,21 +73,11 @@ func (h *ProductHandler) Save(c *gin.Context) {
resp.SUCCESS(c, itemVo)
}
// List 模型列表
// List 数据列表
func (h *ProductHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
}
var items []model.Product
var list = make([]vo.Product, 0)
res := session.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product
@@ -110,6 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -128,8 +127,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -142,8 +142,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Where("id = ?", id).Delete(&model.Product{})
res := h.DB.Where("id", id).Delete(&model.Product{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -1,13 +1,20 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -21,11 +28,6 @@ func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
}
func (h *RewardHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Reward
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
@@ -70,6 +72,7 @@ func (h *RewardHandler) Remove(c *gin.Context) {
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -0,0 +1,12 @@
package admin
import "geekai/core/types"
var Platforms = []types.Platform{
types.OpenAI,
types.QWen,
types.XunFei,
types.ChatGLM,
types.Baidu,
types.Azure,
}

View File

@@ -1,11 +1,18 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/handler"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils/resp"
"geekai/core"
"geekai/handler"
"geekai/service/oss"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"

View File

@@ -1,14 +1,22 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
@@ -17,19 +25,15 @@ import (
type UserHandler struct {
handler.BaseHandler
licenseService *service.LicenseService
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
username := h.GetTrim(c, "username")
@@ -80,6 +84,13 @@ func (h *UserHandler) Save(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检测最大注册人数
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
var user = model.User{}
var res *gorm.DB
var userVo vo.User
@@ -100,6 +111,7 @@ func (h *UserHandler) Save(c *gin.Context) {
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -145,6 +157,7 @@ func (h *UserHandler) Save(c *gin.Context) {
}
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败")
return
}

View File

@@ -1,11 +1,18 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/utils"
"geekai/core"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils"
"errors"
"fmt"
"gorm.io/gorm"

View File

@@ -1,9 +1,16 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/utils/resp"
"geekai/core/types"
"geekai/service"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
)
@@ -45,3 +52,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
}
}
// SlideGet 获取滑动验证图片
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
data, err := h.service.SlideGet()
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// SlideCheck 滑动验证结果校验
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
var data struct {
Key string `json:"key"`
X int `json:"x"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.service.SlideCheck(data) {
resp.SUCCESS(c)
} else {
resp.ERROR(c)
}
}

View File

@@ -1,11 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -25,7 +33,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var res *gorm.DB
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
@@ -36,7 +44,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
h.DB.Where("id IN ?", models).Or("open", true),
).Order("sort_num ASC").Find(&items)
}

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -25,9 +32,10 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
var roleVos = make([]vo.ChatRole, 0)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error())
resp.SUCCESS(c, roleVos)
return
}
@@ -55,8 +63,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
resp.ERROR(c, "角色解析失败!")
return
}
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
if !utils.ContainsStr(roleKeys, r.Key) {
continue
@@ -89,7 +96,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil {
logger.Error("添加应用失败:", err)
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -1,19 +1,25 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"errors"
"fmt"
"html/template"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"io"
"strings"
"time"
"unicode/utf8"
)
// 微软 Azure 模型消息发送实现
@@ -30,21 +36,14 @@ func (h *ChatHandler) sendAzureMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
@@ -66,10 +65,7 @@ func (h *ChatHandler) sendAzureMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
return errors.New(line)
}
if len(responseBody.Choices) == 0 {
@@ -103,105 +99,12 @@ func (h *ChatHandler) sendAzureMessage(
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if strings.Contains(res.Error.Message, "maximum context length") {
logger.Error(res.Error.Message)
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
}
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil

View File

@@ -1,20 +1,26 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"errors"
"fmt"
"html/template"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)
type baiduResp struct {
@@ -47,21 +53,15 @@ func (h *ChatHandler) sendBaiduMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
@@ -128,99 +128,11 @@ func (h *ChatHandler) sendBaiduMessage(
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil

View File

@@ -1,25 +1,35 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"html/template"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
@@ -27,30 +37,21 @@ import (
"gorm.io/gorm"
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
var ErrImg = "![](/images/wx.png)"
var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
}
}
func (h *ChatHandler) Init() {
// 如果后台有上传微信客服微信二维码,则覆盖
if h.App.SysConfig.WechatCardURL != "" {
ErrImg = fmt.Sprintf("![](%s)", h.App.SysConfig.WechatCardURL)
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
}
}
@@ -68,9 +69,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
modelId := h.GetInt(c, "model_id", 0)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
modelId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
res := h.DB.First(&chatModel, modelId)
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -111,17 +123,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
Platform: types.Platform(chatModel.Platform)}
KeyId: chatModel.KeyId,
Platform: chatModel.Platform}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
return
}
h.Init()
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
@@ -129,8 +133,10 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.App.ChatClients.Delete(sessionId)
h.App.ChatSession.Delete(sessionId)
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
@@ -159,7 +165,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
utils.ReplyMessage(client, err.Error())
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Infof("回答完毕: %v", message.Content)
@@ -181,8 +187,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!")
return res.Error
return errors.New("未授权用户,您正在进行非法操作!")
}
var userVo vo.User
err := utils.CopyObject(user, &userVo)
@@ -192,28 +197,22 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
if userVo.Status == false {
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
}
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d[立即购买](/member)。", userVo.Power, session.Model.Power)
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
return errors.New("您的账号已经过期,请联系管理员!")
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
}
var req = types.ApiRequest{
@@ -221,11 +220,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Stream: true,
}
switch session.Model.Platform {
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI:
case types.OpenAI.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
@@ -235,31 +234,32 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
break
}
var tools = make([]interface{}, 0)
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
required := parameters["required"]
delete(parameters, "required")
tools = append(tools, gin.H{
"type": "function",
"function": gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
})
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen:
case types.QWen.Value:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
@@ -267,9 +267,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
}
// 加载聊天上下文
@@ -325,11 +323,41 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
if session.Model.Platform == types.QWen {
req.Input = map[string]interface{}{"prompt": prompt}
if len(reqMgs) > 0 {
req.Input["messages"] = reqMgs
if session.Model.Platform == types.QWen.Value {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: prompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
text := prompt
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
data = append(data, gin.H{
"type": "text",
"text": text,
})
content = data
} else {
content = prompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
@@ -337,24 +365,23 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
})
}
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure:
case types.Azure.Value:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI:
case types.OpenAI.Value:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM:
case types.ChatGLM.Value:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu:
case types.Baidu.Value:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
case types.XunFei.Value:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen:
case types.QWen.Value:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform),
})
return nil
}
@@ -424,26 +451,42 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI.Value {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
}
var apiURL string
switch platform {
case types.Azure:
switch session.Model.Platform {
case types.Azure.Value:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM:
case types.ChatGLM.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
case types.Baidu.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen:
case types.QWen.Value:
apiURL = apiKey.ApiURL
req.Messages = nil
break
@@ -453,7 +496,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
if session.Model.Platform == types.Baidu.Value {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
@@ -477,8 +520,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json")
var proxyURL string
if apiKey.ProxyURL != "" { // 使用代理
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
@@ -488,24 +530,24 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
switch platform {
case types.Azure:
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure.Value:
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM:
case types.ChatGLM.Value:
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu:
case types.Baidu.Value:
request.RequestURI = ""
case types.OpenAI:
case types.OpenAI.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen:
case types.QWen.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
@@ -539,6 +581,99 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
}
func (h *ChatHandler) saveChatHistory(
req types.ApiRequest,
prompt string,
contents []string,
message types.Message,
chatCtx []types.Message,
session *types.ChatSession,
role model.ChatRole,
userVo vo.User,
promptCreatedAt time.Time,
replyCreatedAt time.Time) {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
if session.Model.Power > 0 {
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
}
// 将AI回复消息中生成的图片链接下载到本地
func (h *ChatHandler) extractImgUrl(text string) string {
pattern := `!\[([^\]]*)]\(([^)]+)\)`

View File

@@ -1,11 +1,18 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -187,12 +194,20 @@ func (h *ChatHandler) Detail(c *gin.Context) {
return
}
// 填充角色名称
var role model.ChatRole
res = h.DB.Where("id", chatItem.RoleId).First(&role)
if res.Error != nil {
resp.ERROR(c, "Role not found")
return
}
var chatItemVo vo.ChatItem
err := utils.CopyObject(chatItem, &chatItemVo)
if err != nil {
resp.ERROR(c, err.Error())
return
}
chatItemVo.RoleName = role.Name
resp.SUCCESS(c, chatItemVo)
}

View File

@@ -1,20 +1,25 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/golang-jwt/jwt/v5"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
// 清华大学 ChatGML 消息发送实现
@@ -31,21 +36,14 @@ func (h *ChatHandler) sendChatGLMMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
@@ -107,103 +105,11 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"code"`
Success bool `json:"success"`
Msg string `json:"msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if !res.Success {
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
}
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil

View File

@@ -1,21 +1,26 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"errors"
"fmt"
"html/template"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
req2 "github.com/imroc/req/v3"
"io"
"strings"
"time"
"unicode/utf8"
req2 "github.com/imroc/req/v3"
)
// OPenAI 消息发送实现
@@ -31,24 +36,13 @@ func (h *ChatHandler) sendOpenAiMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
if response.Body != nil {
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
@@ -65,6 +59,7 @@ func (h *ChatHandler) sendOpenAiMessage(
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
@@ -73,10 +68,15 @@ func (h *ChatHandler) sendOpenAiMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
if err != nil { // 数据解析出错
return errors.New(line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
@@ -103,8 +103,10 @@ func (h *ChatHandler) sendOpenAiMessage(
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg)
}
continue
}
@@ -114,16 +116,16 @@ func (h *ChatHandler) sendOpenAiMessage(
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
// output stopped
if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
if isNew {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
@@ -173,126 +175,11 @@ func (h *ChatHandler) sendOpenAiMessage(
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
logger.Error(res.Error.Message)
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
}
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%s", body)
}
return nil

View File

@@ -1,19 +1,24 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/syndtr/goleveldb/leveldb/errors"
"io"
"strings"
"time"
"unicode/utf8"
)
type qWenResp struct {
@@ -45,21 +50,14 @@ func (h *ChatHandler) sendQWenMessage(
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
@@ -82,10 +80,11 @@ func (h *ChatHandler) sendQWenMessage(
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
if !strings.HasPrefix(line, "data:") {
continue
}
content = line[5:]
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
@@ -140,100 +139,11 @@ func (h *ChatHandler) sendQWenMessage(
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil

View File

@@ -1,24 +1,31 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/gorilla/websocket"
"html/template"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
)
type xunFeiResp struct {
@@ -69,10 +76,17 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
@@ -165,89 +179,10 @@ func (h *ChatHandler) sendXunFeiMessage(
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
return nil
}

View File

@@ -1,10 +1,18 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -12,10 +20,11 @@ import (
type ConfigHandler struct {
BaseHandler
licenseService *service.LicenseService
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
}
// Get 获取指定的系统配置
@@ -37,3 +46,9 @@ func (h *ConfigHandler) Get(c *gin.Context) {
resp.SUCCESS(c, value)
}
// License 获取 License 配置
func (h *ConfigHandler) License(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license.Configs)
}

View File

@@ -0,0 +1,262 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
service *dalle.Service
uploader *oss.UploaderManager
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
return &DallJobHandler{
service: service,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) {
if !h.preCheck(c) {
return
}
var data types.DallTask
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.DallJob{
UserId: uint(userId),
Prompt: data.Prompt,
Power: h.App.SysConfig.DallPower,
}
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
}
h.service.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// ImgWall 照片墙
func (h *DallJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *DallJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var items []model.DallJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
}
var jobs = make([]vo.DallJob, 0)
for _, item := range items {
var job vo.DallJob
err := utils.CopyObject(item, &job)
if err != nil {
continue
}
jobs = append(jobs, job)
}
return nil, jobs
}
// Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.DallJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -1,29 +1,44 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"errors"
"fmt"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"strings"
"time"
)
type FunctionHandler struct {
BaseHandler
config types.ChatPlusApiConfig
config types.ApiConfig
uploadManager *oss.UploaderManager
dallService *dalle.Service
}
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
func NewFunctionHandler(
server *core.AppServer,
db *gorm.DB,
config *types.AppConfig,
manager *oss.UploaderManager,
dallService *dalle.Service) *FunctionHandler {
return &FunctionHandler{
BaseHandler: BaseHandler{
App: server,
@@ -31,6 +46,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo
},
config: config.ApiConfig,
uploadManager: manager,
dallService: dallService,
}
}
@@ -151,30 +167,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
// Dall3 DallE3 AI 绘图
func (h *FunctionHandler) Dall3(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
@@ -190,85 +182,45 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
logger.Debugf("绘画参数:%+v", params)
var user model.User
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil {
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
if res.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
return
}
// create dall task
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
job := model.DallJob{
UserId: user.Id,
Prompt: prompt,
Power: h.App.SysConfig.DallPower,
}
res = h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
return
}
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt
}
var res imgRes
var errRes ErrRes
var request *req.Request
if apiKey.ProxyURL != "" {
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else {
request = req.C().R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if r.IsErrorState() {
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
return
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res)
// 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
content, err := h.dallService.Image(types.DallTask{
JobId: job.Id,
UserId: user.Id,
Prompt: job.Prompt,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: job.Power,
}, true)
if err != nil {
resp.ERROR(c, "下载图片失败: "+err.Error())
resp.ERROR(c, "任务执行失败:"+err.Error())
return
}
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// 更新用户算力
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
h.DB.Where("id", user.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content)
}

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"strings"

View File

@@ -0,0 +1,273 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// MarkMapHandler 生成思维导图
type MarkMapHandler struct {
BaseHandler
clients *types.LMap[int, *types.WsClient]
}
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
return &MarkMapHandler{
BaseHandler: BaseHandler{App: app, DB: db},
clients: types.NewLMap[int, *types.WsClient](),
}
}
func (h *MarkMapHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
modelId := h.GetInt(c, "model_id", 0)
userId := h.GetInt(c, "user_id", 0)
client := types.NewWsClient(ws)
h.clients.Put(userId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.clients.Delete(userId)
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
continue
}
// change model
if message.Type == "model_id" {
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
continue
}
logger.Info("Receive a message: ", message.Content)
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
}
}
}()
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId)
if res.Error != nil {
return fmt.Errorf("error with query user info: %v", res.Error)
}
var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel)
if res.Error != nil {
return fmt.Errorf("error with query chat model: %v", res.Error)
}
if user.Status == false {
return errors.New("当前用户被禁用")
}
if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power)
}
messages := make([]interface{}, 0)
messages = append(messages, types.Message{Role: "system", Content: `
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
# Geek-AI 助手
## 完整的开源系统
### 前端开源
### 后端开源
## 支持各种大模型
### OpenAI
### Azure
### 文心一言
### 通义千问
## 集成多种收费方式
### 支付宝
### 微信
另外,除此之外不要任何解释性语句。
`})
messages = append(messages, types.Message{Role: "user", Content: prompt})
var req = types.ApiRequest{
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
}
defer response.Body.Close()
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return fmt.Errorf("error with decode data: %v", line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" {
break
}
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("解析响应失败: %v", err)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
// remove key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
return errors.New("请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
return errors.New("请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else {
return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
}
}
// 扣减算力
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userId).First(&u)
h.DB.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerConsume,
Amount: chatModel.Power,
Mark: types.PowerSub,
Balance: u.Power,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
CreatedAt: time.Now(),
})
}
return nil
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", types.OpenAI).
Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := apiKey.ApiURL
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
return client.Do(request)
}

View File

@@ -0,0 +1,43 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MenuHandler struct {
BaseHandler
}
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu
err := utils.CopyObject(item, &product)
if err == nil {
list = append(list, product)
}
}
}
resp.SUCCESS(c, list)
}

View File

@@ -1,18 +1,24 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/mj"
"chatplus/service/mj/plus"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/mj"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"strings"
"time"
@@ -99,7 +105,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Weight float32 `json:"weight"`
Iw float32 `json:"iw"`
CRef string `json:"cref"` //生成角色一致的图像
SRef string `json:"sref"` //生成风格一致的图像
Cw int `json:"cw"` // 参考程度
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -109,41 +118,57 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
var prompt = data.Prompt
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
prompt += " --ar " + data.Rate
var params = ""
if data.Rate != "" && !strings.Contains(params, "--ar") {
params += " --ar " + data.Rate
}
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
prompt += fmt.Sprintf(" --seed %d", data.Seed)
if data.Seed > 0 && !strings.Contains(params, "--seed") {
params += fmt.Sprintf(" --seed %d", data.Seed)
}
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
prompt += fmt.Sprintf(" --s %d", data.Stylize)
if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
params += fmt.Sprintf(" --s %d", data.Stylize)
}
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
prompt += fmt.Sprintf(" --c %d", data.Chaos)
if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
params += fmt.Sprintf(" --c %d", data.Chaos)
}
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
if len(data.ImgArr) > 0 && data.Iw > 0 {
params += fmt.Sprintf(" --iw %.2f", data.Iw)
}
if data.Raw {
prompt += " --style raw"
params += " --style raw"
}
if data.Quality > 0 {
prompt += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", data.NegPrompt)
params += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.Tile {
prompt += " --tile "
params += " --tile "
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += fmt.Sprintf(" %s", data.Model)
if data.CRef != "" {
params += fmt.Sprintf(" --cref %s", data.CRef)
if data.Cw > 0 {
params += fmt.Sprintf(" --cw %d", data.Cw)
} else {
params += " --cw 100"
}
}
if data.SRef != "" {
params += fmt.Sprintf(" --sref %s", data.SRef)
}
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
params += fmt.Sprintf(" %s", data.Model)
}
// 处理融图和换脸的提示词
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
prompt = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
}
// 如果本地图片上传的是相对地址,处理成绝对地址
for k, v := range data.ImgArr {
if !strings.HasPrefix(v, "http") {
data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/"))
}
}
idValue, _ := c.Get(types.LoginUserID)
@@ -159,7 +184,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: prompt,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
@@ -182,7 +207,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskType(data.TaskType),
Prompt: prompt,
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
})
@@ -246,6 +273,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
@@ -269,7 +297,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -296,7 +340,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjPower,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
@@ -454,27 +498,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
resp.SUCCESS(c)
}
// Notify MidJourney Plus 服务任务回调处理
func (h *MidJourneyHandler) Notify(c *gin.Context) {
var data plus.CBReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error("非法任务回调:%+v", err)
return
}
err := h.pool.Notify(data)
if err != nil {
logger.Error(err)
} else {
userId := h.GetLoginUserId(c)
client := h.pool.Clients.Get(userId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
}
resp.SUCCESS(c)
}
// Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct {
@@ -488,6 +511,7 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败")
return
}

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"

View File

@@ -1,16 +1,23 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/payment"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"embed"
"encoding/base64"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/payment"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/shopspring/decimal"
"math"
"net/http"
@@ -37,6 +44,7 @@ type PaymentHandler struct {
snowflake *service.Snowflake
fs embed.FS
lock sync.Mutex
signKey string // 用来签名的随机秘钥
}
func NewPaymentHandler(
@@ -58,12 +66,27 @@ func NewPaymentHandler(
App: server,
DB: db,
},
signKey: utils.RandString(32),
}
}
func (h *PaymentHandler) DoPay(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no")
payWay := h.GetTrim(c, "pay_way")
t := h.GetInt(c, "t", 0)
sign := h.GetTrim(c, "sign")
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey)
newSign := utils.Sha256(signStr)
if newSign != sign {
resp.ERROR(c, "订单签名错误!")
return
}
// 检查二维码是否过期
if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) {
resp.ERROR(c, "支付二维码已过期,请重新生成!")
return
}
if orderNo == "" {
resp.ERROR(c, types.InvalidArgs)
@@ -266,8 +289,10 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
resp.ERROR(c, err.Error())
return
}
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
timestamp := time.Now().Unix()
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
sign := utils.Sha256(signStr)
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
imgData, err := utils.GenQrcode(imageURL, 400, file)
if err != nil {
resp.ERROR(c, err.Error())
@@ -317,6 +342,8 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
@@ -326,6 +353,8 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
ReturnURL: returnURL,
CallbackURL: returnURL,
WapName: "极客学长",
WapUrl: baseURL,
Type: "WAP",
}
r, err := h.huPiPayService.Pay(params)
if err != nil {
@@ -424,27 +453,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var opt string
var power int
if user.Vip { // 已经是 VIP 用户
if remark.Days > 0 { // 只延期 VIP不增加调用次数
if remark.Days > 0 { // VIP 充值
if user.ExpiredTime >= time.Now().Unix() {
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
} else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, power == 0
opt = "VIP充值VIP 没到期,只延期不增加算力"
} else {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
user.Vip = true
opt = "VIP充值"
power = h.App.SysConfig.VipMonthPower
} else { //点卡days == 0, calls > 0
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
opt = "VIP充值"
}
user.Vip = true
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"

View File

@@ -1,11 +1,18 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)

View File

@@ -1,60 +0,0 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
type PromptHandler struct {
BaseHandler
}
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
return &PromptHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Rewrite translate and rewrite prompt with ChatGPT
func (h *PromptHandler) Rewrite(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(rewritePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) Translate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}

View File

@@ -1,13 +1,20 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
@@ -66,6 +73,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
logger.Error("添加应用失败:", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -77,6 +85,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
logger.Error("添加应用失败:", res.Error)
resp.ERROR(c, "更新数据库失败!")
return
}

View File

@@ -1,16 +1,24 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/service/sd"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"time"
@@ -23,15 +31,19 @@ import (
type SdJobHandler struct {
BaseHandler
redis *redis.Client
pool *sd.ServicePool
uploader *oss.UploaderManager
redis *redis.Client
pool *sd.ServicePool
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
}
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{
pool: pool,
uploader: manager,
pool: pool,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -60,7 +72,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
@@ -83,7 +95,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
if !h.checkLimits(c) {
if !h.preCheck(c) {
return
}
@@ -116,23 +128,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
params := types.SdTaskParams{
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
Prompt: data.Prompt,
NegativePrompt: data.NegativePrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
taskId, err := h.snowflake.Next(true)
if err != nil {
resp.ERROR(c, "error with generate task id: "+err.Error())
return
}
params := types.SdTaskParams{
TaskId: taskId,
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
}
job := model.SdJob{
UserId: userId,
Type: types.TaskImage.String(),
@@ -153,7 +171,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Id: int(job.Id),
SessionId: data.SessionId,
Type: types.TaskImage,
Prompt: data.Prompt,
Params: params,
UserId: userId,
})
@@ -249,10 +266,11 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
}
if item.Progress < 100 {
// 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "")
// 从 leveldb 中获取图片预览数据
var imageData string
err = h.leveldb.Get(item.TaskId, &imageData)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
job.ImgURL = "data:image/png;base64," + imageData
}
}
jobs = append(jobs, job)
@@ -288,7 +306,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
_ = client.Send([]byte(sd.Finished))
}
resp.SUCCESS(c)
@@ -307,6 +325,7 @@ func (h *SdJobHandler) Publish(c *gin.Context) {
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败")
return
}

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/sms"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/sms"
"geekai/utils"
"geekai/utils/resp"
"strings"
"github.com/gin-gonic/gin"

View File

@@ -1,8 +1,8 @@
package handler
import (
"chatplus/service"
"chatplus/service/payment"
"geekai/service"
"geekai/service/payment"
"gorm.io/gorm"
)

View File

@@ -1,12 +1,19 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"geekai/core"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"

View File

@@ -1,13 +1,21 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"strings"
"time"
@@ -21,16 +29,23 @@ import (
type UserHandler struct {
BaseHandler
searcher *xdb.Searcher
redis *redis.Client
searcher *xdb.Searcher
redis *redis.Client
licenseService *service.LicenseService
}
func NewUserHandler(
app *core.AppServer,
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client) *UserHandler {
return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
client *redis.Client,
licenseService *service.LicenseService) *UserHandler {
return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher,
redis: client,
licenseService: licenseService,
}
}
// Register user register
@@ -53,9 +68,17 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
// 检测最大注册人数
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
// 检查验证码
var key string
if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
if data.RegWay == "email" || data.RegWay == "mobile" {
key = CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
@@ -216,18 +239,10 @@ func (h *UserHandler) Login(c *gin.Context) {
// Logout 注 销
func (h *UserHandler) Logout(c *gin.Context) {
sessionId := c.GetHeader(types.ChatTokenHeader)
key := h.GetUserKey(c)
if _, err := h.redis.Del(c, key).Result(); err != nil {
logger.Error("error with delete session: ", err)
}
// 删除 websocket 会话列表
h.App.ChatSession.Delete(sessionId)
// 关闭 socket 连接
client := h.App.ChatClients.Get(sessionId)
if client != nil {
client.Close()
}
resp.SUCCESS(c)
}
@@ -334,7 +349,7 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败")
return
}
@@ -415,6 +430,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil {
logger.Error(res.Error)
resp.ERROR(c, "更新数据库失败")
return
}

View File

@@ -1,5 +1,12 @@
package logger
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"

View File

@@ -1,22 +1,30 @@
package main
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/handler/admin"
"chatplus/handler/chatimpl"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/service/payment"
"chatplus/service/sd"
"chatplus/service/sms"
"chatplus/service/wx"
"chatplus/store"
"context"
"embed"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/handler/admin"
"geekai/handler/chatimpl"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/dalle"
"geekai/service/mj"
"geekai/service/oss"
"geekai/service/payment"
"geekai/service/sd"
"geekai/service/sms"
"geekai/service/wx"
"geekai/store"
"io"
"log"
"os"
@@ -43,16 +51,20 @@ type AppLifecycle struct {
// OnStart 应用程序启动时执行
func (l *AppLifecycle) OnStart(context.Context) error {
log.Println("AppLifecycle OnStart")
logger.Info("AppLifecycle OnStart")
return nil
}
// OnStop 应用程序停止时执行
func (l *AppLifecycle) OnStop(context.Context) error {
log.Println("AppLifecycle OnStop")
logger.Info("AppLifecycle OnStop")
return nil
}
func NewAppLifeCycle() *AppLifecycle {
return &AppLifecycle{}
}
func main() {
configFile := os.Getenv("CONFIG_FILE")
if configFile == "" {
@@ -92,6 +104,7 @@ func main() {
fx.Provide(store.NewGormConfig),
fx.Provide(store.NewMysql),
fx.Provide(store.NewRedisClient),
fx.Provide(store.NewLevelDB),
fx.Provide(func() embed.FS {
return xdbFS
@@ -148,9 +161,21 @@ func main() {
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) {
service.Run()
service.CheckTaskNotify()
service.DownloadImages()
service.CheckTaskStatus()
}),
// 邮件服务
fx.Provide(service.NewSmtpService),
// License 服务
fx.Provide(service.NewLicenseService),
fx.Invoke(func(licenseService *service.LicenseService) {
licenseService.SyncLicense()
}),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
@@ -165,7 +190,8 @@ func main() {
// MidJourney service pool
fx.Provide(mj.NewServicePool),
fx.Invoke(func(pool *mj.ServicePool) {
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
if pool.HasAvailableService() {
pool.DownloadImages()
pool.CheckTaskNotify()
@@ -175,6 +201,13 @@ func main() {
// Stable Diffusion 机器人
fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
pool.InitServices(config.SdConfigs)
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
}
}),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
@@ -232,6 +265,8 @@ func main() {
group := s.Engine.Group("/api/captcha/")
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
}),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
group := s.Engine.Group("/api/reward/")
@@ -246,7 +281,6 @@ func main() {
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("notify", h.Notify)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
@@ -261,13 +295,18 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
group.GET("license", h.License)
}),
// 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/config/")
group.POST("update", h.Update)
group.GET("get", h.Get)
group := s.Engine.Group("/api/admin/")
group.POST("config/update", h.Update)
group.GET("config/get", h.Get)
group.POST("active", h.Active)
group.GET("config/get/license", h.GetLicense)
group.GET("config/get/app", h.GetAppConfig)
group.POST("config/update/draw", h.SaveDrawingConfig)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/")
@@ -285,7 +324,7 @@ func main() {
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.POST("remove", h.Remove)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
group := s.Engine.Group("/api/admin/user/")
@@ -301,7 +340,7 @@ func main() {
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.POST("remove", h.Remove)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/")
@@ -365,13 +404,6 @@ func main() {
group.GET("hits", h.Hits)
}),
fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt/")
group.POST("rewrite", h.Rewrite)
group.POST("translate", h.Translate)
}),
fx.Provide(admin.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
group := s.Engine.Group("/api/admin/function/")
@@ -417,12 +449,44 @@ func main() {
group := s.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
fx.Provide(admin.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
group := s.Engine.Group("/api/admin/menu/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}),
fx.Provide(handler.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
group := s.Engine.Group("/api/menu/")
group.GET("list", h.List)
}),
fx.Provide(handler.NewMarkMapHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
group := s.Engine.Group("/api/markMap/")
group.Any("client", h.Client)
}),
fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
}()
}),
fx.Provide(NewAppLifeCycle),
// 注册生命周期回调函数
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
lifecycle.Append(fx.Hook{

View File

@@ -1,80 +0,0 @@
{
"data": [
"task(cxvkpawy8onnfti)",
"a cute girl",
"",
[],
20,
"DPM++ 2M Karras",
1,
1,
7,
512,
512,
false,
0.7,
2,
"Latent",
0,
0,
0,
"Use same checkpoint",
"Use same sampler",
"",
"",
[],
"None",
false,
"",
0.8,
-1,
false,
-1,
0,
0,
0,
null,
null,
null,
null,
false,
false,
"positive",
"comma",
0,
false,
false,
"",
"Seed",
"",
[],
"Nothing",
"",
[],
"Nothing",
"",
[],
true,
false,
false,
false,
0,
null,
null,
false,
null,
null,
false,
null,
null,
false,
50,
[],
"",
"",
""
],
"event_data": null,
"fn_index": 446,
"session_hash": "nk5noh1rz1o"
}

View File

@@ -1,19 +1,26 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"errors"
"fmt"
"geekai/core/types"
"github.com/imroc/req/v3"
"time"
)
type CaptchaService struct {
config types.ChatPlusApiConfig
config types.ApiConfig
client *req.Client
}
func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService {
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
return &CaptchaService{
config: config,
client: req.C().SetTimeout(10 * time.Second),
@@ -60,3 +67,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true
}
func (s *CaptchaService) SlideGet() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
}
if res.Code != types.Success {
return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
}
return res.Data, nil
}
func (s *CaptchaService) SlideCheck(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return false
}
if res.Code != types.Success {
return false
}
return true
}

View File

@@ -0,0 +1,313 @@
package dalle
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// DALL-E 绘画服务
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.DallTask) {
logger.Infof("add a new DALL-E task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
logger.Info("Starting DALL-E job consumer...")
go func() {
for {
var task types.DallTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("handle a new DALL-E task: %+v", task)
_, err = s.Image(task, false)
if err != nil {
logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
}
}
}()
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
Style string `json:"style"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
s.db.Where("id", task.UserId).First(&user)
if user.Power < task.Power {
return "", errors.New("insufficient of power")
}
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI.Value).
Where("type", "img").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
}
var res imgRes
var errRes ErrRes
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt,
})
if tx.Error != nil {
return "", fmt.Errorf("err with update database: %v", tx.Error)
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
}
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running DALL-E task notify checking ...")
for {
var message sd.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.DallJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
if err != nil {
logger.Error("error with download image: %s, error: %v", imgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
}
time.Sleep(time.Second * 5)
}
}()
}
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
if err != nil {
return "", err
}
// update img_url
res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
return imgURL, nil
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@@ -0,0 +1,197 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store"
"time"
"github.com/imroc/req/v3"
)
type LicenseService struct {
config types.ApiConfig
levelDB *store.LevelDB
license *types.License
urlWhiteList []string
machineId string
}
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License
return &LicenseService{
config: server.Config.ApiConfig,
levelDB: levelDB,
license: &license,
machineId: "",
}
}
type License struct {
Name string `json:"name"`
License string `json:"license"`
MachineId string `json:"mid"`
ActiveAt int64 `json:"active_at"`
ExpiredAt int64 `json:"expired_at"`
UserNum int `json:"user_num"`
Configs types.LicenseConfig `json:"configs"`
}
// ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return fmt.Errorf("发送激活请求失败: %v", err)
}
if response.IsErrorState() {
return fmt.Errorf("发送激活请求失败:%v", response.Status)
}
if res.Code != types.Success {
return fmt.Errorf("激活失败:%v", res.Message)
}
s.license = &types.License{
Key: license,
MachineId: machineId,
Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}
err = s.levelDB.Put(types.LicenseKey, s.license)
if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err)
}
return nil
}
// SyncLicense 定期同步 License
func (s *LicenseService) SyncLicense() {
go func() {
retryCounter := 0
for {
license, err := s.fetchLicense()
if err != nil {
retryCounter++
if retryCounter < 5 {
logger.Error(err)
}
s.license.IsActive = false
} else {
s.license = license
}
urls, err := s.fetchUrlWhiteList()
if err == nil {
s.urlWhiteList = urls
}
time.Sleep(time.Second * 10)
}
}()
}
func (s *LicenseService) fetchLicense() (*types.License, error) {
//var res struct {
// Code types.BizCode `json:"code"`
// Message string `json:"message"`
// Data License `json:"data"`
//}
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
//response, err := req.C().R().
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
// SetSuccessResult(&res).Post(apiURL)
//if err != nil {
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
//}
//if response.IsErrorState() {
// return nil, fmt.Errorf("激活失败:%v", response.Status)
//}
//if res.Code != types.Success {
// return nil, fmt.Errorf("激活失败:%v", res.Message)
//}
return &types.License{
Key: "abc",
MachineId: "abc",
Configs: types.LicenseConfig{
UserNum: 10000,
DeCopy: false,
},
ExpiredAt: 0,
IsActive: true,
}, nil
}
func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data []string `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)
}
if response.IsErrorState() {
return nil, fmt.Errorf("发送请求失败:%v", response.Status)
}
if res.Code != types.Success {
return nil, fmt.Errorf("获取白名单失败:%v", res.Message)
}
return res.Data, nil
}
// GetLicense 获取许可信息
func (s *LicenseService) GetLicense() *types.License {
return s.license
}
// IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行
return nil
//if s.license.IsActive {
// if s.license.MachineId != s.machineId {
// return errors.New("系统使用了盗版的许可证书")
// }
//
// if time.Now().Unix() > s.license.ExpiredAt {
// return errors.New("系统许可证书已经过期")
// }
// return nil
//}
//
//if len(s.urlWhiteList) == 0 {
// urls, err := s.fetchUrlWhiteList()
// if err == nil {
// s.urlWhiteList = urls
// }
//}
//
//for _, v := range s.urlWhiteList {
// if strings.HasPrefix(uri, v) {
// return nil
// }
//}
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
}

View File

@@ -1,233 +0,0 @@
package mj
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
discordgo "github.com/bg5t/mydiscordgo"
"github.com/gorilla/websocket"
"net/http"
"net/url"
"regexp"
"strings"
)
// MidJourney 机器人
var logger = logger2.GetLogger()
type Bot struct {
config types.MidJourneyConfig
bot *discordgo.Session
name string
service *Service
}
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
bot, err := discordgo.New("Bot " + config.BotToken)
if err != nil {
logger.Error(err)
return nil, err
}
// use CDN reverse proxy
if config.UseCDN {
discordgo.SetEndpointDiscord(config.DiscordAPI)
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
bot.MjGateway = config.DiscordGateway + "/"
} else { // use proxy
discordgo.SetEndpointDiscord("https://discord.com")
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
bot.MjGateway = "wss://gateway.discord.gg"
if proxy != "" {
proxy, _ := url.Parse(proxy)
bot.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
bot.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
}
}
}
return &Bot{
config: config,
bot: bot,
name: name,
service: service,
}, nil
}
func (b *Bot) Run() error {
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate)
logger.Infof("Starting MidJourney %s", b.name)
err := b.bot.Open()
if err != nil {
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
return err
}
logger.Infof("Starting MidJourney %s successfully!", b.name)
return nil
}
type TaskStatus string
const (
Start = TaskStatus("Started")
Running = TaskStatus("Running")
Stopped = TaskStatus("Stopped")
Finished = TaskStatus("Finished")
)
type Image struct {
URL string `json:"url"`
ProxyURL string `json:"proxy_url"`
Filename string `json:"filename"`
Width int `json:"width"`
Height int `json:"height"`
Size int `json:"size"`
Hash string `json:"hash"`
}
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author == nil || m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
// parse content
req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: 0,
Status: Start}
b.service.Notify(req)
return
}
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author == nil || m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Stopped)") {
req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: extractProgress(m.Content),
Status: Stopped}
b.service.Notify(req)
return
}
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
progress := extractProgress(content)
var status TaskStatus
if progress == 100 {
status = Finished
} else {
status = Running
}
for _, attachment := range attachments {
if attachment.Width == 0 || attachment.Height == 0 {
continue
}
image := Image{
URL: attachment.URL,
Height: attachment.Height,
ProxyURL: attachment.ProxyURL,
Width: attachment.Width,
Size: attachment.Size,
Filename: attachment.Filename,
Hash: extractHashFromFilename(attachment.Filename),
}
req := CBReq{
ChannelId: channelId,
MessageId: messageId,
ReferenceId: referenceId,
Image: image,
Prompt: extractPrompt(content),
Content: content,
Progress: progress,
Status: status,
}
b.service.Notify(req)
break // only get one image
}
}
// extract prompt from string
func extractPrompt(input string) string {
pattern := `\*\*(.*?)\*\*`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
func extractProgress(input string) int {
pattern := `\((\d+)\%\)`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return utils.IntValue(matches[1], 0)
}
return 100
}
func extractHashFromFilename(filename string) string {
if !strings.HasSuffix(filename, ".png") {
return ""
}
index := strings.LastIndex(filename, "_")
if index != -1 {
return filename[index+1 : len(filename)-4]
}
return ""
}

View File

@@ -1,159 +1,68 @@
package mj
import (
"chatplus/core/types"
"errors"
"fmt"
"time"
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"github.com/imroc/req/v3"
)
import "geekai/core/types"
// MidJourney client
type Client struct {
client *req.Client
Config types.MidJourneyConfig
apiURL string
type Client interface {
Imagine(task types.MjTask) (ImageRes, error)
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error)
Upscale(task types.MjTask) (ImageRes, error)
Variation(task types.MjTask) (ImageRes, error)
QueryTask(taskId string) (QueryRes, error)
}
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second)
var apiURL string
// set proxy URL
if config.UseCDN {
apiURL = config.DiscordAPI + "/api/v9/interactions"
} else {
apiURL = "https://discord.com/api/v9/interactions"
if proxy != "" {
client.SetProxyURL(proxy)
}
}
return &Client{client: client, Config: config, apiURL: apiURL}
type ImageReq struct {
BotType string `json:"botType,omitempty"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter interface{} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook,omitempty"`
State string `json:"state,omitempty"`
}
func (c *Client) Imagine(task types.MjTask) error {
interactionsReq := &InteractionsRequest{
Type: 2,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
SessionID: SessionID,
Data: map[string]any{
"version": "1166847114203123795",
"id": "938956540159881230",
"name": "imagine",
"type": "1",
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
},
},
"application_command": map[string]any{
"id": "938956540159881230",
"application_id": ApplicationID,
"version": "1118961510123847772",
"default_permission": true,
"default_member_permissions": nil,
"type": 1,
"nsfw": false,
"name": "imagine",
"description": "Create images with Midjourney",
"dm_permission": true,
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"description": "The prompt to imagine",
"required": true,
},
},
"attachments": []any{},
},
},
}
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %w%v", err, r.Err)
}
return nil
type ImageRes struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
func (c *Client) Blend(task types.MjTask) error {
return errors.New("function not implemented")
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func (c *Client) SwapFace(task types.MjTask) error {
return errors.New("function not implemented")
}
// Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
type QueryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}

View File

@@ -1,198 +0,0 @@
package plus
import (
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
)
// Service MJ 绘画服务
type Service struct {
Name string // service Name
Client *Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
HandledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: client,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0),
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// if it's reference message, check if it's this channel's message
//if task.ChannelId != "" && task.ChannelId != s.Name {
// logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
// s.taskQueue.RPush(task)
// time.Sleep(time.Second)
// continue
//}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
var job model.MidJourneyJob
s.db.Where("id = ?", task.Id).First(&job)
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(task.UserId)
continue
}
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout
s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道
job.TaskId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.HandledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
if s.Client.Config.CdnURL != "" {
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
} else {
job.OrgURL = task.ImageUrl
}
}
job.MessageId = task.Id
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&s.HandledTaskNum, -1)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -1,74 +1,60 @@
package plus
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
var logger = logger2.GetLogger()
// Client MidJourney Plus Client
type Client struct {
Config types.MidJourneyPlusConfig
apiURL string
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
licenseService *service.LicenseService
}
func NewClient(config types.MidJourneyPlusConfig) *Client {
var apiURL string
if config.CdnURL != "" {
apiURL = config.CdnURL
} else {
apiURL = config.ApiURL
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
}
if config.Mode == "" {
config.Mode = "fast"
}
func (c *PlusClient) preCheck() error {
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
return &Client{Config: config, apiURL: apiURL}
}
type ImageReq struct {
BotType string `json:"botType"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
type ImageRes struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: task.Prompt,
NotifyHook: c.Config.NotifyURL,
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
@@ -81,18 +67,17 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -104,12 +89,16 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
@@ -125,15 +114,14 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -144,7 +132,11 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
@@ -171,20 +163,18 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
"accountFilter": gin.H{
"instanceId": "",
},
"notifyHook": c.Config.NotifyURL,
"state": "",
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -195,16 +185,20 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
}
// Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
@@ -222,13 +216,17 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
@@ -248,35 +246,10 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
return res, nil
}
type QueryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
@@ -290,3 +263,5 @@ func (c *Client) QueryTask(taskId string) (QueryRes, error) {
return res, nil
}
var _ Client = &PlusClient{}

View File

@@ -1,12 +1,21 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service/mj/plus"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"github.com/go-redis/redis/v8"
"strings"
"time"
@@ -16,64 +25,21 @@ import (
// ServicePool Mj service pool
type ServicePool struct {
services []interface{}
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
licenseService *service.LicenseService
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]interface{}, 0)
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
for k, config := range appConfig.MjPlusConfigs {
if config.Enabled == false {
continue
}
client := plus.NewClient(config)
name := fmt.Sprintf("mj-service-plus-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
go func() {
servicePlus.Run()
}()
services = append(services, servicePlus)
}
if len(services) == 0 {
// create mj client and service
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
}
}
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
@@ -81,22 +47,59 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
licenseService: licenseService,
}
}
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range plusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config, p.licenseService)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
plusService.Run()
}()
p.services = append(p.services, plusService)
}
// for mid-journey proxy
for k, config := range proxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
proxyService.Run()
}()
p.services = append(p.services, proxyService)
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(userId)
if client == nil {
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = client.Send([]byte("Task Updated"))
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
@@ -120,17 +123,23 @@ func (p *ServicePool) DownloadImages() {
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
mjService := p.getService(v.ChannelId)
if mjService == nil {
logger.Errorf("Invalid task: %+v", v)
continue
}
task, _ := mjService.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
@@ -141,11 +150,11 @@ func (p *ServicePool) DownloadImages() {
v.ImgURL = imgURL
p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
if client == nil {
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = client.Send([]byte("Task Updated"))
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}
@@ -167,25 +176,6 @@ func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
func (p *ServicePool) Notify(data plus.CBReq) error {
logger.Debugf("收到任务回调:%+v", data)
var job model.MidJourneyJob
res := p.db.Where("task_id = ?", data.Id).First(&job)
if res.Error != nil {
return fmt.Errorf("非法任务:%s", data.Id)
}
// 任务已经拉取完成
if job.Progress == 100 {
return nil
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(job)
}
return nil
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
@@ -200,10 +190,7 @@ func (p *ServicePool) SyncTaskProgress() {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)
// 略过 Upscale 任务
if job.Type != types.TaskUpscale.String() {
continue
}
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
@@ -220,28 +207,23 @@ func (p *ServicePool) SyncTaskProgress() {
CreatedAt: time.Now(),
})
}
}
if !strings.HasPrefix(job.ChannelId, "mj-service-plus") {
continue
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second)
time.Sleep(time.Second * 10)
}
}()
}
func (p *ServicePool) getServicePlus(name string) *plus.Service {
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if servicePlus, ok := s.(*plus.Service); ok {
if servicePlus.Name == name {
return servicePlus
}
if s.Name == name {
return s
}
}
return nil

View File

@@ -0,0 +1,185 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -1,11 +1,21 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
@@ -13,41 +23,28 @@ import (
// Service MJ 绘画服务
type Service struct {
name string // service name
client *Client // MJ client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
running bool
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
return &Service{
name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
client: client,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0),
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
running: true,
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.name)
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for s.running {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
@@ -55,124 +52,153 @@ func (s *Service) Run() {
continue
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task)
break
case types.TaskUpscale:
err = s.client.Upscale(task)
break
case types.TaskVariation:
err = s.client.Variation(task)
break
case types.TaskBlend:
err = s.client.Blend(task)
break
case types.TaskSwapFace:
err = s.client.SwapFace(task)
break
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(task.UserId)
// restore img_call quota
if task.Type.String() != types.TaskUpscale.String() {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
logger.Infof("Task Executed: %+v", task)
// lock the task until the execute timeout
s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
func (s *Service) Notify(data CBReq) {
// extract the task ID
split := strings.Split(data.Prompt, " ")
var job model.MidJourneyJob
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil && data.Status == Finished {
logger.Warn("重复消息:", data.MessageId)
return
}
tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC")
if data.ReferenceId != "" {
tx = tx.Where("reference_id = ?", data.ReferenceId)
} else {
tx = tx.Where("task_id = ?", split[0])
}
// fixed: 修复 U/V 操作任务混淆覆盖的 Bug
if strings.Contains(data.Prompt, "** - Image #") { // for upscale
tx = tx.Where("type = ?", types.TaskUpscale.String())
} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations
tx = tx.Where("type = ?", types.TaskVariation.String())
}
res = tx.First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
job.ChannelId = data.ChannelId
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
if s.client.Config.UseCDN {
job.UseProxy = true
job.OrgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL)
} else {
job.OrgURL = data.Image.URL
}
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
if data.Status == Finished {
// release lock task
atomic.AddInt32(&s.handledTaskNum, -1)
}
s.notifyQueue.RPush(job.UserId)
func (s *Service) Stop() {
s.running = false
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -1,35 +0,0 @@
package mj
const (
ApplicationID string = "936929561302675456"
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
)
type InteractionsRequest struct {
Type int `json:"type"`
ApplicationID string `json:"application_id"`
MessageFlags int `json:"message_flags,omitempty"`
MessageID string `json:"message_id,omitempty"`
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
SessionID string `json:"session_id"`
Data map[string]any `json:"data"`
Nonce string `json:"nonce,omitempty"`
}
type InteractionsResult struct {
Code int `json:"code"`
Message string
Error map[string]any
}
type CBReq struct {
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`
Content string `json:"content"`
Prompt string `json:"prompt"`
Status TaskStatus `json:"status"`
Progress int `json:"progress"`
}

View File

@@ -1,10 +1,18 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
@@ -101,6 +109,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {

View File

@@ -1,15 +1,22 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/gin-gonic/gin"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
)
type LocalStorage struct {
@@ -73,6 +80,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
err = os.WriteFile(filePath, imageData, 0644)
if err != nil {
return "", fmt.Errorf("error writing to file:%v", err)
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) Delete(fileURL string) error {
if _, err := os.Stat(fileURL); err == nil {
return os.Remove(fileURL)

View File

@@ -1,10 +1,18 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
@@ -96,6 +104,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s MiniOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
objectKey,
strings.NewReader(string(imageData)),
int64(len(imageData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {

View File

@@ -1,11 +1,19 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
@@ -112,6 +120,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {

View File

@@ -1,5 +1,12 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
@@ -17,5 +24,6 @@ type File struct {
type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@@ -1,7 +1,14 @@
package oss
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"geekai/core/types"
"strings"
)

View File

@@ -1,9 +1,16 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/smartwalle/alipay/v3"
"log"
"net/url"

View File

@@ -1,12 +1,19 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
@@ -42,6 +49,8 @@ type HuPiPayReq struct {
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
}
type HuPiResp struct {

View File

@@ -1,12 +1,19 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"

View File

@@ -1,11 +1,18 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"geekai/core/types"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"time"
"github.com/go-redis/redis/v8"
@@ -18,28 +25,14 @@ type ServicePool struct {
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploader *oss.UploaderManager
levelDB *store.LevelDB
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
// run sd service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: taskQueue,
@@ -47,6 +40,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
uploader: manager,
levelDB: levelDB,
}
}
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range configs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf(" sd-service-%d", k)
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
// run sd service
go func() {
service.Run()
}()
p.services = append(p.services, service)
}
}
@@ -58,17 +77,18 @@ func (p *ServicePool) PushTask(task types.SdTask) {
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
var message NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(userId)
client := p.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
@@ -79,6 +99,7 @@ func (p *ServicePool) CheckTaskNotify() {
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
@@ -111,7 +132,7 @@ func (p *ServicePool) CheckTaskStatus() {
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@@ -1,17 +1,21 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"encoding/json"
"fmt"
"io"
"os"
"strconv"
"sync/atomic"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"time"
"github.com/imroc/req/v3"
@@ -21,50 +25,62 @@ import (
// SD 绘画服务
type Service struct {
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
leveldb *store.LevelDB
running bool // 运行状态
}
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{
name: name,
config: config,
httpClient: req.C(),
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time),
name: name,
config: config,
httpClient: req.C(),
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
leveldb: levelDB,
uploadManager: manager,
running: true,
}
}
func (s *Service) Run() {
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
for s.running {
var task types.SdTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
@@ -74,239 +90,158 @@ func (s *Service) Run() {
"progress": -1,
"err_msg": err.Error(),
})
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
// 通知前端,任务失败
s.notifyQueue.RPush(task.UserId)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
continue
}
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
func (s *Service) Stop() {
s.running = false
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
// Txt2ImgReq 文生图请求实体
type Txt2ImgReq struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Seed int64 `json:"seed,omitempty"`
Steps int `json:"steps"`
CfgScale float32 `json:"cfg_scale"`
Width int `json:"width"`
Height int `json:"height"`
SamplerName string `json:"sampler_name"`
Scheduler string `json:"scheduler"`
EnableHr bool `json:"enable_hr,omitempty"`
HrScale int `json:"hr_scale,omitempty"`
HrUpscaler string `json:"hr_upscaler,omitempty"`
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
ForceTaskId string `json:"force_task_id,omitempty"`
}
// Txt2ImgResp 文生图响应实体
type Txt2ImgResp struct {
Images []string `json:"images"`
Parameters struct {
} `json:"parameters"`
Info string `json:"info"`
}
// TaskProgressResp 任务进度响应实体
type TaskProgressResp struct {
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
}
// Txt2Img 文生图 API
func (s *Service) Txt2Img(task types.SdTask) error {
var taskInfo TaskInfo
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
if err != nil {
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
body := Txt2ImgReq{
Prompt: task.Params.Prompt,
NegativePrompt: task.Params.NegPrompt,
Steps: task.Params.Steps,
CfgScale: task.Params.CfgScale,
Width: task.Params.Width,
Height: task.Params.Height,
SamplerName: task.Params.Sampler,
Scheduler: task.Params.Scheduler,
ForceTaskId: task.Params.TaskId,
}
err = json.Unmarshal(bytes, &taskInfo)
if err != nil {
return fmt.Errorf("error with decode json params: %s", err.Error())
if task.Params.Seed > 0 {
body.Seed = task.Params.Seed
}
data := taskInfo.Data
params := task.Params
data[ParamKeys["task_id"]] = params.TaskId
data[ParamKeys["prompt"]] = params.Prompt
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
data[ParamKeys["steps"]] = params.Steps
data[ParamKeys["sampler"]] = params.Sampler
// @fix bug: 有些 stable diffusion 没有面部修复功能
//data[ParamKeys["face_fix"]] = params.FaceFix
data[ParamKeys["cfg_scale"]] = params.CfgScale
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSteps
taskInfo.SessionId = task.SessionId
taskInfo.TaskId = params.TaskId
taskInfo.Data = data
taskInfo.JobId = task.Id
taskInfo.UserId = uint(task.UserId)
if task.Params.HdFix {
body.EnableHr = true
body.HrScale = task.Params.HdScale
body.HrUpscaler = task.Params.HdScaleAlg
body.HrSecondPassSteps = task.Params.HdSteps
body.DenoisingStrength = task.Params.HdRedrawRate
}
var res Txt2ImgResp
var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
logger.Debugf("send image request to %s", apiURL)
// send a request to sd api endpoint
go func() {
s.runTask(taskInfo, s.httpClient)
}()
return nil
}
// 执行任务
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
cbReq.Message = "error with send request: " + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- err
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error http code status: %v", response.Status)
return
}
var images []struct {
Name string `json:"name"`
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
// 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil {
cbReq.Message = "error with decode image:" + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error with upload image: %v", err)
return
}
var info map[string]any
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
// 获取绘画真实的 seed
var info map[string]interface{}
err = utils.JsonDecode(res.Info, &info)
if err != nil {
logger.Error(res.Data)
cbReq.Message = "error with decode image url:" + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error with decode task response: %v", err)
return
}
// 获取真实的 seed 值
cbReq.ImageName = images[0].Name
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
cbReq.Seed = seed
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
errChan <- nil
}()
// waiting for task finish
for {
select {
case value := <-result:
s.callback(value)
return
case err := <-errChan:
if err != nil {
return err
}
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
default:
var progressReq = map[string]any{
"id_task": taskInfo.TaskId,
"id_live_preview": 1,
err, resp := s.checkTaskProgress()
// 更新任务进度
if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
}
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
s.callback(cbReq)
time.Sleep(time.Second)
}
}
}
func (s *Service) callback(data CBReq) {
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
if data.Success { // 任务成功
var job model.SdJob
res := s.db.Where("id = ?", data.JobId).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
// 更新任务进度
job.Progress = data.Progress
// 更新任务 seed
var params types.SdTaskParams
err := utils.JsonDecode(job.Params, &params)
if err != nil {
logger.Error("任务解析失败:", err)
return
}
params.Seed = data.Seed
if data.ImageName != "" { // 下载图片
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
if data.Progress == 100 {
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imageURL
}
}
job.Params = utils.JsonEncode(params)
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
logger.Debugf("绘图进度:%d", data.Progress)
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// update the task progress
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": data.Message,
})
// 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
var res TaskProgressResp
response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return err, nil
}
if response.IsErrorState() {
return fmt.Errorf("error http code status: %v", response.Status), nil
}
// 发送更新状态信号
s.notifyQueue.RPush(data.UserId)
return nil, &res
}

View File

@@ -1,47 +1,24 @@
package sd
import logger2 "chatplus/logger"
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import logger2 "geekai/logger"
var logger = logger2.GetLogger()
type TaskInfo struct {
UserId uint `json:"user_id"`
SessionId string `json:"session_id"`
JobId int `json:"job_id"`
TaskId string `json:"task_id"`
Data []interface{} `json:"data"`
EventData interface{} `json:"event_data"`
FnIndex int `json:"fn_index"`
SessionHash string `json:"session_hash"`
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
type CBReq struct {
UserId uint
SessionId string
JobId int
TaskId string
ImageName string
ImageData string
Progress int
Seed int64
Success bool
Message string
}
var ParamKeys = map[string]int{
"task_id": 0,
"prompt": 1,
"negative_prompt": 2,
"steps": 4,
"sampler": 5,
"face_fix": 7, // 面部修复
"cfg_scale": 8,
"seed": 27,
"height": 10,
"width": 9,
"hd_fix": 11,
"hd_redraw_rate": 12, //高清修复重绘幅度
"hd_scale": 13, // 高清修复放大倍数
"hd_scale_alg": 14, // 高清修复放大算法
"hd_sample_num": 15, // 高清修复采样次数
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)

View File

@@ -1,8 +1,15 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"fmt"
"geekai/core/types"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
)

View File

@@ -1,9 +1,16 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"

View File

@@ -1,5 +1,12 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const Ali = "ALI"
const Bao = "BAO"

View File

@@ -1,8 +1,15 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"geekai/core/types"
logger2 "geekai/logger"
"strings"
)

View File

@@ -1,11 +1,20 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"chatplus/core/types"
"crypto/tls"
"fmt"
"geekai/core/types"
"mime"
"net/smtp"
"net/textproto"
)
type SmtpService struct {
@@ -19,12 +28,18 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
}
func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := "ChatPlus注册验证码"
body := fmt.Sprintf("您正在注册 ChatPlus AI 助手账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", code)
subject := "Geek-AI 注册验证码"
body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", code)
// 设置SMTP客户端配置
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
if s.config.UseTls {
return s.sendTLS(auth, to, subject, body)
} else {
return s.send(auth, to, subject, body)
}
}
func (s *SmtpService) send(auth smtp.Auth, to string, subject string, body string) error {
// 对主题进行MIME编码
encodedSubject := mime.QEncoding.Encode("UTF-8", subject)
// 组装邮件
@@ -34,11 +49,83 @@ func (s *SmtpService) SendVerifyCode(to string, code int) error {
message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject))
message.WriteString("\r\n" + body)
// 发送邮件
// 发送邮件
err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes())
if err != nil {
return fmt.Errorf("error sending email: %v", err)
}
return err
}
func (s *SmtpService) sendTLS(auth smtp.Auth, to string, subject string, body string) error {
// TLS配置
tlsConfig := &tls.Config{
ServerName: s.config.Host,
}
// 建立TLS连接
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig)
if err != nil {
return fmt.Errorf("error connecting to SMTP server: %v", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, s.config.Host)
if err != nil {
return fmt.Errorf("error creating SMTP client: %v", err)
}
defer client.Quit()
// 身份验证
if err = client.Auth(auth); err != nil {
return fmt.Errorf("error authenticating: %v", err)
}
// 设置寄件人
if err = client.Mail(s.config.From); err != nil {
return fmt.Errorf("error setting sender: %v", err)
}
// 设置收件人
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("error setting recipient: %v", err)
}
// 发送邮件内容
wc, err := client.Data()
if err != nil {
return fmt.Errorf("error getting data writer: %v", err)
}
defer wc.Close()
header := make(textproto.MIMEHeader)
header.Set("From", s.config.From)
header.Set("To", to)
header.Set("Subject", subject)
// 将邮件头写入
for key, values := range header {
for _, value := range values {
_, err = fmt.Fprintf(wc, "%s: %s\r\n", key, value)
if err != nil {
return fmt.Errorf("error sending email header: %v", err)
}
}
}
_, _ = fmt.Fprintln(wc)
// 将邮件内容写入
_, err = fmt.Fprintf(wc, body)
if err != nil {
return fmt.Errorf("error sending email: %v", err)
}
// 发送完毕
err = wc.Close()
if err != nil {
return fmt.Errorf("error closing data writer: %v", err)
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More