Compare commits

...

162 Commits

Author SHA1 Message Date
RockYang
5213bdf08b fixed bug for calculate chat message tokens 2025-01-19 13:08:38 +08:00
RockYang
cf817fd8ea merge code for v4.1.4 2025-01-18 23:20:41 +08:00
RockYang
a2481ff1cf update readme 2025-01-07 11:58:16 +08:00
RockYang
bc7d06d3e5 修复登录失效的 Bug 2024-12-21 21:44:52 +08:00
RockYang
8e81dfa12a update database name 2024-12-16 11:26:54 +08:00
RockYang
0ff76f0f21 update database file 2024-12-16 11:11:42 +08:00
RockYang
787caa84c8 merge config.toml file 2024-12-16 10:09:23 +08:00
RockYang
c2503e663a merge v4.1.3 2024-12-16 10:07:52 +08:00
RockYang
405a88862b Merge branch 'dev' 2024-11-27 15:04:26 +08:00
RockYang
296eabe09a merge v4.1.2 2024-11-27 15:00:02 +08:00
RockYang
54b45ec2ff update docker-compose.yaml 2024-11-14 16:37:43 +08:00
RockYang
c434f85045 update database 2024-11-14 16:01:27 +08:00
RockYang
4d10279870 merge v4.1.1 and fixed conflicts 2024-11-13 18:40:04 +08:00
RockYang
9de9489673 remove sensitive words 2024-11-04 10:05:57 +08:00
RockYang
9814fec930 update the default api url to https://api.geekai.pro 2024-10-29 14:09:55 +08:00
RockYang
53ba731159 update database sql file 2024-10-29 14:07:49 +08:00
RockYang
b2f57aa483 merge v4.1.0 and fixed conflicts 2024-10-08 17:54:08 +08:00
RockYang
4c2dba1004 merge v4.1.0 and fixed conflicts 2024-10-08 17:51:14 +08:00
RockYang
283a023a06 update database sql file for v4.1.4 2024-09-23 15:54:22 +08:00
RockYang
d315edef5f logout the user when it has been disabled 2024-09-20 16:49:03 +08:00
RockYang
5fa17b300e add release v4.1.4 2024-09-20 15:50:04 +08:00
RockYang
32919de7a7 add email white list check in register handler 2024-09-20 14:10:40 +08:00
RockYang
7d126aab41 add email white list 2024-09-20 10:40:37 +08:00
RockYang
16ac57ced3 payment for mobile page is ready 2024-09-19 19:03:03 +08:00
RockYang
4976b967e7 wechat payment for mobile page is ready 2024-09-19 17:59:27 +08:00
RockYang
e874178782 wechat payment for pc is ready 2024-09-19 14:42:25 +08:00
RockYang
8cb66ad01b fixed bug geek-plus#6, register page first tab not auto active 2024-09-19 09:03:14 +08:00
RockYang
f887a39912 geek payment notify api is ready 2024-09-18 22:24:05 +08:00
RockYang
2beffd3dd3 urgent bug fix: remove suno and luma task will recharge user power 2024-09-18 20:33:29 +08:00
RockYang
d8cb92d8d4 Geek Pay notify is ready 2024-09-18 18:07:49 +08:00
RockYang
158db83965 add geek payment 2024-09-18 07:03:46 +08:00
RockYang
603bfa7def recommend user to use Google Chrome 2024-09-15 10:25:50 +08:00
RockYang
829fb879a6 merge app type function branch 2024-09-14 18:17:55 +08:00
RockYang
0385e60ce1 refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model 2024-09-14 17:06:13 +08:00
胡双明
aaea23f785 feat: 应用分类功能 2024-09-14 11:05:49 +08:00
RockYang
131efd6ba5 refactor chat message body struct 2024-09-14 07:11:45 +08:00
RockYang
866564370d return at least one chat role for getUserRoles API 2024-09-14 05:55:56 +08:00
RockYang
dcdc0d8918 add tid field for chat app role 2024-09-13 18:32:13 +08:00
RockYang
6c7fa17e50 fixed bug, filelist page support pagination, do not load captcha component for user login first time 2024-09-13 17:03:05 +08:00
胡双明
38a0d00142 Merge branch 'main' into husm_2024-09-02 2024-09-13 10:25:15 +08:00
RockYang
5c77e67b0f fixed bug for reset password 2024-09-12 17:25:19 +08:00
RockYang
961cee5e41 fixed bug for register page code verification 2024-09-12 15:42:09 +08:00
胡双明
c9cc93be8c Merge branch 'main' into husm_2024-09-02 2024-09-11 09:57:01 +08:00
RockYang
49f2e1a71e optimize download function for suno 2024-09-11 08:39:28 +08:00
胡双明
97eff6085a feat: 210 AI对话页面文件列表增加分页功能 2024-09-10 18:14:34 +08:00
RockYang
8b2e2d61af file list api support pagination 2024-09-10 15:24:36 +08:00
胡双明
c096efb416 Merge branch 'main' into husm_2024-09-02 2024-09-10 14:29:08 +08:00
RockYang
cdaf6fb9dc update v4.1.3 database sql file 2024-09-10 11:15:26 +08:00
RockYang
78f443ed6d update v4.1.3 database sql file 2024-09-10 11:11:17 +08:00
RockYang
54e8d72b10 support multiple delete users, update database sql file 2024-09-10 10:56:04 +08:00
RockYang
05161f48fd update config file 2024-09-09 18:58:03 +08:00
RockYang
e971bf6b88 merge luma page code for v4.1.3 2024-09-09 18:07:10 +08:00
RockYang
55b979784c Merge remote-tracking branch 'inet/husm_2024-09-02' into dev-4.1.3 2024-09-09 10:54:43 +08:00
RockYang
79adc871ef export the newest database sql file 2024-09-05 15:29:00 +08:00
RockYang
97aa922b5f 优化聊天页面代码,刷新页面之后自动加载当前对话的 role_id 和 model_id 2024-09-05 14:50:37 +08:00
胡双明
11c760a4e8 feat: 生成视频页面2 2024-09-05 11:56:02 +08:00
RockYang
8144fada25 merged v4.0.9 and fixed conflicts 2024-09-05 11:02:32 +08:00
RockYang
87b03332d9 add auto execute task to downloading video files 2024-09-05 10:54:58 +08:00
胡双明
8b14eeadf4 feat: 生成视频页面 2024-09-05 08:50:49 +08:00
RockYang
e0ead127e0 remove unused css file 2024-09-04 22:42:56 +08:00
RockYang
0887bcdee0 adjust chat page styles 2024-09-04 18:07:39 +08:00
RockYang
67d83041d7 user can select function tools by themself 2024-09-04 14:53:21 +08:00
RockYang
1350f388f0 add page and total field for pagination vo 2024-09-03 18:32:15 +08:00
RockYang
65dde9e69d add sync lock for sub or add user's power 2024-09-03 12:09:36 +08:00
RockYang
2e5bd238b7 luma create and list api is ready 2024-09-02 18:08:50 +08:00
胡双明
8fc8fd6cba feat(Luma): 加上传前后帧调换功能 2024-09-02 17:33:03 +08:00
RockYang
dfc6c87250 optimize ChatPlus page, fixed bug for websocket reconnection 2024-09-02 16:35:15 +08:00
RockYang
b63e01225e add luma api service 2024-08-30 18:12:14 +08:00
RockYang
561b82027a update change log file 2024-08-30 16:47:52 +08:00
RockYang
f6d8fbf570 suno add new function for merging full songs and upload custom music 2024-08-30 16:46:48 +08:00
RockYang
568201ebbb add logo for favirate icon 2024-08-29 13:36:35 +08:00
RockYang
ab421f2185 luma page, upload image and remove image function is ready 2024-08-26 17:59:05 +08:00
RockYang
f71a2f5263 add download for video 2024-08-26 07:24:04 +08:00
RockYang
d000cc5a67 luma page video list component is ready 2024-08-23 18:25:58 +08:00
RockYang
04d6ba0853 luma page is ready 2024-08-20 14:31:40 +08:00
RockYang
8d7c028ca8 refactor reset password functions 2024-08-19 12:05:00 +08:00
RockYang
3ae7ebfeaf fixed styles 2024-08-19 06:42:53 +08:00
RockYang
aa42d38387 add bind mobile, bind email, bind wechat function is ready 2024-08-14 15:56:50 +08:00
RockYang
43843b92f2 add mobile and email filed for user 2024-08-13 18:40:50 +08:00
RockYang
5da879600a add verification code for login and register page 2024-08-13 14:55:47 +08:00
RockYang
87ed2064e3 add Captcha components 2024-08-13 10:01:07 +08:00
RockYang
34e96e91d4 fixed conflicts 2024-08-13 06:48:22 +08:00
RockYang
8c4c2b89ce add wechat login for login dialog 2024-08-12 18:00:34 +08:00
RockYang
373021c191 add drag icon for dragable rows 2024-08-12 14:00:50 +08:00
RockYang
740c3c1b00 release v4.1.2 2024-08-09 18:50:02 +08:00
RockYang
67c7132e6b redeem code function is ready 2024-08-09 18:19:51 +08:00
RockYang
c77843424b add redeem code function 2024-08-08 18:28:50 +08:00
RockYang
2d4959aa7d add cache for getting user info and system configs 2024-08-08 14:37:33 +08:00
RockYang
167c59a159 add clear unpaid order functions 2024-08-07 18:00:28 +08:00
RockYang
1d0006ce59 refactor stable diffusion service, use api key instead of configs 2024-08-07 17:30:59 +08:00
RockYang
6a8b4ee2f1 refactor midjourney service, use api key in database 2024-08-06 18:30:57 +08:00
RockYang
72b1515b68 show sql error message 2024-08-05 16:14:44 +08:00
RockYang
754ba02263 fixed build script 2024-08-01 18:45:53 +08:00
RockYang
7ddf57ae06 merge v4.0.8 2024-08-01 18:09:00 +08:00
RockYang
3f0252b498 update change log 2024-08-01 08:54:04 +08:00
RockYang
cc5180a6f7 update readme 2024-08-01 08:52:46 +08:00
RockYang
1d9d487f0e restore use power when removed not finish jobs 2024-07-31 16:08:46 +08:00
RockYang
96f1126d02 remove platform field for api key and chat model 2024-07-30 17:24:21 +08:00
RockYang
7f9b8d8246 update datebasesl 2024-07-30 14:55:49 +08:00
RockYang
5132d52a44 add back-to-top component for all list page 2024-07-29 11:00:53 +08:00
RockYang
1bcbf74883 remove chat debug log 2024-07-28 18:55:17 +08:00
RockYang
abdf5298fe add function to generate lyrics 2024-07-28 10:04:53 +08:00
RockYang
2129f7a8b7 song detail page is ready 2024-07-26 19:12:44 +08:00
RockYang
f6f8748521 adjust chat records layout styles 2024-07-25 11:01:27 +08:00
RockYang
59301df073 add put url file for oss interface 2024-07-23 18:36:26 +08:00
RockYang
e17dcf4d5f remove other platform supports, ONLY use chatGPT API 2024-07-22 18:36:58 +08:00
RockYang
09f44e6d9b optimize foot copyright snaps 2024-07-22 17:54:09 +08:00
RockYang
59824bffc5 add close button for music player 2024-07-22 07:12:21 +08:00
RockYang
cb0dacd5e0 enable use random pure color background for index page 2024-07-19 18:43:01 +08:00
RockYang
7463cfc66c the music player is ready 2024-07-18 18:34:11 +08:00
RockYang
b248560ba2 add suno page 2024-07-17 18:58:09 +08:00
RockYang
37368fe13f support upload file from clipboard 2024-07-17 10:23:02 +08:00
RockYang
246b023624 allow user to use chat role directly, no need to add to workspace 2024-07-16 18:28:08 +08:00
RockYang
9f44c34d34 update docs url 2024-07-16 18:15:34 +08:00
RockYang
a6b9f57a50 show error message for Midjourney task list page 2024-07-16 17:16:58 +08:00
RockYang
42bc23cacf fixed bug for function call for openai 2024-07-16 06:25:40 +08:00
RockYang
282f55c7a3 优化版权显示逻辑,允许激活用户更改自定义版权 2024-07-15 18:44:14 +08:00
RockYang
44798f89ba update docker-compose file 2024-07-12 18:13:43 +08:00
RockYang
596cb2b206 update database file, add tika host config 2024-07-12 18:10:32 +08:00
RockYang
d1965deff1 tidy apis 2024-07-12 14:39:14 +08:00
RockYang
b793b81768 update geekai image version 2024-07-05 11:09:13 +08:00
RockYang
a5ef4299ec wechat login is ready 2024-07-04 15:34:32 +08:00
RockYang
cdb1a8bde1 feat: support wechat login function 2024-07-02 18:27:06 +08:00
RockYang
64e5fc48ba docs: update change log file 2024-06-28 16:21:43 +08:00
RockYang
a692cf1338 feat: optimize chat page data list style, support list style and chat style 2024-06-28 15:53:49 +08:00
RockYang
6998dd7af4 feat: chat with file function is ready 2024-06-27 18:01:49 +08:00
RockYang
9343c73e0f enable set custom index background image 2024-06-27 10:49:31 +08:00
RockYang
739cd46539 add test code for reading pdf files 2024-06-26 18:50:48 +08:00
RockYang
f8fed83507 feat: new UI for chat file manager is ready 2024-06-25 18:59:27 +08:00
RockYang
d63536d5ef fixed bug: mobile chat list could not update chat title 2024-06-25 09:53:08 +08:00
RockYang
4905fb28d4 update version 2024-06-23 17:54:42 +08:00
RockYang
a3a2a8abcb add wechat payment configs sample 2024-06-22 16:04:04 +08:00
RockYang
839dd8dbf4 update docker-compose file 2024-06-22 12:37:41 +08:00
RockYang
0375164f40 add database files 2024-06-22 12:17:35 +08:00
RockYang
691294b444 finish mobile wechat payment 2024-06-22 12:10:43 +08:00
RockYang
c24b4d7074 update change log files 2024-06-14 18:23:54 +08:00
RockYang
ab24398748 wechat payment is ready for PC 2024-06-12 14:20:37 +08:00
RockYang
6110522b54 change payment component, upgrade golang to 1.22.4 2024-06-11 11:48:41 +08:00
RockYang
bcdf5e3776 fix bug: free model not record the chat history 2024-06-06 15:01:32 +08:00
RockYang
2207830db9 fixe page styles 2024-06-05 18:08:23 +08:00
RockYang
d52dfbfef4 fixed bug markmap generation 2024-06-04 16:21:08 +08:00
RockYang
66ccb387e8 dalle3 and gptt-4o api compatible with azure 2024-06-03 18:34:37 +08:00
RockYang
3cc2263dc7 fixed bug for function call error None is not of type 'array' 2024-05-30 09:59:44 +08:00
RockYang
f0a3c5d8ae fixed bug for mobile chat share 2024-05-30 08:37:14 +08:00
RockYang
2a4ef27774 add v4.0.8 database sql file 2024-05-29 17:41:37 +08:00
RockYang
2b057f32aa feat: add dalle3 page for h5 2024-05-29 17:25:01 +08:00
RockYang
bc6451026f feat: add system config for enable rand background image for index page 2024-05-29 16:24:56 +08:00
RockYang
99fd596862 feat: add system config for enable rand background image for index page 2024-05-29 16:23:42 +08:00
RockYang
f0959b5df6 fix markdown formula parse plugin 2024-05-29 13:49:45 +08:00
RockYang
1b0938b33f micro fixs 2024-05-27 17:39:17 +08:00
RockYang
02faff461a fixed bug for dalle prompt translate 2024-05-27 11:42:14 +08:00
RockYang
e18e5a38c6 put model and app selector on the top of chat page 2024-05-24 12:33:22 +08:00
RockYang
2f9b1b7835 fixed bug for payment api authorization 2024-05-24 11:31:38 +08:00
RockYang
717b137a6d chore: use config value for order pay timeout 2024-05-22 18:15:06 +08:00
RockYang
f755bdccae feat: add sign check for PC QR code payment 2024-05-22 17:47:53 +08:00
RockYang
4bba77ab47 extract code for saving chat history 2024-05-22 15:32:44 +08:00
RockYang
6944a32ff3 check if the api url in whitelist for mj plus client 2024-05-22 11:47:04 +08:00
RockYang
5742b40aee fixed bug for mobile chat page change chat model not work 2024-05-21 17:54:03 +08:00
RockYang
7f1ec90748 auto resize the input element rows, when use inputed more than one line 2024-05-21 17:36:47 +08:00
RockYang
bee19392c1 add logs for updating database failed 2024-05-21 11:55:38 +08:00
RockYang
00d31a2379 update docker image url 2024-05-21 11:03:11 +08:00
RockYang
5d65505ab7 rename project name to geekai 2024-05-20 15:11:14 +08:00
323 changed files with 19753 additions and 23035 deletions

View File

@@ -1,4 +1,86 @@
# 更新日志
## v4.1.4
* 功能优化:用户文件列表组件增加分页功能支持
* Bug修复修复用户注册失败Bug注册操作只弹出一次行为验证码
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
* 功能新增:给 AI 应用(角色)增加分类,前端支持分类筛选
* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
* 功能优化移除PayJS支付渠道支持PayJs已经关闭注册服务请使用其他支付方式。
* 功能新增新增GeeK易支付支付渠道支持支付宝微信支付QQ钱包京东支付抖音支付Paypal支付等支付方式
* Bug修复修复注册页面 tab 组件没有自动选中问题 [#6](https://github.com/yangjian102621/geekai-plus/issues/6)
* 功能优化Luma生成视频任务增加自动翻译功能
* Bug修复Suno 和 Luma 任务没有判断用户算力
* 功能新增:邮箱注册增加邮箱后缀白名单,防止使用某些垃圾邮箱注册薅羊毛
* 功能优化清空未支付订单时只清空超过15分钟未支付的订单
## v4.1.3
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
* 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
* 功能优化:管理后台给可以拖动排序的组件添加拖动图标
* 功能优化Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
* Bug修复手机端角色和模型选择不生效
* Bug修复用户登录过期之后聊天页面出现大量报错需要刷新页面才能正常
* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
* 功能优化:给算力增减服务全部加上数据库事务和同步锁
* 功能优化:支持用户在前端对话界面选择插件
* 功能新增:支持 Luma 文生视频功能
## v4.1.2
* Bug修复修复思维导图页面获取模型失败的问题
* 功能优化优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
* Bug修复修复后台拖动排序组件 Bug
* 功能优化:更新数据库失败时候显示具体的的报错信息
* Bug修复修复管理后台对话详情页内容显示异常问题
* 功能优化:管理后台新增清空所有未支付订单的功能
* 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
* 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
## v4.1.1
* Bug修复修复 GPT 模型 function call 调用后没有输出的问题
* 功能新增:允许获取 License 授权用户可以自定义版权信息
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
* 功能新增MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
* 功能新增:允许在管理后台设置首页显示的导航菜单
* Bug修复修复注册页面先显示关闭注册组件然后再显示注册组件
* 功能新增:增加 Suno 文生歌曲功能
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
* 功能优化:在所有列表页面增加返回顶部按钮
## v4.1.0
* bug修复修复移动端修改聊天标题不生效的问题
* Bug修复修复用户注册不显示用户名的问题
* Bug修复修复管理后台拖动排序不生效的问题
* 功能优化:允许用户设置自定义首页背景图片
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
* 功能优化:优化聊天界面的用户上传文件的列表样式
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
## v4.0.9
* 环境升级:升级 Golang 到 go1.22.4
* 功能增加:接入微信商户号支付渠道
* Bug修复修复前端页面菜单把页面撑开底部留白问题
* 功能优化:聊天页面自动根据内容调整输入框的高度
* Bug修复修复Dalle绘图失败退回算力的问题
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
* 功能优化:允许设置邮件验证码的抬头
* Bug修复修复免费模型不会记录聊天记录的bug
* Bug修复修复聊天输入公式显示异常的Bug
## v4.0.8
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
* 功能优化:当数据库更新失败的时候记录错误日志
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
* Bug修复修复移动端聊天页面模型切换不生效的Bug
* 功能优化给PC端扫码支付增加签名验证和有效期验证
* Bug修复修复支付码生成API权限控制的问题
* Bug修复模型算力设置为0时不扣减用户算力并且不记录算力消费日志
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
* 功能新增H5端支持 Dalle 绘图
## v4.0.7

View File

@@ -1,15 +1,14 @@
# GeekAI
### 本项目已经正式更名为 GeekAI请大家及时更新代码克隆地址
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Claude, 通义千问KimiDeepSeekGitee AI 等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
主要特性:
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
- 基于 Websocket 实现,完美的打字机体验。
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。
- 支持 OpenAI, Claude, 通义千问KimiDeepSeek等多个大语言模型**支持 Gitee AI Serverless 大模型 API**。
- 支持 Suno 文生音乐
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
@@ -26,63 +25,16 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
- [x] 支持网站 Logo 版权等信息的修改
## 功能截图
### PC 端聊天界面
![ChatGPT Chat Page](/docs/imgs/gpt.gif)
### AI 对话界面
![ChatGPT new Chat Page](/docs/imgs/chat-new.png)
### MidJourney 专业绘画界面
![mid-journey](/docs/imgs/mj_image.jpg)
### Stable-Diffusion 专业绘画页面
![Stable-Diffusion](/docs/imgs/sd_image.jpg)
![Stable-Diffusion](/docs/imgs/sd_image_detail.jpg)
### 绘图作品展
![ChatGPT image_list](/docs/imgs/image-list.png)
### AI应用列表
![ChatGPT-app-list](/docs/imgs/app-list.jpg)
### 会员充值
![会员充值](/docs/imgs/member.png)
### 自动调用函数插件
![ChatGPT function plugin](/docs/imgs/plugin.png)
![ChatGPT function plugin](/docs/imgs/mj.jpg)
### 管理后台
![ChatGPT admin](/docs/imgs/admin_dashboard.png)
![ChatGPT admin](/docs/imgs/admin_config.jpg)
![ChatGPT admin](/docs/imgs/admin_models.jpg)
![ChatGPT admin](/docs/imgs/admin_user.png)
### 移动端 Web 页面
![Mobile chat list](/docs/imgs/mobile_chat_list.png)
![Mobile chat session](/docs/imgs/mobile_chat_session.png)
![Mobile chat setting](/docs/imgs/mobile_user_profile.png)
![Mobile chat setting](/docs/imgs/mobile_pay.png)
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
### 体验地址
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
## 快速部署
请参考文档 [**GeekAI 快速部署**](https://ai.r9it.com/docs/install/)。
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
## 使用须知
@@ -101,14 +53,14 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
## TODOLIST
* [ ] 支持基于知识库的 AI 问答
* [ ] 会员邀请注册推广功能
* [ ] 文生视频,文生歌曲功能
* [ ] 微信支付功能
## 项目文档
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**GeekAI 文档**](https://ai.r9it.com/docs/)。
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
加微信进入微信讨论群可获取 **一键部署脚本添加好友时请注明来自Github!!!)。**

3
api/.gitignore vendored
View File

@@ -17,4 +17,5 @@ bin
data
config.toml
static/upload
storage.json
storage.json
res/certs/wechat/apiclient_key.pem

View File

@@ -3,8 +3,7 @@ ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = ""
WeChatBot = false
TikaHost = "http://tika:9998"
[Session]
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
@@ -17,7 +16,7 @@ WeChatBot = false
DB = 0
[ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通
ApiURL = ""
ApiURL = "https://sapi.geekai.me"
AppId = ""
Token = ""
@@ -64,23 +63,6 @@ WeChatBot = false
SubDir = ""
Domain = ""
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[[MjPlusConfigs]]
Enabled = false
ApiURL = "https://api.chat-plus.net"
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
@@ -89,6 +71,15 @@ WeChatBot = false
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
RegistryKey = "chatgpt-plus" # 任务注册 key
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
# 支付宝商户支付
[AlipayConfig]
Enabled = false # 启用支付宝支付通道
SandBox = false # 是否启用沙盒模式
@@ -98,28 +89,27 @@ WeChatBot = false
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
# 虎皮椒支付
[HuPiPayConfig]
Enabled = false
Name = "wechat"
AppId = ""
AppSecret = ""
ApiURL = "https://api.xunhupay.com"
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口请使用 465 端口,并开启 TLS 连接
UseTls = false
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
[JPayConfig] # PayJs 支付配置
# 微信商户支付
[WechatPayConfig]
Enabled = false
Name = "wechat" # 请不要改动
AppId = "" # 商户 ID
PrivateKey = "" # 秘钥
ApiURL = "https://payjs.cn"
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己
AppId = "" # 商户应用ID
MchId = "" # 商户
SerialNo = "" # API 证书序列号
PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置
# 易支付
[GeekPayConfig]
Enabled = true
AppId = "" # 商户ID
PrivateKey = "" # 商户私钥
ApiURL = "https://pay.geekai.cn"
Methods = ["alipay", "wxpay", "qqpay", "jdpay", "douyin", "paypal"] # 支持的支付方式

View File

@@ -9,12 +9,12 @@ package core
import (
"bytes"
"context"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
@@ -32,31 +32,19 @@ import (
)
type AppServer struct {
Debug bool
Config *types.AppConfig
Engine *gin.Engine
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
Debug bool
Config *types.AppConfig
Engine *gin.Engine
SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
}
func NewServer(appConfig *types.AppConfig) *AppServer {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
return &AppServer{
Debug: false,
Config: appConfig,
Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
Debug: false,
Config: appConfig,
Engine: gin.Default(),
}
}
@@ -77,13 +65,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
func (s *AppServer) Run(db *gorm.DB) error {
// load system configs
var sysConfig model.Config
res := db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
err := db.Where("marker", "system").First(&sysConfig).Error
if err != nil {
return err
return fmt.Errorf("failed to load system config: %v", err)
}
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil {
return fmt.Errorf("failed to decode system config: %v", err)
}
logger.Infof("http://%s", s.Config.Listen)
return s.Engine.Run(s.Config.Listen)
@@ -95,7 +83,7 @@ func errorHandler(c *gin.Context) {
if r := recover(); r != nil {
logger.Errorf("Handler Panic: %v", r)
debug.PrintStack()
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
c.JSON(http.StatusBadRequest, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
c.Abort()
}
}()
@@ -151,7 +139,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
if tokenString == "" {
if needLogin(c) {
resp.ERROR(c, "You should put Authorization in request headers")
resp.NotAuth(c, "You should put Authorization in request headers")
c.Abort()
return
} else { // 直接放行
@@ -213,11 +201,12 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/admin/logout" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/user/session" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/app/list" ||
c.Request.URL.Path == "/api/app/type/list" ||
c.Request.URL.Path == "/api/app/list/user" ||
c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
@@ -227,15 +216,23 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" ||
c.Request.URL.Path == "/api/markMap/client" ||
c.Request.URL.Path == "/api/payment/doPay" ||
c.Request.URL.Path == "/api/payment/payWays" ||
c.Request.URL.Path == "/api/suno/client" ||
c.Request.URL.Path == "/api/suno/detail" ||
c.Request.URL.Path == "/api/suno/play" ||
c.Request.URL.Path == "/api/download" ||
c.Request.URL.Path == "/api/video/client" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false
}
@@ -370,6 +367,7 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}

View File

@@ -38,7 +38,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload",
},
},
WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
}
}

View File

@@ -9,14 +9,14 @@ package types
// ApiRequest API 请求实体
type ApiRequest struct {
Model string `json:"model,omitempty"` // 兼容百度文心一言
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言
Stream bool `json:"stream"`
Messages []interface{} `json:"messages,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
Tools []Tool `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
Model string `json:"model,omitempty"`
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
Stream bool `json:"stream,omitempty"`
Messages []interface{} `json:"messages,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"`
@@ -53,23 +53,24 @@ type Delta struct {
// ChatSession 聊天会话对象
type ChatSession struct {
SessionId string `json:"session_id"`
UserId uint `json:"user_id"`
ClientIP string `json:"client_ip"` // 客户端 IP
Username string `json:"username"` // 当前登录的 username
UserId uint `json:"user_id"` // 当前登录的 user ID
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
Model ChatModel `json:"model"` // GPT 模型
Start int64 `json:"start"` // 开始请求时间戳
Tools []int `json:"tools"` // 工具函数列表
Stream bool `json:"stream"` // 是否采用流式输出
}
type ChatModel struct {
Id uint `json:"id"`
Platform Platform `json:"platform"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长
Temperature float32 `json:"temperature"` // 模型温度
KeyId int `json:"key_id"` // 绑定 API KEY
Id uint `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温
KeyId int `json:"key_id"` // 绑定 API KEY
}
type ApiError struct {
@@ -92,7 +93,7 @@ const (
PowerConsume = PowerType(2) // 消费
PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
PowerInvite = PowerType(4) // 邀请奖励
PowerReward = PowerType(5) // 众筹
PowerRedeem = PowerType(5) // 众筹
PowerGift = PowerType(6) // 系统赠送
)
@@ -104,8 +105,8 @@ func (t PowerType) String() string {
return "消费"
case PowerRefund:
return "退款"
case PowerReward:
return "众筹"
case PowerRedeem:
return "兑换"
}
return "其他"

View File

@@ -12,28 +12,25 @@ import (
)
type AppConfig struct {
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config
MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig
AlipayConfig AlipayConfig
HuPiPayConfig HuPiPayConfig
SmtpConfig SmtpConfig // 邮件发送配置
JPayConfig JPayConfig // payjs 支付配置
Path string `toml:"-"`
Listen string
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
SmtpConfig SmtpConfig // 邮件发送配置
XXLConfig XXLConfig
AlipayConfig AlipayConfig // 支付宝支付渠道配置
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
GeekPayConfig GeekPayConfig // GEEK 支付配置
WechatPayConfig WechatPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址
}
type SmtpConfig struct {
@@ -51,27 +48,6 @@ type ApiConfig struct {
Token string
}
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct {
Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境
@@ -81,29 +57,38 @@ type AlipayConfig struct {
PublicKey string // 用户公钥文件路径
AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
}
type WechatPayConfig struct {
Enabled bool // 是否启用该支付通道
AppId string // 公众号的APPID,如wxd678efh567hg6787
MchId string // 直连商户的商户号,由微信支付生成并下发
SerialNo string // 商户证书的证书序列号
PrivateKey string // 用户私钥文件路径
ApiV3Key string // API V3 秘钥
NotifyURL string // 异步通知地址
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道
Name string // 支付名称wechat/alipay
AppId string // App ID
AppSecret string // app 密钥
ApiURL string // 支付网关
NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
}
// JPayConfig PayJs 支付配置
type JPayConfig struct {
// GeekPayConfig GEEK支付配置
type GeekPayConfig struct {
Enabled bool
Name string // 支付名称,默认 wechat
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步回调地址
ReturnURL string // 支付成功返回地址
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
Methods []string // 支付方式
}
type XXLConfig struct { // XXL 任务调度配置
@@ -126,30 +111,27 @@ type RedisConfig struct {
const LicenseKey = "Geek-AI-License"
type License struct {
Key string `json:"key"` // 许可证书密钥
MachineId string `json:"machine_id"` // 机器码
UserNum int `json:"user_num"` // 用户数量
ExpiredAt int64 `json:"expired_at"` // 过期时间
IsActive bool `json:"is_active"` // 是否激活
Key string `json:"key"` // 许可证书密钥
MachineId string `json:"machine_id"` // 机器码
ExpiredAt int64 `json:"expired_at"` // 过期时间
IsActive bool `json:"is_active"` // 是否激活
Configs LicenseConfig `json:"configs"`
}
type LicenseConfig struct {
UserNum int `json:"user_num"` // 用户数量
DeCopy bool `json:"de_copy"` // 去版权
}
func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
type Platform string
const OpenAI = Platform("OpenAI")
const Azure = Platform("Azure")
const ChatGLM = Platform("ChatGLM")
const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei")
const QWen = Platform("QWen")
type SystemConfig struct {
Title string `json:"title,omitempty"`
AdminTitle string `json:"admin_title,omitempty"`
Logo string `json:"logo,omitempty"`
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` // 方形 Logo
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
@@ -158,10 +140,6 @@ type SystemConfig struct {
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
@@ -169,7 +147,9 @@ type SystemConfig struct {
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
@@ -177,4 +157,13 @@ type SystemConfig struct {
ContextDeep int `json:"context_deep,omitempty"`
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
}

View File

@@ -24,5 +24,4 @@ type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
Required interface{} `json:"required,omitempty"`
}

View File

@@ -22,3 +22,18 @@ type OrderRemark struct {
Price float64 `json:"price"`
Discount float64 `json:"discount"`
}
var PayMethods = map[string]string{
"alipay": "支付宝商号",
"wechat": "微信商号",
"hupi": "虎皮椒",
"geek": "易支付",
}
var PayNames = map[string]string{
"alipay": "支付宝",
"wxpay": "微信支付",
"qqpay": "QQ钱包",
"jdpay": "京东支付",
"douyin": "抖音支付",
"paypal": "PayPal支付",
}

View File

@@ -27,8 +27,6 @@ type MjTask struct {
Id uint `json:"id"`
TaskId string `json:"task_id"`
ImgArr []string `json:"img_arr"`
ChannelId string `json:"channel_id"`
SessionId string `json:"session_id"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
@@ -38,11 +36,12 @@ type MjTask struct {
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"`
ChannelId string `json:"channel_id"` // 渠道ID用来区分是哪个渠道创建的任务一个任务的 create 和 action 操作必须要再同一个渠道
Mode string `json:"mode"` // 绘画模式relax, fast, turbo
}
type SdTask struct {
Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Params SdTaskParams `json:"params"`
@@ -55,10 +54,10 @@ type SdTaskParams struct {
NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
Scheduler string `json:"scheduler"`
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Scheduler string `json:"scheduler"` // 采样调度
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
@@ -80,3 +79,47 @@ type DallTask struct {
Power int `json:"power"`
}
type SunoTask struct {
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type int `json:"type"`
Title string `json:"title"`
RefTaskId string `json:"ref_task_id,omitempty"`
RefSongId string `json:"ref_song_id,omitempty"`
Prompt string `json:"prompt"` // 提示词/歌词
Tags string `json:"tags"`
Model string `json:"model"`
Instrumental bool `json:"instrumental"` // 是否纯音乐
ExtendSecs int `json:"extend_secs,omitempty"` // 延长秒杀
SongId string `json:"song_id,omitempty"` // 合并歌曲ID
AudioURL string `json:"audio_url"` // 用户上传音频地址
}
const (
VideoLuma = "luma"
VideoRunway = "runway"
VideoCog = "cog"
)
type VideoTask struct {
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
}
type VideoParams struct {
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
Loop bool `json:"loop"` // 是否循环参考图
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址
Model string `json:"model"` // 使用哪个模型生成视频
Radio string `json:"radio"` // 视频尺寸
Style string `json:"style"` // 风格
Duration int `json:"duration"` // 视频时长(秒)
}

View File

@@ -17,30 +17,35 @@ type BizVo struct {
Data interface{} `json:"data,omitempty"`
}
// WsMessage Websocket message
type WsMessage struct {
// ReplyMessage 对话回复消息结构
type ReplyMessage struct {
Type WsMsgType `json:"type"` // 消息类别start, end, img
Content interface{} `json:"content"`
}
type WsMsgType string
const (
WsStart = WsMsgType("start")
WsMiddle = WsMsgType("middle")
WsEnd = WsMsgType("end")
WsErr = WsMsgType("error")
WsContent = WsMsgType("content") // 输出内容
WsEnd = WsMsgType("end")
WsErr = WsMsgType("error")
)
// InputMessage 对话输入消息结构
type InputMessage struct {
Content string `json:"content"`
Tools []int `json:"tools"` // 允许调用工具列表
Stream bool `json:"stream"` // 是否采用流式输出
}
type BizCode int
const (
Success = BizCode(0)
Failed = BizCode(1)
NotAuthorized = BizCode(400) // 未授权
NotPermission = BizCode(403) // 没有权限
NotAuthorized = BizCode(401) // 未授权
OkMsg = "Success"
ErrorMsg = "系统开小差了"
InvalidArgs = "非法参数或参数解析失败"
NoData = "No Data"
)

View File

@@ -8,7 +8,6 @@ require (
github.com/BurntSushi/toml v1.1.0
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
github.com/eatmoreapple/openwechat v1.2.1
github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt/v5 v5.0.0
@@ -19,7 +18,6 @@ require (
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
github.com/qiniu/go-sdk/v7 v7.17.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/smartwalle/alipay/v3 v3.2.15
go.uber.org/zap v1.23.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/driver/mysql v1.4.7
@@ -28,15 +26,28 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0
require (
github.com/mojocn/base64Captcha v1.3.1
github.com/go-pay/gopay v1.5.101
github.com/google/go-tika v0.3.1
github.com/microcosm-cc/bluemonday v1.0.26
github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/image v0.15.0
)
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-pay/crypto v0.0.1 // indirect
github.com/go-pay/errgroup v0.0.2 // indirect
github.com/go-pay/util v0.0.2 // indirect
github.com/go-pay/xlog v0.0.2 // indirect
github.com/go-pay/xtime v0.0.2 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/shirou/gopsutil v3.21.11+incompatible // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/mock v0.4.0 // indirect
)
@@ -74,9 +85,6 @@ require (
github.com/refraction-networking/utls v1.3.2 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/smartwalle/ncrypto v1.0.2 // indirect
github.com/smartwalle/ngx v1.0.6 // indirect
github.com/smartwalle/nsign v1.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
go.uber.org/dig v1.16.1 // indirect
golang.org/x/arch v0.3.0 // indirect

View File

@@ -6,6 +6,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
@@ -26,8 +28,6 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
@@ -43,6 +43,20 @@ github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-pay/crypto v0.0.1 h1:B6InT8CLfSLc6nGRVx9VMJRBBazFMjr293+jl0lLXUY=
github.com/go-pay/crypto v0.0.1/go.mod h1:41oEIvHMKbNcYlWUlRWtsnC6+ASgh7u29z0gJXe5bes=
github.com/go-pay/errgroup v0.0.2 h1:5mZMdm0TDClDm2S3G0/sm0f8AuQRtz0dOrTHDR9R8Cc=
github.com/go-pay/errgroup v0.0.2/go.mod h1:0+4b8mvFMS71MIzsaC+gVvB4x37I93lRb2dqrwuU8x8=
github.com/go-pay/gopay v1.5.101 h1:rVb+sfv6hiQtknAlZnTTLvU27NvFJ4p0yglN/vPpGXI=
github.com/go-pay/gopay v1.5.101/go.mod h1:AW4Yj8jDZX9BM1/GTLTY1Gy5SHjiq8kQvG5sBTN2sxI=
github.com/go-pay/util v0.0.2 h1:goJ4f6kNY5zzdtg1Cj8oWC+Cw7bfg/qq2rJangMAb9U=
github.com/go-pay/util v0.0.2/go.mod h1:qM8VbyF1n7YAPZBSJONSPMPsPedhUTktewUAdf1AjPg=
github.com/go-pay/xlog v0.0.2 h1:kUg5X8/5VZAPDg1J5eGjA3MG0/H5kK6Ew0dW/Bycsws=
github.com/go-pay/xlog v0.0.2/go.mod h1:DbjMADPK4+Sjxj28ekK9goqn4zmyY4hql/zRiab+S9E=
github.com/go-pay/xtime v0.0.2 h1:7YR4/iuELsEHpJ6LUO0SVK80hQxDO9MLCfuVYIiTCRM=
github.com/go-pay/xtime v0.0.2/go.mod h1:W1yRbJaSt4CSBcdAtLBQ8xajiN/Pl5hquGczUcUE9xE=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -68,8 +82,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
@@ -77,11 +89,15 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -123,6 +139,8 @@ github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc=
@@ -135,8 +153,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
@@ -176,20 +192,14 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8=
github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE=
github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI=
github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps=
github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo=
github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -204,6 +214,10 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
@@ -215,6 +229,8 @@ github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
@@ -237,14 +253,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -252,15 +270,21 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -270,17 +294,27 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
@@ -288,6 +322,7 @@ golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -8,19 +8,19 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/mojocn/base64Captcha"
"time"
"github.com/gin-gonic/gin"
@@ -29,37 +29,47 @@ import (
var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1
type ManagerHandler struct {
handler.BaseHandler
redis *redis.Client
redis *redis.Client
captcha *service.CaptchaService
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, captcha *service.CaptchaService) *ManagerHandler {
return &ManagerHandler{
BaseHandler: handler.BaseHandler{DB: db, App: app},
redis: client,
captcha: captcha,
}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data Manager
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// add captcha
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
resp.ERROR(c, "验证码错误!")
return
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var manager model.AdminUser

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
@@ -31,7 +32,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
@@ -48,23 +48,22 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
if data.Id > 0 {
h.DB.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled
apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name
res := h.DB.Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Save(&apiKey).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var keyVo vo.ApiKey
err := utils.CopyObject(apiKey, &keyVo)
err = utils.CopyObject(apiKey, &keyVo)
if err != nil {
resp.ERROR(c, "数据拷贝失败!")
resp.ERROR(c, fmt.Sprintf("拷贝数据失败:%v", err))
return
}
keyVo.Id = apiKey.Id
@@ -83,7 +82,7 @@ func (h *ApiKeyHandler) List(c *gin.Context) {
if t != "" {
session = session.Where("type", t)
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := session.Find(&items)
@@ -116,9 +115,9 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
return
}
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -131,9 +130,9 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) {
return
}
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Where("id", id).Delete(&model.ApiKey{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)

View File

@@ -1,46 +0,0 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/handler"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)
type CaptchaHandler struct {
handler.BaseHandler
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {
CaptchaId string `json:"captcha_id"`
PicPath string `json:"pic_path"`
}
// GetCaptcha 获取验证码
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
var captchaVo CaptchaVo
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
// b64s是图片的base64编码
id, b64s, err := cp.Generate()
if err != nil {
resp.ERROR(c, "生成验证码错误!")
return
}
captchaVo.CaptchaId = id
captchaVo.PicPath = b64s
resp.SUCCESS(c, captchaVo)
}

View File

@@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
@@ -21,16 +22,16 @@ import (
"gorm.io/gorm"
)
type ChatRoleHandler struct {
type ChatAppHandler struct {
handler.BaseHandler
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或者更新某个角色
func (h *ChatRoleHandler) Save(c *gin.Context) {
func (h *ChatAppHandler) Save(c *gin.Context) {
var data vo.ChatRole
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -45,10 +46,16 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
role.Id = data.Id
if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0)
} else {
err = h.DB.Where("marker", data.Key).First(&role).Error
if err == nil {
resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
return
}
}
res := h.DB.Save(&role)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err = h.DB.Save(&role).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 填充 ID 数据
@@ -57,7 +64,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
resp.SUCCESS(c, data)
}
func (h *ChatRoleHandler) List(c *gin.Context) {
func (h *ChatAppHandler) List(c *gin.Context) {
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.DB.Order("sort_num ASC").Find(&items)
@@ -68,13 +75,18 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// initialize model mane for role
modelIds := make([]int, 0)
typeIds := make([]int, 0)
for _, v := range items {
if v.ModelId > 0 {
modelIds = append(modelIds, v.ModelId)
}
if v.Tid > 0 {
typeIds = append(typeIds, v.Tid)
}
}
modelNameMap := make(map[int]string)
typeNameMap := make(map[int]string)
if len(modelIds) > 0 {
var models []model.ChatModel
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
@@ -84,6 +96,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
}
}
}
if len(typeIds) > 0 {
var appTypes []model.AppType
tx := h.DB.Where("id IN ?", typeIds).Find(&appTypes)
if tx.Error == nil {
for _, m := range appTypes {
typeNameMap[int(m.Id)] = m.Name
}
}
}
for _, v := range items {
var role vo.ChatRole
@@ -93,6 +114,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
role.CreatedAt = v.CreatedAt.Unix()
role.UpdatedAt = v.UpdatedAt.Unix()
role.ModelName = modelNameMap[role.ModelId]
role.TypeName = typeNameMap[role.Tid]
roles = append(roles, role)
}
}
@@ -101,7 +123,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
}
// Sort 更新角色排序
func (h *ChatRoleHandler) Sort(c *gin.Context) {
func (h *ChatAppHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
@@ -113,9 +135,9 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
@@ -123,7 +145,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Set(c *gin.Context) {
func (h *ChatAppHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
@@ -135,15 +157,15 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
return
}
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
func (h *ChatAppHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
@@ -152,6 +174,7 @@ func (h *ChatRoleHandler) Remove(c *gin.Context) {
}
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "删除失败!")
return
}

View File

@@ -0,0 +1,148 @@
package admin
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatAppTypeHandler struct {
handler.BaseHandler
}
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或更新App类型
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Icon string `json:"icon"`
SortNum int `json:"sort_num"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id == 0 { // for add
err := h.DB.Where("name", data.Name).First(&model.AppType{}).Error
if err == nil {
resp.ERROR(c, "当前分类已经存在")
return
}
err = h.DB.Create(&model.AppType{
Name: data.Name,
Icon: data.Icon,
Enabled: data.Enabled,
SortNum: data.SortNum,
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
} else { // for update
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).Updates(map[string]interface{}{
"name": data.Name,
"icon": data.Icon,
"enabled": data.Enabled,
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var appType vo.AppType
err = utils.CopyObject(v, &appType)
if err != nil {
continue
}
appType.Id = v.Id
appType.CreatedAt = v.CreatedAt.Unix()
appTypes = append(appTypes, appType)
}
resp.SUCCESS(c, appTypes)
}
// Remove 删除App类型
func (h *ChatAppTypeHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.AppType{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用|禁用
func (h *ChatAppTypeHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Sort 更新排序
func (h *ChatAppTypeHandler) Sort(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
Sorts []int `json:"sorts"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
for index, id := range data.Ids {
err := h.DB.Model(&model.AppType{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}

View File

@@ -259,9 +259,9 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
// RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)

View File

@@ -49,27 +49,32 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
return
}
item := model.ChatModel{
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
KeyId: data.KeyId,
Power: data.Power}
item := model.ChatModel{}
// 更新
if data.Id > 0 {
h.DB.Where("id", data.Id).First(&item)
}
item.Name = data.Name
item.Value = data.Value
item.Enabled = data.Enabled
item.SortNum = data.SortNum
item.Open = data.Open
item.Power = data.Power
item.MaxTokens = data.MaxTokens
item.MaxContext = data.MaxContext
item.Temperature = data.Temperature
item.KeyId = data.KeyId
var res *gorm.DB
if data.Id > 0 {
item.Id = data.Id
res = h.DB.Select("*").Omit("created_at").Updates(&item)
res = h.DB.Save(&item)
} else {
res = h.DB.Create(&item)
}
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
logger.Error("error with update database", res.Error)
resp.ERROR(c, res.Error.Error())
return
}
@@ -88,9 +93,13 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
func (h *ChatModelHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
name := h.GetTrim(c, "name")
if enable {
session = session.Where("enabled", enable)
}
if name != "" {
session = session.Where("name LIKE ?", name+"%")
}
var items []model.ChatModel
var cms = make([]vo.ChatModel, 0)
res := session.Order("sort_num ASC").Find(&items)
@@ -138,9 +147,9 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return
}
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -158,9 +167,9 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
@@ -175,9 +184,9 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
return
}
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)

View File

@@ -11,21 +11,29 @@ import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
)
type ConfigHandler struct {
handler.BaseHandler
levelDB *store.LevelDB
levelDB *store.LevelDB
licenseService *service.LicenseService
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ConfigHandler {
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB}
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB,
licenseService: licenseService,
}
}
func (h *ConfigHandler) Update(c *gin.Context) {
@@ -36,6 +44,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -43,6 +52,12 @@ func (h *ConfigHandler) Update(c *gin.Context) {
return
}
// ONLY authorized user can change the copyright
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
return
}
value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value}
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
@@ -95,3 +110,98 @@ func (h *ConfigHandler) Get(c *gin.Context) {
resp.SUCCESS(c, value)
}
// Active 激活系统
func (h *ConfigHandler) Active(c *gin.Context) {
var data struct {
License string `json:"license"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
info, err := host.Info()
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.licenseService.ActiveLicense(data.License, info.HostID)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, info.HostID)
}
// GetLicense 获取 License 信息
func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license)
}
// FixData 修复数据
func (h *ConfigHandler) FixData(c *gin.Context) {
var fixed bool
version := "data_fix_4.1.4"
err := h.levelDB.Get(version, &fixed)
if err == nil || fixed {
resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
return
}
tx := h.DB.Begin()
var users []model.User
err = tx.Find(&users).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, user := range users {
if user.Email != "" || user.Mobile != "" {
continue
}
if utils.IsValidEmail(user.Username) {
user.Email = user.Username
} else if utils.IsValidMobile(user.Username) {
user.Mobile = user.Username
}
err = tx.Save(&user).Error
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
}
var orders []model.Order
err = h.DB.Find(&orders).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, order := range orders {
if order.PayWay == "支付宝" {
order.PayWay = "alipay"
order.PayType = "alipay"
} else if order.PayWay == "微信支付" {
order.PayWay = "wechat"
order.PayType = "wxpay"
} else if order.PayWay == "hupi" {
order.PayType = "wxpay"
}
err = tx.Save(&order).Error
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
}
tx.Commit()
err = h.levelDB.Put(version, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -60,13 +60,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
stats.Tokens += item.Tokens
}
// 众筹收入
var rewards []model.Reward
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入
var orders []model.Order
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
@@ -101,13 +94,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
}
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {

View File

@@ -69,9 +69,9 @@ func (h *FunctionHandler) Set(c *gin.Context) {
return
}
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -101,9 +101,9 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Delete(&model.Function{Id: uint(id)}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}

View File

@@ -41,16 +41,16 @@ func (h *MenuHandler) Save(c *gin.Context) {
return
}
res := h.DB.Save(&model.Menu{
err := h.DB.Save(&model.Menu{
Id: data.Id,
Name: data.Name,
Icon: data.Icon,
URL: data.URL,
SortNum: data.SortNum,
Enabled: data.Enabled,
})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -84,9 +84,9 @@ func (h *MenuHandler) Enable(c *gin.Context) {
return
}
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -104,9 +104,9 @@ func (h *MenuHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
@@ -118,9 +118,9 @@ func (h *MenuHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Menu{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Where("id", id).Delete(&model.Menu{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}

View File

@@ -15,6 +15,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -67,6 +68,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
if !ok {
payMethod = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
if !ok {
payName = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
list = append(list, order)
} else {
logger.Error(err)
@@ -92,11 +103,33 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return
}
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Where("id = ?", id).Delete(&model.Order{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}
func (h *OrderHandler) Clear(c *gin.Context) {
var orders []model.Order
err := h.DB.Where("status <> ?", 2).Where("pay_time", 0).Find(&orders).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
deleteIds := make([]uint, 0)
for _, order := range orders {
// 只删除 15 分钟内的未支付订单
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
deleteIds = append(deleteIds, order.Id)
}
}
err = h.DB.Where("id IN ?", deleteIds).Delete(&model.Order{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -55,16 +55,16 @@ func (h *ProductHandler) Save(c *gin.Context) {
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Save(&item).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var itemVo vo.Product
err := utils.CopyObject(item, &itemVo)
err = utils.CopyObject(item, &itemVo)
if err != nil {
resp.ERROR(c, "数据拷贝失败")
resp.ERROR(c, "数据拷贝失败: "+err.Error())
return
}
itemVo.Id = item.Id
@@ -105,9 +105,9 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return
}
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -125,9 +125,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
@@ -139,9 +139,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.DB.Where("id", id).Delete(&model.Product{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
err := h.DB.Where("id", id).Delete(&model.Product{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}

View File

@@ -0,0 +1,164 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RedeemHandler struct {
handler.BaseHandler
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RedeemHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
code := c.Query("code")
status := h.GetInt(c, "status", -1)
session := h.DB.Session(&gorm.Session{})
if code != "" {
session.Where("code LIKE ?", "%"+code+"%")
}
if status == 0 {
session.Where("redeem_at = ?", 0)
} else if status == 1 {
session.Where("redeem_at > ?", 0)
}
var total int64
session.Model(&model.Redeem{}).Count(&total)
var redeems []model.Redeem
offset := (page - 1) * pageSize
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&redeems).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var items = make([]vo.Redeem, 0)
userIds := make([]uint, 0)
for _, v := range redeems {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
}
for _, v := range redeems {
var r vo.Redeem
err = utils.CopyObject(v, &r)
if err != nil {
continue
}
r.Id = v.Id
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
items = append(items, r)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *RedeemHandler) Create(c *gin.Context) {
var data struct {
Name string `json:"name"`
Power int `json:"power"`
Num int `json:"num"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
counter := 0
codes := make([]string, 0)
var errMsg = ""
if data.Num > 0 {
for i := 0; i < data.Num; i++ {
code, err := utils.GenRedeemCode(32)
if err != nil {
errMsg = err.Error()
continue
}
err = h.DB.Create(&model.Redeem{
Code: code,
Name: data.Name,
Power: data.Power,
Enabled: true,
}).Error
if err != nil {
errMsg = err.Error()
continue
}
codes = append(codes, code)
counter++
}
}
if counter == 0 {
resp.ERROR(c, errMsg)
return
}
resp.SUCCESS(c, gin.H{
"counter": counter,
})
}
func (h *RedeemHandler) Set(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Filed string `json:"filed"`
Value interface{} `json:"value"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Model(&model.Redeem{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *RedeemHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c)
}

View File

@@ -1,80 +0,0 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RewardHandler struct {
handler.BaseHandler
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
var items []model.Reward
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
for _, v := range items {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
}
for _, v := range items {
var r vo.Reward
err := utils.CopyObject(v, &r)
if err != nil {
continue
}
r.Id = v.Id
r.Username = userMap[v.UserId].Username
r.CreatedAt = v.CreatedAt.Unix()
r.UpdatedAt = v.UpdatedAt.Unix()
rewards = append(rewards, r)
}
}
resp.SUCCESS(c, rewards)
}
func (h *RewardHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -12,10 +12,12 @@ import (
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/go-redis/redis/v8"
"time"
"github.com/gin-gonic/gin"
@@ -24,10 +26,12 @@ import (
type UserHandler struct {
handler.BaseHandler
licenseService *service.LicenseService
redis *redis.Client
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService, redisCli *redis.Client) *UserHandler {
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
}
// List 用户列表
@@ -47,7 +51,7 @@ func (h *UserHandler) List(c *gin.Context) {
}
session.Model(&model.User{}).Count(&total)
res := session.Offset(offset).Limit(pageSize).Find(&items)
res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error == nil {
for _, item := range items {
var user vo.User
@@ -71,6 +75,8 @@ func (h *UserHandler) Save(c *gin.Context) {
Id uint `json:"id"`
Password string `json:"password"`
Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
ChatRoles []string `json:"chat_roles"`
ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"`
@@ -82,7 +88,13 @@ func (h *UserHandler) Save(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检测最大注册人数
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
var user = model.User{}
var res *gorm.DB
var userVo vo.User
@@ -94,6 +106,8 @@ func (h *UserHandler) Save(c *gin.Context) {
}
var oldPower = user.Power
user.Username = data.Username
user.Email = data.Email
user.Mobile = data.Mobile
user.Status = data.Status
user.Vip = data.Vip
user.Power = data.Power
@@ -101,9 +115,11 @@ func (h *UserHandler) Save(c *gin.Context) {
user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
res = h.DB.Select("username", "mobile", "email", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
logger.Error("error with update database", res.Error)
resp.ERROR(c, res.Error.Error())
return
}
// 记录算力日志
@@ -126,12 +142,27 @@ func (h *UserHandler) Save(c *gin.Context) {
CreatedAt: time.Now(),
})
}
// 如果禁用了用户,则将用户踢下线
if user.Status == false {
key := fmt.Sprintf("users/%v", user.Id)
if _, err := h.redis.Del(c, key).Result(); err != nil {
logger.Error("error with delete session: ", err)
}
}
} else {
// 检查用户是否已经存在
h.DB.Where("username", data.Username).First(&user)
if user.Id > 0 {
resp.ERROR(c, "用户名已存在")
return
}
salt := utils.RandString(8)
u := model.User{
Username: data.Username,
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Password: utils.GenPassword(data.Password, salt),
Mobile: data.Mobile,
Email: data.Email,
Avatar: "/images/avatar/user.png",
Salt: salt,
Power: data.Power,
@@ -140,6 +171,11 @@ func (h *UserHandler) Save(c *gin.Context) {
ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
}
if h.licenseService.GetLicense().Configs.DeCopy {
u.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
}
res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id
@@ -148,7 +184,7 @@ func (h *UserHandler) Save(c *gin.Context) {
}
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
resp.ERROR(c, res.Error.Error())
return
}
@@ -184,33 +220,69 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
id := c.Query("id")
ids := c.QueryArray("ids[]")
if id != "" {
ids = append(ids, id)
}
if len(ids) == 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
tx := h.DB.Begin()
var err error
for _, id = range ids {
// 删除用户
if err = tx.Where("id", id).Delete(&model.User{}).Error; err != nil {
break
}
// 删除聊天记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatItem{}).Error; err != nil {
break
}
// 删除聊天历史记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatMessage{}).Error; err != nil {
break
}
// 删除登录日志
if err = tx.Where("user_id = ?", id).Delete(&model.UserLoginLog{}).Error; err != nil {
break
}
// 删除算力日志
if err = tx.Where("user_id = ?", id).Delete(&model.PowerLog{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.InviteLog{}).Error; err != nil {
break
}
// 删除众筹日志
if err = tx.Where("user_id = ?", id).Delete(&model.Redeem{}).Error; err != nil {
break
}
// 删除绘图任务
if err = tx.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SdJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.DallJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SunoJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.VideoJob{}).Error; err != nil {
break
}
}
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -8,13 +8,13 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils"
"errors"
"fmt"
"gorm.io/gorm"
"strings"
@@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
}
var user model.User
res := h.DB.First(&user, userId)
res := h.DB.Where("id", userId).First(&user)
// 更新缓存
if res.Error == nil {
c.Set(types.LoginUserCache, user)

View File

@@ -0,0 +1,44 @@
package handler
import (
"geekai/core"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatAppTypeHandler struct {
BaseHandler
}
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var appType vo.AppType
err = utils.CopyObject(v, &appType)
if err != nil {
continue
}
appType.Id = v.Id
appType.CreatedAt = v.CreatedAt.Unix()
appTypes = append(appTypes, appType)
}
resp.SUCCESS(c, appTypes)
}

View File

@@ -31,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
var res *gorm.DB
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
t := c.Query("type")
if t != "" {
session = session.Where("type", t)
}
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
res = session.Where("open", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int

View File

@@ -29,45 +29,63 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
tid := h.GetInt(c, "tid", 0)
var roles []model.ChatRole
session := h.DB.Where("enable", true)
if tid > 0 {
session = session.Where("tid", tid)
}
err := session.Order("sort_num ASC").Find(&roles).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
}
// ListByUser 获取用户添加的角色列表
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
var roleVos = make([]vo.ChatRole, 0)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
session := h.DB.Where("enable", true)
// 如果用户没登录,则获取所有角色
if userId > 0 {
var user model.User
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return
}
// 保证用户至少有一个角色可用
if len(roleKeys) > 0 {
session = session.Where("marker IN ?", roleKeys)
}
}
if id > 0 {
session = session.Or("id", id)
}
res := session.Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.SUCCESS(c, roleVos)
return
}
// 获取所有角色
if userId == 0 || all {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
return
}
var user model.User
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
resp.ERROR(c, res.Error.Error())
return
}
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
if !utils.ContainsStr(roleKeys, r.Key) {
continue
}
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
@@ -94,10 +112,9 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return
}
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil {
logger.Error("添加应用失败:", err)
resp.ERROR(c, "更新数据库失败!")
err = h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}

View File

@@ -1,205 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return errors.New(line)
}
if len(responseBody.Choices) == 0 {
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if strings.Contains(res.Error.Message, "maximum context length") {
logger.Error(res.Error.Message)
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
return fmt.Errorf("请求 Azure API 失败:%v", res.Error)
}
}
return nil
}

View File

@@ -1,275 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"html/template"
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)
type baiduResp struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
Result string `json:"result"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
var resp baiduResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
}
if resp.IsEnd {
break
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
}
return nil
}
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 24 * 20 // access_token 有效期
key := strings.Split(apiKey, "|")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
tokenString = fmt.Sprintf("%s", r["access_token"])
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, nil
}

View File

@@ -17,16 +17,19 @@ import (
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"html/template"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
@@ -38,15 +41,23 @@ var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
userService *service.UserService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
ChatContexts: types.NewLMap[string, []types.Message](),
userService: userService,
}
}
@@ -67,7 +78,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
c.Abort()
return
}
@@ -79,26 +90,15 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
var chatModel model.ChatModel
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭!!!")
utils.ReplyErrorMessage(client, "当前AI模型暂未启用对话已关闭!!!")
c.Abort()
return
}
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
return
}
session = &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
Username: user.Username,
UserId: user.Id,
}
h.App.ChatSession.Put(sessionId, session)
session := &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
UserId: h.GetLoginUserId(c),
}
// use old chat data override the chat model and role ID
@@ -118,51 +118,42 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
KeyId: chatModel.KeyId}
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.App.ChatClients.Delete(sessionId)
h.App.ChatSession.Delete(sessionId)
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
cancelFunc := h.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
h.App.ReqCancelFunc.Delete(sessionId)
h.ReqCancelFunc.Delete(sessionId)
}
return
}
var message types.WsMessage
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
logger.Infof("Receive a message:%+v", message)
session.Tools = message.Tools
session.Stream = message.Stream
ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel)
h.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyMessage(client, err.Error())
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
logger.Infof("回答完毕: %v", message.Content)
}
@@ -211,70 +202,57 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
Model: session.Model.Value,
}
switch session.Model.Platform {
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = session.Model.Temperature
// 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") {
utils.ReplyContent(ws, "AI 正在思考...\n")
req.Stream = false
session.Start = time.Now().Unix()
} else {
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
req.Stream = session.Stream
}
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
required := parameters["required"]
delete(parameters, "required")
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
// Fixed: compatible for gpt4-turbo-xxx model
if !strings.HasPrefix(req.Model, "gpt-4-turbo-") {
tool.Function.Required = required
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
}
// 加载聊天上下文
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
messages = h.App.ChatContexts.Get(session.ChatId)
if h.ChatContexts.Has(session.ChatId) {
messages = h.ChatContexts.Get(session.ChatId)
} else {
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
@@ -299,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
for i := len(messages) - 1; i >= 0; i-- {
v := messages[i]
tks, _ = utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
@@ -322,66 +301,69 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
if session.Model.Platform == types.QWen {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: prompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
text := prompt
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
fullPrompt := prompt
text := prompt
// extract files in prompt
files := utils.ExtractFileURLs(prompt)
logger.Debugf("detected FILES: %+v", files)
// 如果不是逆向模型,则提取文件内容
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
strings.HasSuffix(session.Model.Value, "claude-3")) {
contents := make([]string, 0)
var file model.File
for _, v := range files {
h.DB.Where("url = ?", v).First(&file)
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
if err != nil {
logger.Error("error with read file: ", err)
} else {
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
data = append(data, gin.H{
"type": "text",
"text": text,
})
content = data
} else {
content = prompt
text = strings.Replace(text, v, "", 1)
}
if len(contents) > 0 {
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML)\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
}
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
if tokens > session.Model.MaxContext {
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
logger.Debug("最终Prompt", fullPrompt)
// extract images from prompt
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
})
content = data
} else {
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
// Tokens 统计 token 数量
@@ -397,17 +379,17 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
}
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c, item.Tokens)
return
}
//if data.Text == "" && data.ChatId != "" {
// var item model.ChatMessage
// userId, _ := c.Get(types.LoginUserID)
// res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
// if res.Error != nil {
// resp.ERROR(c, res.Error.Error())
// return
// }
// resp.SUCCESS(c, item.Tokens)
// return
//}
tokens, err := utils.CalcTokens(data.Text, data.Model)
if err != nil {
@@ -441,9 +423,9 @@ func getTotalTokens(req types.ApiRequest) int {
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")
if h.App.ReqCancelFunc.Has(sessionId) {
h.App.ReqCancelFunc.Get(sessionId)()
h.App.ReqCancelFunc.Delete(sessionId)
if h.ReqCancelFunc.Has(sessionId) {
h.ReqCancelFunc.Get(sessionId)()
h.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}
@@ -453,51 +435,24 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Debug().Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
var apiURL string
switch session.Model.Platform {
case types.Azure:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if session.Model.Platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
// ONLY allow apiURL in blank list
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
logger.Debugf("对话请求消息体:%+v", req)
logger.Debugf(utils.JsonEncode(req))
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@@ -521,28 +476,10 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure:
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM:
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
return client.Do(request)
}
@@ -552,24 +489,115 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
Type: types.PowerConsume,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
})
if err != nil {
logger.Error(err)
}
}
func (h *ChatHandler) saveChatHistory(
req types.ApiRequest,
usage Usage,
message types.Message,
chatCtx []types.Message,
session *types.ChatSession,
role model.ChatRole,
userVo vo.User,
promptCreatedAt time.Time,
replyCreatedAt time.Time) {
useMsg := types.Message{Role: "user", Content: usage.Prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
var promptTokens, replyTokens, totalTokens int
if usage.PromptTokens > 0 {
promptTokens = usage.PromptTokens
} else {
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(usage.Prompt),
Tokens: promptTokens,
TotalTokens: promptTokens,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
err := h.DB.Save(&historyUserMsg).Error
if err != nil {
logger.Error("failed to save prompt history message: ", err)
}
// for reply
// 计算本次对话消耗的总 token 数量
if usage.CompletionTokens > 0 {
replyTokens = usage.CompletionTokens
totalTokens = usage.TotalTokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
totalTokens = replyTokens + getTotalTokens(req)
}
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: usage.Content,
Tokens: replyTokens,
TotalTokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
err = h.DB.Create(&historyReplyMsg).Error
if err != nil {
logger.Error("failed to save reply history message: ", err)
}
// 更新用户算力
if session.Model.Power > 0 {
h.subUserPower(userVo, session, promptTokens, replyTokens)
}
// 保存当前会话
var chatItem model.ChatItem
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
if err != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = userVo.Id
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(usage.Prompt) > 30 {
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
} else {
chatItem.Title = usage.Prompt
}
chatItem.Model = req.Model
err = h.DB.Create(&chatItem).Error
if err != nil {
logger.Error("failed to save chat item: ", err)
}
}
}
// 将AI回复消息中生成的图片链接下载到本地
@@ -587,7 +615,7 @@ func (h *ChatHandler) extractImgUrl(text string) string {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue

View File

@@ -96,7 +96,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
for _, chat := range chats {
chatIds = append(chatIds, chat.ChatId)
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
h.ChatContexts.Delete(chat.ChatId)
}
err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
@@ -108,8 +108,6 @@ func (h *ChatHandler) Clear(c *gin.Context) {
if res.Error != nil {
return res.Error
}
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
return nil
})
@@ -175,7 +173,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
// 清空会话上下文
h.App.ChatContexts.Delete(chatId)
h.ChatContexts.Delete(chatId)
resp.SUCCESS(c, types.OkMsg)
}

View File

@@ -1,237 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/golang-jwt/jwt/v5"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var event, content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "event:") {
event = line[6:]
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
switch event {
case "add":
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
contents = append(contents, content)
case "finish":
break
case "error":
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
break
case "interrupted":
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"code"`
Success bool `json:"success"`
Msg string `json:"msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if !res.Success {
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
}
}
return nil
}
func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 2
key := strings.Split(apiKey, ".")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": key[0],
"timestamp": time.Now().Unix(),
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
})
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
delete(token.Header, "typ")
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString([]byte(key[1]))
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, err
}

View File

@@ -17,15 +17,38 @@ import (
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"html/template"
req2 "github.com/imroc/req/v3"
"io"
"strings"
"time"
"unicode/utf8"
req2 "github.com/imroc/req/v3"
)
type Usage struct {
Prompt string `json:"prompt,omitempty"`
Content string `json:"content,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIResVo struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage Usage `json:"usage"`
}
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []types.Message,
@@ -52,23 +75,26 @@ func (h *ChatHandler) sendOpenAiMessage(
defer response.Body.Close()
}
if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var message = types.Message{Role: "assistant"}
var contents = make([]string, 0)
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
@@ -77,6 +103,9 @@ func (h *ChatHandler) sendOpenAiMessage(
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
continue
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
@@ -107,8 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg})
contents = append(contents, callMsg)
}
continue
@@ -125,12 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
if isNew {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
@@ -145,7 +169,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
if toolCall { // 调用函数完成任务
var params map[string]interface{}
params := make(map[string]interface{})
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
@@ -162,14 +186,14 @@ func (h *ChatHandler) sendOpenAiMessage(
}
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsContent,
Content: msg,
})
contents = append(contents, msg)
} else {
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsContent,
Content: apiRes.Data,
})
contents = append(contents, utils.InterfaceToString(apiRes.Data))
@@ -178,126 +202,34 @@ func (h *ChatHandler) sendOpenAiMessage(
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
useContext := true
if toolCall {
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
usage := Usage{
Prompt: prompt,
Content: strings.Join(contents, ""),
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
message.Content = usage.Content
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
} else { // 非流式输出
var respVo OpenAIResVo
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
return fmt.Errorf("读取响应失败:%v", body)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
err = json.Unmarshal(body, &respVo)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
return fmt.Errorf("解析响应失败:%v", body)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
logger.Error(res.Error.Message)
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
utils.ReplyMessage(ws, content)
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
}
return nil

View File

@@ -1,242 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/syndtr/goleveldb/leveldb/errors"
"html/template"
"io"
"strings"
"time"
"unicode/utf8"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
content = line[5:]
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res struct {
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
}
return nil
}

View File

@@ -1,336 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"html/template"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
)
type xunFeiResp struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
} `json:"text"`
} `json:"choices"`
Usage struct {
Text struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
var Model2URL = map[string]string{
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
if err != nil {
logger.Error(readResp(resp) + err.Error())
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
} else if resp.StatusCode != 101 {
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
}
data := buildRequest(key[0], req)
fmt.Printf("%+v", data)
fmt.Println(apiURL)
err = conn.WriteJSON(data)
if err != nil {
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
return nil
}
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.Error("error with read message:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
break
}
// 解析数据
var result xunFeiResp
err = json.Unmarshal(msg, &result)
if err != nil {
logger.Error("error with parsing JSON:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
return nil
}
if result.Header.Code != 0 {
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
return nil
}
content = result.Payload.Choices.Text[0].Content
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
contents = append(contents, content)
// 第一个结果
if result.Payload.Choices.Status == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
if result.Payload.Choices.Status == 2 { // 最终结果
_ = conn.Close() // 关闭连接
break
}
select {
case <-ctx.Done():
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
return nil
default:
continue
}
}
// 消息发送成功
if len(contents) > 0 {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
}
}
return nil
}
// 构建 websocket 请求实体
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
return map[string]interface{}{
"header": map[string]interface{}{
"app_id": appid,
},
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",
},
},
"payload": map[string]interface{}{
"message": map[string]interface{}{
"text": req.Messages,
},
},
}
}
// 创建鉴权 URL
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
ul, err := url.Parse(hostURL)
if err != nil {
return "", err
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
signStr := strings.Join(signString, "\n")
sha := hmacWithSha256(signStr, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
return hostURL + "?" + v.Encode(), nil
}
// 使用 sha256 签名
func hmacWithSha256(data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
// 读取响应
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}

View File

@@ -9,6 +9,7 @@ package handler
import (
"geekai/core"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
@@ -19,10 +20,11 @@ import (
type ConfigHandler struct {
BaseHandler
licenseService *service.LicenseService
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
}
// Get 获取指定的系统配置
@@ -44,3 +46,9 @@ func (h *ConfigHandler) Get(c *gin.Context) {
resp.SUCCESS(c, value)
}
// License 获取 License 配置
func (h *ConfigHandler) License(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license.Configs)
}

View File

@@ -8,34 +8,36 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
service *dalle.Service
uploader *oss.UploaderManager
redis *redis.Client
dallService *dalle.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
return &DallJobHandler{
service: service,
uploader: manager,
dallService: service,
uploader: manager,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -60,18 +62,18 @@ func (h *DallJobHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
h.dallService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
h.dallService.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
var message types.ReplyMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
@@ -126,7 +128,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
return
}
h.service.PushTask(types.DallTask{
h.dallService.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
@@ -136,7 +138,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId)
client := h.dallService.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -158,13 +160,13 @@ func (h *DallJobHandler) ImgWall(c *gin.Context) {
// JobList 获取 SD 任务列表
func (h *DallJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
finish := h.GetBool(c, "finish")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
err, jobs := h.getData(finish, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -174,11 +176,11 @@ func (h *DallJobHandler) JobList(c *gin.Context) {
}
// JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
session = session.Where("progress >= ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
@@ -192,11 +194,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
// 统计总数
var total int64
session.Model(&model.DallJob{}).Count(&total)
var items []model.DallJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
return res.Error, vo.Page{}
}
var jobs = make([]vo.DallJob, 0)
@@ -209,30 +214,44 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
jobs = append(jobs, job)
}
return nil, jobs
return nil, vo.NewPage(total, page, pageSize, jobs)
}
// Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.DallJob
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// remove job recode
res := h.DB.Delete(&model.DallJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
@@ -242,18 +261,13 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
// Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
err := h.DB.Model(&model.DallJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}

View File

@@ -8,15 +8,16 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"errors"
"fmt"
"strings"
"time"
@@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
resp.SUCCESS(c, content)
}
// List 获取所有的工具函数列表
func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function
err := h.DB.Where("enabled", true).Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
tools := make([]vo.Function, 0)
for _, v := range items {
var f vo.Function
err = utils.CopyObject(v, &f)
if err != nil {
continue
}
f.Action = ""
f.Token = ""
tools = append(tools, f)
}
resp.SUCCESS(c, tools)
}

View File

@@ -9,7 +9,6 @@ package handler
import (
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
@@ -59,23 +58,16 @@ func (h *InviteHandler) Code(c *gin.Context) {
// List Log 用户邀请记录
func (h *InviteHandler) List(c *gin.Context) {
var data struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
var list = make([]vo.InviteLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var v vo.InviteLog
@@ -89,7 +81,7 @@ func (h *InviteHandler) List(c *gin.Context) {
}
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
}
// Hits 访问邀请码

View File

@@ -15,6 +15,7 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
@@ -30,13 +31,15 @@ import (
// MarkMapHandler 生成思维导图
type MarkMapHandler struct {
BaseHandler
clients *types.LMap[int, *types.WsClient]
clients *types.LMap[int, *types.WsClient]
userService *service.UserService
}
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler {
return &MarkMapHandler{
BaseHandler: BaseHandler{App: app, DB: db},
clients: types.NewLMap[int, *types.WsClient](),
userService: userService,
}
}
@@ -61,7 +64,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
return
}
var message types.WsMessage
var message types.ReplyMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
@@ -82,7 +85,9 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
utils.ReplyErrorMessage(client, err.Error())
} else {
utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd})
}
}
@@ -101,17 +106,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
return fmt.Errorf("error with query chat model: %v", res.Error)
}
if user.Status == false {
return errors.New("当前用户被禁用")
}
if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power)
}
messages := make([]interface{}, 0)
messages = append(messages, types.Message{Role: "system", Content: `
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
你是一位非常优秀的思维导图助手, 你能帮助用户整理思路,根据用户提供的主题或内容,快速生成结构清晰,有条理的思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
# Geek-AI 助手
## 完整的开源系统
@@ -130,7 +131,7 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
另外,除此之外不要任何解释性语句。
`})
messages = append(messages, types.Message{Role: "user", Content: prompt})
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", prompt)})
var req = types.ApiRequest{
Model: chatModel.Value,
Stream: true,
@@ -149,7 +150,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
@@ -170,79 +170,50 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
break
}
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
utils.ReplyChunkMessage(client, types.ReplyMessage{
Type: types.WsContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("解析响应失败: %v", err)
}
// OpenAI API 调用异常处理
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
// remove key
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
return errors.New("请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
return errors.New("请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else {
return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
}
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
}
// 扣减算力
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userId).First(&u)
h.DB.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerConsume,
Amount: chatModel.Power,
Mark: types.PowerSub,
Balance: u.Power,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
CreatedAt: time.Now(),
if chatModel.Power > 0 {
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
Type: types.PowerConsume,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
})
if err != nil {
return err
}
}
return nil
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", types.OpenAI).
Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := apiKey.ApiURL
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
@@ -269,5 +240,6 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
return client.Do(request)
}

View File

@@ -27,9 +27,15 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
index := h.GetBool(c, "index")
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
session := h.DB.Session(&gorm.Session{})
session = session.Where("enabled", true)
if index {
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
}
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu

View File

@@ -8,6 +8,8 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
@@ -17,8 +19,6 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
@@ -30,16 +30,18 @@ import (
type MidJourneyHandler struct {
BaseHandler
pool *mj.ServicePool
snowflake *service.Snowflake
uploader *oss.UploaderManager
mjService *mj.Service
snowflake *service.Snowflake
uploader *oss.UploaderManager
userService *service.UserService
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
return &MidJourneyHandler{
snowflake: snowflake,
pool: pool,
uploader: manager,
snowflake: snowflake,
mjService: service,
uploader: manager,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -59,11 +61,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false
}
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true
}
@@ -85,26 +82,25 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
h.mjService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
SessionId string `json:"session_id"`
TaskType string `json:"task_type"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Model string `json:"model"` // 模型
Chaos int `json:"chaos"` // 创意度取值范围: 0-100
Raw bool `json:"raw"` // 是否开启原始模型
Seed int64 `json:"seed"` // 随机数
Stylize int `json:"stylize"` // 风格化
ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Tile bool `json:"tile"` // 重复平铺
Quality float32 `json:"quality"` // 画质
Iw float32 `json:"iw"`
CRef string `json:"cref"` //生成角色一致的图像
SRef string `json:"sref"` //生成风格一致的图像
@@ -202,40 +198,34 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
h.pool.PushTask(types.MjTask{
h.mjService.PushTask(types.MjTask{
Id: job.Id,
TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
@@ -244,17 +234,12 @@ type reqVo struct {
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
ChatId string `json:"chat_id"`
RoleId int `json:"role_id"`
Icon string `json:"icon"`
}
// Upscale send upscale command to MidJourney Bot
func (h *MidJourneyHandler) Upscale(c *gin.Context) {
var data reqVo
if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" {
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
@@ -272,7 +257,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
@@ -281,46 +265,40 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return
}
h.pool.PushTask(types.MjTask{
h.mjService.PushTask(types.MjTask{
Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskUpscale,
Prompt: data.Prompt,
UserId: userId,
ChannelId: data.ChannelId,
Index: data.Index,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Variation send variation command to MidJourney Bot
func (h *MidJourneyHandler) Variation(c *gin.Context) {
var data reqVo
if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" {
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
@@ -339,7 +317,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(),
}
@@ -348,40 +325,32 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return
}
h.pool.PushTask(types.MjTask{
h.mjService.PushTask(types.MjTask{
Id: job.Id,
SessionId: data.SessionId,
Type: types.TaskVariation,
Prompt: data.Prompt,
UserId: userId,
Index: data.Index,
ChannelId: data.ChannelId,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
@@ -400,13 +369,13 @@ func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
finish := h.GetBool(c, "finish")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
err, jobs := h.getData(finish, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -416,10 +385,10 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
session = session.Where("progress >= ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
@@ -434,10 +403,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
session = session.Offset(offset).Limit(pageSize)
}
// 统计总数
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var items []model.MidJourneyJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
return res.Error, vo.Page{}
}
var jobs = make([]vo.MidJourneyJob, 0)
@@ -449,48 +422,57 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
}
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} else {
job.ImgURL = job.OrgURL
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
jobs = append(jobs, job)
}
return nil, jobs
return nil, vo.NewPage(total, page, pageSize, jobs)
}
// Remove remove task image
func (h *MidJourneyHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0)
var job model.MidJourneyJob
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// remove job recode
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "mid-journey",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -500,18 +482,12 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
// Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
id := h.GetInt(c, "id", 0)
userId := h.GetInt(c, "user_id", 0)
action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}

161
api/handler/net_handler.go Normal file
View File

@@ -0,0 +1,161 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"io"
"net/http"
"time"
)
type NetHandler struct {
BaseHandler
uploaderManager *oss.UploaderManager
}
func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *NetHandler {
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *NetHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
logger.Info("upload file: ", file.Name)
// cut the file name if it's too long
if len(file.Name) > 100 {
file.Name = file.Name[:90] + file.Ext
}
userId := h.GetLoginUserId(c)
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}
func (h *NetHandler) List(c *gin.Context) {
var data struct {
Urls []string `json:"urls,omitempty"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
session := h.DB.Session(&gorm.Session{})
session = session.Where("user_id = ?", userId)
if len(data.Urls) > 0 {
session = session.Where("url IN ?", data.Urls)
}
// 统计总数
var total int64
session.Model(&model.File{}).Count(&total)
if data.Page > 0 && data.PageSize > 0 {
offset := (data.Page - 1) * data.PageSize
session = session.Offset(offset).Limit(data.PageSize)
}
err := session.Order("id desc").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, files))
}
// Remove remove files
func (h *NetHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}
func (h *NetHandler) Download(c *gin.Context) {
fileUrl := c.Query("url")
// 使用http工具下载文件
if fileUrl == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
// 使用http.Get下载文件
r, err := http.Get(fileUrl)
if err != nil {
resp.ERROR(c, err.Error())
return
}
defer r.Body.Close()
if r.StatusCode != http.StatusOK {
resp.ERROR(c, "error status"+r.Status)
return
}
c.Status(http.StatusOK)
// 将下载的文件内容写入响应
_, _ = io.Copy(c.Writer, r.Body)
}

View File

@@ -14,6 +14,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -27,23 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 订单列表
func (h *OrderHandler) List(c *gin.Context) {
var data struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order
var list = make([]vo.Order, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var order vo.Order
@@ -52,11 +48,51 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
if !ok {
payMethod = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
if !ok {
payName = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
list = append(list, order)
} else {
logger.Error(err)
}
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
}
// Query 查询订单状态
func (h *OrderHandler) Query(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no")
var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
if order.Status == types.OrderPaidSuccess {
resp.SUCCESS(c, gin.H{"status": order.Status})
return
}
counter := 0
for {
time.Sleep(time.Second)
var item model.Order
h.DB.Where("order_no = ?", orderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
}
counter++
}
resp.SUCCESS(c, gin.H{"status": order.Status})
}

View File

@@ -8,6 +8,8 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"embed"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
@@ -15,13 +17,8 @@ import (
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"embed"
"encoding/base64"
"fmt"
"github.com/shopspring/decimal"
"math"
"net/http"
"net/url"
"sync"
"time"
@@ -29,343 +26,192 @@ import (
"gorm.io/gorm"
)
const (
PayWayAlipay = "支付宝"
PayWayXunHu = "虎皮椒"
PayWayJs = "PayJS"
)
type PayWay struct {
Name string `json:"name"`
Value string `json:"value"`
}
// PaymentHandler 支付服务回调 handler
type PaymentHandler struct {
BaseHandler
alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService
js *payment.PayJS
snowflake *service.Snowflake
fs embed.FS
lock sync.Mutex
alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService
geekPayService *payment.GeekPayService
wechatPayService *payment.WechatPayService
snowflake *service.Snowflake
userService *service.UserService
fs embed.FS
lock sync.Mutex
signKey string // 用来签名的随机秘钥
}
func NewPaymentHandler(
server *core.AppServer,
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
js *payment.PayJS,
geekPayService *payment.GeekPayService,
wechatPayService *payment.WechatPayService,
db *gorm.DB,
userService *service.UserService,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
js: js,
snowflake: snowflake,
fs: fs,
lock: sync.Mutex{},
alipayService: alipayService,
huPiPayService: huPiPayService,
geekPayService: geekPayService,
wechatPayService: wechatPayService,
snowflake: snowflake,
userService: userService,
fs: fs,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
signKey: utils.RandString(32),
}
}
func (h *PaymentHandler) DoPay(c *gin.Context) {
orderNo := h.GetTrim(c, "order_no")
payWay := h.GetTrim(c, "pay_way")
if orderNo == "" {
func (h *PaymentHandler) Pay(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"`
PayType string `json:"pay_type"`
ProductId int `json:"product_id"`
UserId int `json:"user_id"`
Device string `json:"device"`
Host string `json:"host"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
var product model.Product
err := h.DB.Where("id", data.ProductId).First(&product).Error
if err != nil {
resp.ERROR(c, "Product not found")
return
}
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
if order.Status == types.OrderPaidSuccess {
resp.ERROR(c, "This order had been paid, please do not pay twice")
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
err = h.DB.Where("id", data.UserId).First(&user).Error
if err != nil {
resp.NotAuth(c)
return
}
// 更新扫码状态
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
if payWay == "alipay" { // 支付宝
// 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL
returnURL := "" // 关闭同步回跳
amount := fmt.Sprintf("%.2f", order.Amount)
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payURL, returnURL, notifyURL string
switch data.PayWay {
case "alipay":
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
notifyURL = h.App.Config.AlipayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
}
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
returnURL = h.App.Config.AlipayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
money := fmt.Sprintf("%.2f", amount)
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
c.Redirect(302, uri)
return
} else if payWay == "hupi" { // 虎皮椒支付
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", order.Amount),
Title: order.Subject,
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
WapName: "极客学长",
break
case "wechat":
if h.App.Config.WechatPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
}
r, err := h.huPiPayService.Pay(params)
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name,
NotifyURL: notifyURL,
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
c.Redirect(302, r.URL)
}
resp.ERROR(c, "Invalid operations")
}
// OrderQuery 查询订单状态
func (h *PaymentHandler) OrderQuery(c *gin.Context) {
var data struct {
OrderNo string `json:"order_no"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var order model.Order
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
}
if order.Status == types.OrderPaidSuccess {
resp.SUCCESS(c, gin.H{"status": order.Status})
return
}
counter := 0
for {
time.Sleep(time.Second)
var item model.Order
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
}
counter++
}
resp.SUCCESS(c, gin.H{"status": order.Status})
}
// PayQrcode 生成支付 URL 二维码
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
var payWay string
var notifyURL string
switch data.PayWay {
break
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
default:
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
// PayJs 单独处理,只能用官方生成的二维码
if data.PayWay == "payjs" {
params := payment.JPayReq{
TotalFee: int(math.Ceil(order.Amount * 100)),
OutTradeNo: order.OrderNo,
Subject: product.Name,
}
r := h.js.Pay(params)
if r.IsOK() {
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
return
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
} else {
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
return
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
}
}
var logo string
if data.PayWay == "alipay" {
logo = "res/img/alipay.jpg"
} else if data.PayWay == "hupi" {
if h.App.Config.HuPiPayConfig.Name == "wechat" {
logo = "res/img/wechat-pay.jpg"
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
} else {
logo = "res/img/alipay.jpg"
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
}
file, err := h.fs.Open(logo)
if err != nil {
resp.ERROR(c, "error with open qrcode log file: "+err.Error())
return
}
parse, err := url.Parse(notifyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
imgData, err := utils.GenQrcode(imageURL, 400, file)
if err != nil {
resp.ERROR(c, err.Error())
return
}
imgDataBase64 := base64.StdEncoding.EncodeToString(imgData)
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
}
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
params := payment.HuPiPayReq{
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
CallbackURL: returnURL,
WapName: "极客学长",
}
r, err := h.huPiPayService.Pay(params)
WapName: "GeekAI助手",
})
if err != nil {
logger.Error("error with generating Pay URL: ", err.Error())
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
resp.ERROR(c, err.Error())
return
}
payURL = r.URL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
returnURL = h.App.Config.JPayConfig.ReturnURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
params := url.Values{}
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
params.Add("out_trade_no", orderNo)
params.Add("body", product.Name)
params.Add("notify_url", notifyURL)
params.Add("auto", "0")
payURL = h.js.PayH5(params)
case "alipay":
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
returnURL = h.App.Config.AlipayConfig.ReturnURL
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
break
case "geek":
if h.App.Config.GeekPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
}
if h.App.Config.GeekPayConfig.ReturnURL != "" {
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
}
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
params := payment.GeekPayParams{
OutTradeNo: orderNo,
Method: "web",
Name: product.Name,
Money: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
Type: data.PayType,
ReturnURL: returnURL,
NotifyURL: notifyURL,
}
res, err := h.geekPayService.Pay(params)
if err != nil {
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
resp.ERROR(c, err.Error())
return
}
payURL = res.PayURL
default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
resp.ERROR(c, "不支持的支付渠道")
return
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
@@ -374,7 +220,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
Price: product.Price,
Discount: product.Discount,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
@@ -383,26 +228,24 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
PayWay: data.PayWay,
PayType: data.PayType,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
err = h.DB.Create(&order).Error
if err != nil {
resp.ERROR(c, "error with create order: "+err.Error())
return
}
resp.SUCCESS(c, payURL)
}
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error)
logger.Error(err)
return err
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
if err != nil {
return fmt.Errorf("error with fetch order: %v", err)
}
h.lock.Lock()
@@ -414,45 +257,24 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
}
var user model.User
res = h.DB.First(&user, order.UserId)
if res.Error != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error)
logger.Error(err)
return err
err = h.DB.First(&user, order.UserId).Error
if err != nil {
return fmt.Errorf("error with fetch user info: %v", err)
}
var remark types.OrderRemark
err := utils.JsonDecode(order.Remark, &remark)
err = utils.JsonDecode(order.Remark, &remark)
if err != nil {
err := fmt.Errorf("error with decode order remark: %v", err)
logger.Error(err)
return err
return fmt.Errorf("error with decode order remark: %v", err)
}
var opt string
var power int
if remark.Days > 0 { // VIP 充值
if user.ExpiredTime >= time.Now().Unix() {
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
opt = "VIP充值VIP 没到期,只延期不增加算力"
} else {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
power = h.App.SysConfig.VipMonthPower
opt = "VIP充值"
}
user.Vip = true
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
// 增加用户算力
err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
Type: types.PowerRecharge,
Model: order.PayWay,
Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
})
if err != nil {
return err
}
@@ -460,29 +282,16 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo
res = h.DB.Updates(&order)
if res.Error != nil {
err := fmt.Errorf("error with update order info: %v", res.Error)
logger.Error(err)
return err
err = h.DB.Updates(&order).Error
if err != nil {
return fmt.Errorf("error with update order info: %v", err)
}
// 更新产品销量
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
// 记录算力充值日志
if opt != "" {
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
err = h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).
UpdateColumn("sales", gorm.Expr("sales + ?", 1)).Error
if err != nil {
return fmt.Errorf("error with update product sales: %v", err)
}
return nil
@@ -490,17 +299,22 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
data := gin.H{}
payWays := make([]gin.H, 0)
if h.App.Config.AlipayConfig.Enabled {
data["alipay"] = gin.H{"name": "alipay"}
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
}
if h.App.Config.HuPiPayConfig.Enabled {
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
}
if h.App.Config.JPayConfig.Enabled {
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
if h.App.Config.GeekPayConfig.Enabled {
for _, v := range h.App.Config.GeekPayConfig.Methods {
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
}
}
resp.SUCCESS(c, data)
if h.App.Config.WechatPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
}
resp.SUCCESS(c, payWays)
}
// HuPiPayNotify 虎皮椒支付异步回调
@@ -513,15 +327,17 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,订单 NO%s交易流水号%s", orderNo, tradeNo)
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
if err = h.huPiPayService.Check(tradeNo); err != nil {
if err = h.huPiPayService.Check(orderNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
@@ -537,18 +353,18 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return
}
// TODO验证交易签名
res := h.alipayService.TradeVerify(c.Request.Form)
logger.Infof("验证支付结果:%+v", res)
if !res.Success() {
logger.Error("订单校验失败:", res.Message)
result := h.alipayService.TradeVerify(c.Request)
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
if !result.Success() {
logger.Error("订单校验失败:", result.Message)
c.String(http.StatusOK, "fail")
return
}
tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(res.OutTradeNo, tradeNo)
err = h.notify(result.OutTradeNo, tradeNo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
@@ -556,33 +372,59 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
c.String(http.StatusOK, "success")
}
// PayJsNotify PayJs 支付异步回调
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
// GeekPayNotify 支付异步回调
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
var params = make(map[string]string)
for k := range c.Request.URL.Query() {
params[k] = c.Query(k)
}
logger.Infof("收到GeekPay订单支付回调%+v", params)
// 检查支付状态
if params["trade_status"] != "TRADE_SUCCESS" {
c.String(http.StatusOK, "success")
return
}
sign := h.geekPayService.Sign(params)
if sign != c.Query("sign") {
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
c.String(http.StatusOK, "fail")
return
}
err := h.notify(params["out_trade_no"], params["trade_no"])
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// WechatPayNotify 微信商户支付异步回调
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("out_trade_no")
returnCode := c.Request.Form.Get("return_code")
logger.Infof("收到订单支付回调,订单 NO%s支付结果代码%v", orderNo, returnCode)
// 支付失败
if returnCode != "1" {
return
}
// 校验订单支付状态
tradeNo := c.Request.Form.Get("payjs_order_id")
err = h.js.Check(tradeNo)
if err != nil {
result := h.wechatPayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", result)
if !result.Success() {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
c.JSON(http.StatusBadRequest, gin.H{
"code": "FAIL",
"message": err.Error(),
})
return
}
err = h.notify(orderNo, tradeNo)
err = h.notify(result.OutTradeNo, result.TradeId)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}

View File

@@ -0,0 +1,88 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"sync"
"time"
)
type RedeemHandler struct {
BaseHandler
lock sync.Mutex
userService *service.UserService
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler {
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
}
func (h *RedeemHandler) Verify(c *gin.Context) {
var data struct {
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
h.lock.Lock()
defer h.lock.Unlock()
var item model.Redeem
res := h.DB.Where("code", data.Code).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的兑换码!")
return
}
if !item.Enabled {
resp.ERROR(c, "当前兑换码已被禁用!")
return
}
if item.RedeemedAt > 0 {
resp.ERROR(c, "当前兑换码已使用,请勿重复使用!")
return
}
tx := h.DB.Begin()
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
Type: types.PowerRedeem,
Model: "兑换码",
Remark: fmt.Sprintf("兑换码核销,算力:%d兑换码%s...", item.Power, item.Code[:10]),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 更新核销状态
item.RedeemedAt = time.Now().Unix()
item.UserId = userId
err = tx.Updates(&item).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -1,106 +0,0 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
BaseHandler
lock sync.Mutex
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
func (h *RewardHandler) Verify(c *gin.Context) {
var data struct {
TxId string `json:"tx_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
}
// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
data.TxId = strings.ReplaceAll(data.TxId, " ", "")
h.lock.Lock()
defer h.lock.Unlock()
var item model.Reward
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!")
return
}
if item.Status {
resp.ERROR(c, "当前众筹交易流水号已经被核销,请不要重复核销!")
return
}
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 更新核销状态
item.Status = true
item.UserId = user.Id
item.Exchange = utils.JsonEncode(exchange)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
@@ -18,7 +19,6 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"fmt"
"net/http"
"time"
@@ -31,19 +31,27 @@ import (
type SdJobHandler struct {
BaseHandler
redis *redis.Client
pool *sd.ServicePool
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
redis *redis.Client
sdService *sd.Service
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
userService *service.UserService
}
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
func NewSdJobHandler(app *core.AppServer,
db *gorm.DB,
service *sd.Service,
manager *oss.UploaderManager,
snowflake *service.Snowflake,
userService *service.UserService,
levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{
pool: pool,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
sdService: service,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -68,7 +76,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
h.sdService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
@@ -79,11 +87,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false
}
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
return false
}
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
@@ -99,10 +102,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return
}
var data struct {
SessionId string `json:"session_id"`
types.SdTaskParams
}
var data types.SdTaskParams
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
@@ -167,35 +167,27 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return
}
h.pool.PushTask(types.SdTask{
Id: int(job.Id),
SessionId: data.SessionId,
Type: types.TaskImage,
Params: params,
UserId: userId,
h.sdService.PushTask(types.SdTask{
Id: int(job.Id),
Type: types.TaskImage,
Params: params,
UserId: userId,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.sdService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -216,13 +208,13 @@ func (h *SdJobHandler) ImgWall(c *gin.Context) {
// JobList 获取 SD 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
finish := h.GetBool(c, "finish")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
err, jobs := h.getData(finish, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -232,11 +224,11 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
session = session.Where("progress >= ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
@@ -251,10 +243,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
session = session.Offset(offset).Limit(pageSize)
}
// 统计总数
var total int64
session.Model(&model.SdJob{}).Count(&total)
var items []model.SdJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
return res.Error, vo.Page{}
}
var jobs = make([]vo.SdJob, 0)
@@ -276,56 +272,60 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
jobs = append(jobs, job)
}
return nil, jobs
return nil, vo.NewPage(total, page, pageSize, jobs)
}
// Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.SdJob
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// remove job recode
res := h.DB.Delete(&model.SdJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s Err: %s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte(sd.Finished))
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *SdJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
action := h.GetBool(c, "action") // 发布动作true => 发布false => 取消分享
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}

View File

@@ -49,28 +49,50 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
var data struct {
Receiver string `json:"receiver"` // 接收者
Key string `json:"key"`
Dots string `json:"dots"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if !h.captcha.Check(data) {
resp.ERROR(c, "验证码错误,请先完人机验证")
return
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
code := utils.RandomNumber(6)
var err error
if strings.Contains(data.Receiver, "@") { // email
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!")
return
}
// 检查邮箱后缀是否在白名单
if len(h.App.SysConfig.EmailWhiteList) > 0 {
inWhiteList := false
for _, suffix := range h.App.SysConfig.EmailWhiteList {
if strings.HasSuffix(data.Receiver, suffix) {
inWhiteList = true
break
}
}
if !inWhiteList {
resp.ERROR(c, "邮箱后缀不在白名单中")
return
}
}
err = h.smtp.SendVerifyCode(data.Receiver, code)
} else {
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!")
return
}
@@ -89,5 +111,9 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
return
}
resp.SUCCESS(c)
if h.App.Debug {
resp.SUCCESS(c, code)
} else {
resp.SUCCESS(c)
}
}

398
api/handler/suno_handler.go Normal file
View File

@@ -0,0 +1,398 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/suno"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
type SunoHandler struct {
BaseHandler
sunoService *suno.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
return &SunoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
sunoService: service,
uploader: uploader,
userService: userService,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SunoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.sunoService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SunoHandler) Create(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"`
Model string `json:"model"`
Tags string `json:"tags"`
Title string `json:"title"`
Type int `json:"type"`
RefTaskId string `json:"ref_task_id"` // 续写的任务id
ExtendSecs int `json:"extend_secs"` // 续写秒数
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
SongId string `json:"song_id,omitempty"` // 要拼接的歌曲id
AudioURL string `json:"audio_url,omitempty"` // 上传自己创作的歌曲
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
// 歌曲拼接
if data.SongId != "" && data.Type == 3 {
var song model.SunoJob
if err := h.DB.Where("song_id = ?", data.SongId).First(&song).Error; err == nil {
data.Instrumental = song.Instrumental
data.Model = song.ModelName
data.Tags = song.Tags
}
// 拼接歌词
var refSong model.SunoJob
if err := h.DB.Where("song_id = ?", data.RefSongId).First(&refSong).Error; err == nil {
data.Prompt = fmt.Sprintf("%s\n%s", song.Prompt, refSong.Prompt)
}
}
// 插入数据库
job := model.SunoJob{
UserId: int(h.GetLoginUserId(c)),
Prompt: data.Prompt,
Instrumental: data.Instrumental,
ModelName: data.Model,
Tags: data.Tags,
Title: data.Title,
Type: data.Type,
RefSongId: data.RefSongId,
RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower,
SongId: utils.RandString(32),
}
if data.Lyrics != "" {
job.Prompt = data.Lyrics
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.sunoService.PushTask(types.SunoTask{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Title: job.Title,
RefTaskId: data.RefTaskId,
RefSongId: data.RefSongId,
ExtendSecs: data.ExtendSecs,
Prompt: job.Prompt,
Tags: data.Tags,
Model: data.Model,
Instrumental: data.Instrumental,
SongId: data.SongId,
AudioURL: data.AudioURL,
})
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
client := h.sunoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数
var total int64
session.Model(&model.SunoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.SunoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 初始化续写关系
songIds := make([]string, 0)
for _, v := range list {
if v.RefTaskId != "" {
songIds = append(songIds, v.RefSongId)
}
}
var tasks []model.SunoJob
h.DB.Where("song_id IN ?", songIds).Find(&tasks)
songMap := make(map[string]model.SunoJob)
for _, t := range tasks {
songMap[t.SongId] = t
}
// 转换为 VO
items := make([]vo.SunoJob, 0)
for _, v := range list {
var item vo.SunoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if s, ok := songMap[v.RefSongId]; ok {
item.RefSong = map[string]interface{}{
"id": s.Id,
"title": s.Title,
"cover": s.CoverURL,
"audio": s.AudioURL,
}
}
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *SunoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.SunoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 只有失败,或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
return
}
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 恢复用户算力
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
}
func (h *SunoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *SunoHandler) Update(c *gin.Context) {
var data struct {
Id int `json:"id"`
Title string `json:"title"`
Cover string `json:"cover"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id == 0 || data.Title == "" || data.Cover == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
var item model.SunoJob
if err := h.DB.Where("id", data.Id).Where("user_id", userId).First(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
item.Title = data.Title
item.CoverURL = data.Cover
if err := h.DB.Updates(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Detail 歌曲详情
func (h *SunoHandler) Detail(c *gin.Context) {
songId := c.Query("song_id")
if songId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
var item model.SunoJob
if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
// 读取用户信息
var user model.User
if err := h.DB.Where("id", item.UserId).First(&user).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
var itemVo vo.SunoJob
if err := utils.CopyObject(item, &itemVo); err != nil {
resp.ERROR(c, err.Error())
return
}
itemVo.CreatedAt = item.CreatedAt.Unix()
itemVo.User = map[string]interface{}{
"nickname": user.Nickname,
"avatar": user.Avatar,
}
resp.SUCCESS(c, itemVo)
}
// Play 增加歌曲播放次数
func (h *SunoHandler) Play(c *gin.Context) {
songId := c.Query("song_id")
if songId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
}
const genLyricTemplate = `
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
请以【%s】为主题创作一首歌曲歌曲时间不要太短3分钟左右不要输出任何解释性的内容。
输出格式如下:
歌曲名称
第一节:
{{歌词内容}}
副歌:
{{歌词内容}}
第二节:
{{歌词内容}}
副歌:
{{歌词内容}}
尾声:
{{歌词内容}}
`
// Lyric 生成歌词
func (h *SunoHandler) Lyric(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}

View File

@@ -3,15 +3,52 @@ package handler
import (
"geekai/service"
"geekai/service/payment"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"net/http"
)
type TestHandler struct {
db *gorm.DB
snowflake *service.Snowflake
js *payment.PayJS
js *payment.GeekPayService
}
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js}
}
func (h *TestHandler) SseTest(c *gin.Context) {
//c.Header("Content-Type", "text/event-stream")
//c.Header("Cache-Control", "no-cache")
//c.Header("Connection", "keep-alive")
//
//
//// 模拟实时数据更新
//for i := 0; i < 10; i++ {
// // 发送 SSE 数据
// _, err := fmt.Fprintf(c.Writer, "data: %v\n\n", data)
// if err != nil {
// return
// }
// c.Writer.Flush() // 确保立即发送数据
// time.Sleep(1 * time.Second) // 每秒发送一次数据
//}
//c.Abort()
}
func (h *TestHandler) PostTest(c *gin.Context) {
var data struct {
Message string `json:"message"`
UserId uint `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 将参数存储在上下文中
c.Set("data", data)
c.Next()
}

View File

@@ -1,101 +0,0 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
userId := h.GetLoginUserId(c)
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}
func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 {
for _, v := range items {
var file vo.File
err := utils.CopyObject(v, &file)
if err != nil {
logger.Error(err)
continue
}
file.CreatedAt = v.CreatedAt.Unix()
files = append(files, file)
}
}
resp.SUCCESS(c, files)
}
// Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return
}
// remove files
objectKey := file.ObjKey
if objectKey == "" {
objectKey = file.URL
}
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}

View File

@@ -11,10 +11,12 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/imroc/req/v3"
"strings"
"time"
@@ -28,19 +30,28 @@ import (
type UserHandler struct {
BaseHandler
searcher *xdb.Searcher
redis *redis.Client
searcher *xdb.Searcher
redis *redis.Client
licenseService *service.LicenseService
captcha *service.CaptchaService
userService *service.UserService
}
func NewUserHandler(
app *core.AppServer,
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client) *UserHandler {
client *redis.Client,
captcha *service.CaptchaService,
userService *service.UserService,
licenseService *service.LicenseService) *UserHandler {
return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher,
redis: client,
BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher,
redis: client,
captcha: captcha,
licenseService: licenseService,
userService: userService,
}
}
@@ -50,24 +61,58 @@ func (h *UserHandler) Register(c *gin.Context) {
var data struct {
RegWay string `json:"reg_way"`
Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Password string `json:"password"`
Code string `json:"code"`
InviteCode string `json:"invite_code"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
data.Password = strings.TrimSpace(data.Password)
if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符")
return
}
// 检测最大注册人数
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
// 检查验证码
var key string
if data.RegWay == "email" || data.RegWay == "mobile" {
key = CodeStorePrefix + data.Username
if data.RegWay == "email" {
key = CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
} else if data.RegWay == "mobile" {
key = CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
@@ -85,9 +130,19 @@ func (h *UserHandler) Register(c *gin.Context) {
}
}
// check if the username is exists
// check if the username is existing
var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item)
session := h.DB.Session(&gorm.Session{})
if data.Mobile != "" {
session = session.Where("mobile = ?", data.Mobile)
data.Username = data.Mobile
} else if data.Email != "" {
session = session.Where("email = ?", data.Email)
data.Username = data.Email
} else if data.Username != "" {
session = session.Where("username = ?", data.Username)
}
session.First(&item)
if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册")
return
@@ -96,8 +151,9 @@ func (h *UserHandler) Register(c *gin.Context) {
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Mobile: data.Mobile,
Email: data.Email,
Password: utils.GenPassword(data.Password, salt),
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
@@ -106,10 +162,19 @@ func (h *UserHandler) Register(c *gin.Context) {
Power: h.App.SysConfig.InitPower,
}
res = h.DB.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
// 被邀请人也获得赠送算力
if data.InviteCode != "" {
user.Power += h.App.SysConfig.InvitePower
}
if h.licenseService.GetLicense().Configs.DeCopy {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
}
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
@@ -118,35 +183,35 @@ func (h *UserHandler) Register(c *gin.Context) {
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
// 记录邀请算力充值日志
var inviter model.User
h.DB.Where("id", inviteCode.UserId).First(&inviter)
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
// 添加邀请记录
h.DB.Create(&model.InviteLog{
err := tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
})
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
@@ -163,7 +228,7 @@ func (h *UserHandler) Register(c *gin.Context) {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, tokenString)
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
}
// Login 用户登录
@@ -171,20 +236,41 @@ func (h *UserHandler) Login(c *gin.Context) {
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
needVerify, err := h.redis.Get(c, verifyKey).Bool()
if h.App.SysConfig.EnabledVerify && needVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var user model.User
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名不存在")
return
}
password := utils.GenPassword(data.Password, user.Salt)
if password != user.Password {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名或密码错误")
return
}
@@ -217,12 +303,14 @@ func (h *UserHandler) Login(c *gin.Context) {
return
}
// 保存到 redis
key := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
sessionKey := fmt.Sprintf("users/%d", user.Id)
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, tokenString)
// 移除登录行为验证码
h.redis.Del(c, verifyKey)
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
}
// Logout 注 销
@@ -234,21 +322,176 @@ func (h *UserHandler) Logout(c *gin.Context) {
resp.SUCCESS(c)
}
// CLogin 第三方登录请求二维码
func (h *UserHandler) CLogin(c *gin.Context) {
returnURL := h.GetTrim(c, "return_url")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
resp.SUCCESS(c, res.Data)
}
// CLoginCallback 第三方登录回调
func (h *UserHandler) CLoginCallback(c *gin.Context) {
loginType := c.Query("login_type")
code := c.Query("code")
userId := h.GetInt(c, "user_id", 0)
action := c.Query("action")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
// login successfully
data := res.Data.(map[string]interface{})
var user model.User
if action == "bind" && userId > 0 {
err = h.DB.Where("openid", data["openid"]).First(&user).Error
if err == nil {
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
return
}
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.ERROR(c, "绑定用户不存在")
return
}
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": ""})
return
}
session := gin.H{}
tx := h.DB.Where("openid", data["openid"]).First(&user)
if tx.Error != nil {
// create new user
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
salt := utils.RandString(8)
password := fmt.Sprintf("%d", utils.RandomNumber(8))
user = model.User{
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
Password: utils.GenPassword(password, salt),
Avatar: fmt.Sprintf("%s", data["avatar"]),
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
Power: h.App.SysConfig.InitPower,
OpenId: fmt.Sprintf("%s", data["openid"]),
Nickname: fmt.Sprintf("%s", data["nickname"]),
}
tx = h.DB.Create(&user)
if tx.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(tx.Error)
return
}
session["username"] = user.Username
session["password"] = password
} else { // login directly
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
}
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
session["token"] = tokenString
resp.SUCCESS(c, session)
}
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} else {
resp.NotAuth(c)
if err != nil {
resp.NotAuth(c, err.Error())
return
}
var userVo vo.User
err = utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 用户 VIP 到期
if user.ExpiredTime > 0 && user.ExpiredTime < time.Now().Unix() {
h.DB.Model(&user).UpdateColumn("vip", false)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
}
type userProfile struct {
@@ -335,20 +578,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
}
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败")
err = h.DB.Model(&user).UpdateColumn("password", newPass).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// ResetPass 重置密码
// ResetPass 找回密码
func (h *UserHandler) ResetPass(c *gin.Context) {
var data struct {
Username string `json:"username"`
Type string `json:"type"` // 验证类别mobile, email
Mobile string `json:"mobile"` // 手机号
Email string `json:"email"` // 邮箱地址
Code string `json:"code"` // 验证码
Password string `json:"password"` // 新密码
}
@@ -357,37 +601,47 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
return
}
session := h.DB.Session(&gorm.Session{})
var key string
if data.Type == "email" {
session = session.Where("email", data.Email)
key = CodeStorePrefix + data.Email
} else if data.Type == "mobile" {
session = session.Where("mobile", data.Mobile)
key = CodeStorePrefix + data.Mobile
} else {
resp.ERROR(c, "验证类别错误")
return
}
var user model.User
res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil {
err := session.First(&user).Error
if err != nil {
resp.ERROR(c, "用户不存在!")
return
}
// 检查验证码
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
resp.ERROR(c, "验证码错误")
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
err = h.DB.Model(&user).UpdateColumn("password", password).Error
if err != nil {
resp.ERROR(c, err.Error())
} else {
h.redis.Del(c, key)
resp.SUCCESS(c)
}
}
// BindUsername 重置账
func (h *UserHandler) BindUsername(c *gin.Context) {
// BindMobile 绑定手机
func (h *UserHandler) BindMobile(c *gin.Context) {
var data struct {
Username string `json:"username"`
Code string `json:"code"`
Mobile string `json:"mobile"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -395,7 +649,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
}
// 检查验证码
key := CodeStorePrefix + data.Username
key := CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
@@ -404,21 +658,56 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定
var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item)
res := h.DB.Where("mobile", data.Mobile).First(&item)
if res.Error == nil {
resp.ERROR(c, "该号已经其他账号绑定")
resp.ERROR(c, "该手机号已经绑定了其他账号,请更换手机号")
return
}
user, err := h.GetLoginUser(c)
userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("mobile", data.Mobile).Error
if err != nil {
resp.NotAuth(c)
return
}
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
resp.ERROR(c, err.Error())
return
}
_ = h.redis.Del(c, key) // 删除短信验证码
resp.SUCCESS(c)
}
// BindEmail 绑定邮箱
func (h *UserHandler) BindEmail(c *gin.Context) {
var data struct {
Email string `json:"email"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检查验证码
key := CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.DB.Where("email", data.Email).First(&item)
if res.Error == nil {
resp.ERROR(c, "该邮箱地址已经绑定了其他账号,请更邮箱地址")
return
}
userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("email", data.Email).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}

View File

@@ -0,0 +1,250 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/video"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
type VideoHandler struct {
BaseHandler
videoService *video.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
return &VideoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
videoService: service,
uploader: uploader,
userService: userService,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *VideoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.videoService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"`
ExpandPrompt bool `json:"expand_prompt,omitempty"`
Loop bool `json:"loop,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
if data.Prompt == "" {
resp.ERROR(c, "prompt is needed")
return
}
userId := int(h.GetLoginUserId(c))
params := types.VideoParams{
PromptOptimize: data.ExpandPrompt,
Loop: data.Loop,
StartImgURL: data.FirstFrameImg,
EndImgURL: data.EndFrameImg,
}
// 插入数据库
job := model.VideoJob{
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
Params: utils.JsonEncode(params),
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.videoService.PushTask(types.VideoTask{
Id: job.Id,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
})
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "luma",
Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
client := h.videoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
func (h *VideoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
t := c.Query("type")
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
all := h.GetBool(c, "all")
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
if t != "" {
session = session.Where("type", t)
}
if all {
session = session.Where("publish", 0).Where("progress", 100)
} else {
session = session.Where("user_id", h.GetLoginUserId(c))
}
// 统计总数
var total int64
session.Model(&model.VideoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.VideoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 转换为 VO
items := make([]vo.VideoJob, 0)
for _, v := range list {
var item vo.VideoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *VideoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 只有失败或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) {
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
return
}
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 恢复算力
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "luma",
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
}
func (h *VideoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.DB.Model(&job).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -23,7 +23,8 @@ import (
"geekai/service/payment"
"geekai/service/sd"
"geekai/service/sms"
"geekai/service/wx"
"geekai/service/suno"
"geekai/service/video"
"geekai/store"
"io"
"log"
@@ -128,9 +129,9 @@ func main() {
fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler),
fx.Provide(chatimpl.NewChatHandler),
fx.Provide(handler.NewUploadHandler),
fx.Provide(handler.NewNetHandler),
fx.Provide(handler.NewSmsHandler),
fx.Provide(handler.NewRewardHandler),
fx.Provide(handler.NewRedeemHandler),
fx.Provide(handler.NewCaptchaHandler),
fx.Provide(handler.NewMidJourneyHandler),
fx.Provide(handler.NewChatModelHandler),
@@ -145,8 +146,8 @@ func main() {
fx.Provide(admin.NewAdminHandler),
fx.Provide(admin.NewApiKeyHandler),
fx.Provide(admin.NewUserHandler),
fx.Provide(admin.NewChatRoleHandler),
fx.Provide(admin.NewRewardHandler),
fx.Provide(admin.NewChatAppHandler),
fx.Provide(admin.NewRedeemHandler),
fx.Provide(admin.NewDashboardHandler),
fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler),
@@ -160,53 +161,59 @@ func main() {
return service.NewCaptchaService(config.ApiConfig)
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) {
service.Run()
service.CheckTaskNotify()
service.DownloadImages()
service.CheckTaskStatus()
fx.Invoke(func(s *dalle.Service) {
s.Run()
s.CheckTaskNotify()
s.DownloadImages()
s.CheckTaskStatus()
}),
// 邮件服务
fx.Provide(service.NewSmtpService),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
if config.WeChatBot {
err := bot.Run()
if err != nil {
logger.Error("微信登录失败:", err)
}
}
// License 服务
fx.Provide(service.NewLicenseService),
fx.Invoke(func(licenseService *service.LicenseService) {
licenseService.SyncLicense()
}),
// MidJourney service pool
fx.Provide(mj.NewServicePool),
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
if pool.HasAvailableService() {
pool.DownloadImages()
pool.CheckTaskNotify()
pool.SyncTaskProgress()
}
fx.Provide(mj.NewService),
fx.Provide(mj.NewClient),
fx.Invoke(func(s *mj.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadImages()
}),
// Stable Diffusion 机器人
fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
pool.InitServices(config.SdConfigs)
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
}
fx.Provide(sd.NewService),
fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
s.Run()
s.CheckTaskStatus()
s.CheckTaskNotify()
}),
fx.Provide(suno.NewService),
fx.Invoke(func(s *suno.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(video.NewService),
fx.Invoke(func(s *video.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewPayJS),
fx.Provide(payment.NewJPayService),
fx.Provide(payment.NewWechatService),
fx.Provide(service.NewSnowflake),
fx.Provide(service.NewXXLJobExecutor),
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
@@ -219,8 +226,9 @@ func main() {
// 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
group := s.Engine.Group("/api/role/")
group := s.Engine.Group("/api/app/")
group.GET("list", h.List)
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateRole)
}),
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
@@ -232,8 +240,11 @@ func main() {
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/username", h.BindUsername)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.POST("resetPass", h.ResetPass)
group.GET("clogin", h.CLogin)
group.GET("clogin/callback", h.CLoginCallback)
}),
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
group := s.Engine.Group("/api/chat/")
@@ -247,10 +258,11 @@ func main() {
group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
}),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
s.Engine.POST("/api/upload", h.Upload)
s.Engine.GET("/api/upload/list", h.List)
s.Engine.POST("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
s.Engine.GET("/api/download", h.Download)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/")
@@ -263,8 +275,8 @@ func main() {
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
}),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
group := s.Engine.Group("/api/reward/")
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
group := s.Engine.Group("/api/redeem/")
group.POST("verify", h.Verify)
}),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
@@ -275,8 +287,8 @@ func main() {
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
@@ -284,19 +296,23 @@ func main() {
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
group.GET("license", h.License)
}),
// 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/")
group.POST("config/update", h.Update)
group.GET("config/get", h.Get)
group := s.Engine.Group("/api/admin/config")
group.POST("update", h.Update)
group.GET("get", h.Get)
group.POST("active", h.Active)
group.GET("fixData", h.FixData)
group.GET("license", h.GetLicense)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/")
@@ -324,7 +340,7 @@ func main() {
group.GET("loginLog", h.LoginLog)
group.POST("resetPass", h.ResetPass)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatRoleHandler) {
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
group := s.Engine.Group("/api/admin/role/")
group.GET("list", h.List)
group.POST("save", h.Save)
@@ -332,9 +348,11 @@ func main() {
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/")
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
group := s.Engine.Group("/api/admin/redeem/")
group.GET("list", h.List)
group.POST("create", h.Create)
group.POST("set", h.Set)
group.POST("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
@@ -355,14 +373,12 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/")
group.GET("doPay", h.DoPay)
group.POST("doPay", h.Pay)
group.GET("payWays", h.GetPayWays)
group.POST("query", h.OrderQuery)
group.POST("qrcode", h.PayQrcode)
group.POST("mobile", h.Mobile)
group.POST("alipay/notify", h.AlipayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify)
group.POST("notify/alipay", h.AlipayNotify)
group.GET("notify/geek", h.GeekPayNotify)
group.POST("notify/wechat", h.WechatPayNotify)
group.POST("notify/hupi", h.HuPiPayNotify)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/")
@@ -376,10 +392,12 @@ func main() {
group := s.Engine.Group("/api/admin/order/")
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
}),
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
group := s.Engine.Group("/api/order/")
group.POST("list", h.List)
group.GET("list", h.List)
group.GET("query", h.Query)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
group := s.Engine.Group("/api/product/")
@@ -390,7 +408,7 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/")
group.GET("code", h.Code)
group.POST("list", h.List)
group.GET("list", h.List)
group.GET("hits", h.Hits)
}),
@@ -404,13 +422,6 @@ func main() {
group.GET("token", h.GenToken)
}),
// 验证码
fx.Provide(admin.NewCaptchaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
group := s.Engine.Group("/api/admin/login/")
group.GET("captcha", h.GetCaptcha)
}),
fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload)
@@ -422,6 +433,7 @@ func main() {
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
@@ -465,14 +477,56 @@ func main() {
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Provide(handler.NewSunoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
group := s.Engine.Group("/api/suno")
group.Any("client", h.Client)
group.POST("create", h.Create)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.POST("update", h.Update)
group.GET("detail", h.Detail)
group.GET("play", h.Play)
group.POST("lyric", h.Lyric)
}),
fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.Any("client", h.Client)
group.POST("luma/create", h.LumaCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Provide(admin.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
group := s.Engine.Group("/api/admin/app/type")
group.POST("save", h.Save)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}),
fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type")
group.GET("list", h.List)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")
group.Any("sse", h.PostTest, h.SseTest)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() {
err := s.Run(db)
if err != nil {
log.Fatal(err)
logger.Error(err)
os.Exit(0)
}
}()
}),

BIN
api/res/img/geek-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
api/res/img/qq-pay.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@@ -14,7 +14,6 @@ import (
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
@@ -36,9 +35,10 @@ type Service struct {
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
userService *service.UserService
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
@@ -46,6 +46,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
userService: userService,
}
}
@@ -70,10 +71,10 @@ func (s *Service) Run() {
if err != nil {
logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": -1,
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
}
}
}()
@@ -109,13 +110,12 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
@@ -124,14 +124,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", errors.New("insufficient of power")
}
// 扣减算力
err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
})
if err != nil {
return "", fmt.Errorf("error with decrease power: %v", err)
}
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI).
Where("type", "img").
err = s.db.Where("type", "dalle").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return "", fmt.Errorf("no available DALL-E api key: %v", err)
}
var res imgRes
@@ -139,36 +148,42 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
Style: task.Style,
Quality: task.Quality,
}).
SetBody(reqBody).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt,
})
}).Error
if err != nil {
return "", fmt.Errorf("err with update database: %v", err)
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
@@ -178,25 +193,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
}
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil
}
@@ -204,7 +200,7 @@ func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running DALL-E task notify checking ...")
for {
var message sd.NotifyMessage
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
@@ -221,6 +217,30 @@ func (s *Service) CheckTaskNotify() {
}()
}
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running DALL-E task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 超时的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
time.Sleep(time.Second * 10)
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.DallJob
@@ -254,7 +274,7 @@ func (s *Service) DownloadImages() {
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
if err != nil {
return "", err
}
@@ -264,47 +284,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.DallJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@@ -0,0 +1,197 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store"
"time"
"github.com/imroc/req/v3"
)
type LicenseService struct {
config types.ApiConfig
levelDB *store.LevelDB
license *types.License
urlWhiteList []string
machineId string
}
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License
return &LicenseService{
config: server.Config.ApiConfig,
levelDB: levelDB,
license: &license,
machineId: "",
}
}
type License struct {
Name string `json:"name"`
License string `json:"license"`
MachineId string `json:"mid"`
ActiveAt int64 `json:"active_at"`
ExpiredAt int64 `json:"expired_at"`
UserNum int `json:"user_num"`
Configs types.LicenseConfig `json:"configs"`
}
// ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return fmt.Errorf("发送激活请求失败: %v", err)
}
if response.IsErrorState() {
return fmt.Errorf("发送激活请求失败:%v", response.Status)
}
if res.Code != types.Success {
return fmt.Errorf("激活失败:%v", res.Message)
}
s.license = &types.License{
Key: license,
MachineId: machineId,
Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}
err = s.levelDB.Put(types.LicenseKey, s.license)
if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err)
}
return nil
}
// SyncLicense 定期同步 License
func (s *LicenseService) SyncLicense() {
go func() {
retryCounter := 0
for {
license, err := s.fetchLicense()
if err != nil {
retryCounter++
if retryCounter < 5 {
logger.Warn(err)
}
s.license.IsActive = false
} else {
s.license = license
}
urls, err := s.fetchUrlWhiteList()
if err == nil {
s.urlWhiteList = urls
}
time.Sleep(time.Second * 10)
}
}()
}
func (s *LicenseService) fetchLicense() (*types.License, error) {
//var res struct {
// Code types.BizCode `json:"code"`
// Message string `json:"message"`
// Data License `json:"data"`
//}
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
//response, err := req.C().R().
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
// SetSuccessResult(&res).Post(apiURL)
//if err != nil {
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
//}
//if response.IsErrorState() {
// return nil, fmt.Errorf("激活失败:%v", response.Status)
//}
//if res.Code != types.Success {
// return nil, fmt.Errorf("激活失败:%v", res.Message)
//}
return &types.License{
Key: "abc",
MachineId: "abc",
Configs: types.LicenseConfig{
UserNum: 10000,
DeCopy: false,
},
ExpiredAt: 0,
IsActive: true,
}, nil
}
func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data []string `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)
}
if response.IsErrorState() {
return nil, fmt.Errorf("发送请求失败:%v", response.Status)
}
if res.Code != types.Success {
return nil, fmt.Errorf("获取白名单失败:%v", res.Message)
}
return res.Data, nil
}
// GetLicense 获取许可信息
func (s *LicenseService) GetLicense() *types.License {
return s.license
}
// IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行
return nil
//if s.license.IsActive {
// if s.license.MachineId != s.machineId {
// return errors.New("系统使用了盗版的许可证书")
// }
//
// if time.Now().Unix() > s.license.ExpiredAt {
// return errors.New("系统许可证书已经过期")
// }
// return nil
//}
//
//if len(s.urlWhiteList) == 0 {
// urls, err := s.fetchUrlWhiteList()
// if err == nil {
// s.urlWhiteList = urls
// }
//}
//
//for _, v := range s.urlWhiteList {
// if strings.HasPrefix(uri, v) {
// return nil
// }
//}
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
}

View File

@@ -7,15 +7,28 @@ package mj
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "geekai/core/types"
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"time"
type Client interface {
Imagine(task types.MjTask) (ImageRes, error)
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error)
Upscale(task types.MjTask) (ImageRes, error)
Variation(task types.MjTask) (ImageRes, error)
QueryTask(taskId string) (QueryRes, error)
"github.com/gin-gonic/gin"
)
// Client MidJourney client
type Client struct {
client *req.Client
licenseService *service.LicenseService
db *gorm.DB
}
type ImageReq struct {
@@ -33,13 +46,8 @@ type ImageRes struct {
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
Result string `json:"result"`
Channel string `json:"channel,omitempty"`
}
type QueryRes struct {
@@ -66,3 +74,177 @@ type QueryRes struct {
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
var logger = logger2.GetLogger()
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
return &Client{
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
db: db,
}
}
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
var res ImageRes
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
if channel != "" {
session = session.Where("api_url", channel)
}
var apiKey model.ApiKey
err := session.Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(body).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
errMsg, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
}
// update the api key last used time
if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
logger.Error("update api key last used time error: ", err)
}
res.Channel = apiKey.ApiURL
return res, nil
}
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
var apiKey model.ApiKey
err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
if err != nil {
return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}

View File

@@ -1,240 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
}
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
}
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &PlusClient{}

View File

@@ -1,227 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"github.com/go-redis/redis/v8"
"strings"
"time"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range plusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
plusService.Run()
}()
p.services = append(p.services, plusService)
}
// for mid-journey proxy
for k, config := range proxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
proxyService.Run()
}()
p.services = append(p.services, proxyService)
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
mjService := p.getService(v.ChannelId)
if mjService == nil {
logger.Errorf("Invalid task: %+v", v)
continue
}
task, _ := mjService.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
if res.Error != nil {
continue
}
for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second * 10)
}
}()
}
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if s.Name == name {
return s
}
}
return nil
}

View File

@@ -1,185 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -11,10 +11,11 @@ import (
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/service/sd"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"strings"
"time"
@@ -23,112 +24,112 @@ import (
// Service MJ 绘画服务
type Service struct {
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
running bool
client *Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploaderManager *oss.UploaderManager
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
return &Service{
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
running: true,
db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
client: client,
Clients: types.NewLMap[uint, *types.WsClient](),
uploaderManager: manager,
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for s.running {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
logger.Info("Starting MidJourney job consumer for service")
go func() {
for {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
// use fast mode as default
if task.Mode == "" {
task.Mode = "fast"
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
logger.Infof("handle a new MidJourney task: %+v", task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.client.Upscale(task)
break
case types.TaskVariation:
res, err = s.client.Variation(task)
break
case types.TaskBlend:
res, err = s.client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = service.FailTaskProgress
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = res.Channel
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
func (s *Service) Stop() {
s.running = false
}()
}
type CBReq struct {
@@ -149,46 +150,6 @@ type CBReq struct {
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
@@ -196,3 +157,143 @@ func GetImageHash(action string) string {
}
return split[len(split)-1]
}
func (s *Service) CheckTaskNotify() {
go func() {
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := s.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
s.db.Updates(&v)
cli := s.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(service.TaskStatusFinished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
// 10 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
continue
}
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
if err != nil {
logger.Errorf("error with query task: %v", err)
continue
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": task.FailReason,
})
logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
err = s.db.Updates(&job).Error
if err != nil {
logger.Errorf("error with update database: %v", err)
continue
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := service.TaskStatusRunning
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
}
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -84,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
fileData, err = utils.DownloadImage(fileURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
parse, err := url.Parse(imageURL)
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil {
return "", err
}

View File

@@ -57,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
parse, err := url.Parse(imageURL)
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
@@ -69,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
}
if useProxy {
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
err = utils.DownloadFile(fileURL, filePath, s.proxyURL)
} else {
err = utils.DownloadFile(imageURL, filePath, "")
err = utils.DownloadFile(fileURL, filePath, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)

View File

@@ -44,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
}
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
fileData, err = utils.DownloadImage(fileURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
parse, err := url.Parse(imageURL)
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
@@ -65,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
context.Background(),
s.config.Bucket,
filename,
strings.NewReader(string(imageData)),
int64(len(imageData)),
strings.NewReader(string(fileData)),
int64(len(fileData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err

View File

@@ -93,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
fileData, err = utils.DownloadImage(fileURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
parse, err := url.Parse(imageURL)
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
@@ -113,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra)
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra)
if err != nil {
return "", err
}

View File

@@ -23,7 +23,7 @@ type File struct {
}
type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error)
PutUrlFile(url string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@@ -8,12 +8,13 @@ package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/smartwalle/alipay/v3"
"log"
"net/url"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
"net/http"
"os"
)
@@ -35,93 +36,98 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
xClient, err := alipay.New(config.AppId, priKey, !config.SandBox)
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
}
if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil {
return nil, fmt.Errorf("error with loading App PublicKey: %v", err)
}
if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil {
return nil, fmt.Errorf("error with loading alipay RootCert: %v", err)
}
if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err)
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with load payment public key: %v", err)
}
return &AlipayService{config: &config, client: xClient}, nil
return &AlipayService{config: &config, client: client}, nil
}
func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) {
var p = alipay.TradeWapPay{}
p.NotifyURL = notifyURL
p.ReturnURL = returnURL
p.Subject = subject
p.OutTradeNo = outTradeNo
p.TotalAmount = Amount
p.ProductCode = "QUICK_WAP_WAY"
res, err := s.client.TradeWapPay(p)
if err != nil {
return "", err
}
return res.String(), err
type AlipayParams struct {
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"subject"`
TotalFee string `json:"total_fee"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) {
var p = alipay.TradePagePay{}
p.NotifyURL = notifyURL
p.ReturnURL = returnURL
p.Subject = subject
p.OutTradeNo = outTradeNo
p.TotalAmount = amount
p.ProductCode = "FAST_INSTANT_TRADE_PAY"
res, err := s.client.TradePagePay(p)
if err != nil {
return "", nil
}
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("quit_url", params.ReturnURL)
bm.Set("total_amount", params.TotalFee)
bm.Set("return_url", params.ReturnURL)
bm.Set("notify_url", params.NotifyURL)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.TradeWapPay(context.Background(), bm)
}
return res.String(), err
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
}
// TradeVerify 交易验证
func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo {
err := s.client.VerifySign(reqForm)
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil {
log.Println("异步通知验证签名发生错误", err)
return NotifyVo{
Status: 0,
Message: "异步通知验证签名发生错误",
Status: Failure,
Message: "error with parse notify request: " + err.Error(),
}
}
return s.TradeQuery(reqForm.Get("out_trade_no"))
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
}
return s.TradeQuery(request.Form.Get("out_trade_no"))
}
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
var p = alipay.TradeQuery{}
p.OutTradeNo = outTradeNo
rsp, err := s.client.TradeQuery(p)
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return NotifyVo{
Status: 0,
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
}
}
if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" {
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{
Status: 1,
OutTradeNo: rsp.OutTradeNo,
TradeNo: rsp.TradeNo,
Amount: rsp.TotalAmount,
Subject: rsp.Subject,
Status: Success,
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Subject: rsp.Response.Subject,
Message: "OK",
}
} else {
return NotifyVo{
Status: 0,
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo,
}
}
@@ -134,16 +140,3 @@ func readKey(filename string) (string, error) {
}
return string(data), nil
}
type NotifyVo struct {
Status int
OutTradeNo string
TradeNo string
Amount string
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == 1
}

View File

@@ -0,0 +1,139 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
// GeekPayService Geek 支付服务
type GeekPayService struct {
config *types.GeekPayConfig
}
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
return &GeekPayService{
config: &appConfig.GeekPayConfig,
}
}
type GeekPayParams struct {
Method string `json:"method"` // 接口类型
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
}
// Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
p := map[string]string{
"pid": s.config.AppId,
//"method": params.Method,
"device": params.Device,
"type": params.Type,
"out_trade_no": params.OutTradeNo,
"name": params.Name,
"money": params.Money,
"clientip": params.ClientIP,
"notify_url": params.NotifyURL,
"return_url": params.ReturnURL,
"timestamp": fmt.Sprintf("%d", time.Now().Unix()),
}
p["sign"] = s.Sign(p)
p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p)
}
func (s *GeekPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数
var keys []string
for k := range params {
if params[k] == "" || k == "sign" || k == "sign_type" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
// 构建待签名字符串
var signStr strings.Builder
for _, k := range keys {
signStr.WriteString(k)
signStr.WriteString("=")
signStr.WriteString(params[k])
signStr.WriteString("&")
}
signString := strings.TrimSuffix(signStr.String(), "&") + s.config.PrivateKey
return utils.Md5(signString)
}
type GeekPayResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
TradeNo string `json:"trade_no"`
PayURL string `json:"payurl"`
QrCode string `json:"qrcode"`
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
}
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{}
for k, v := range params {
form.Add(k, v)
}
apiURL := fmt.Sprintf("%s/mapi.php", endpoint)
logger.Infof(apiURL)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // 取消 SSL 证书验证
},
}
client := &http.Client{Transport: tr}
resp, err := client.PostForm(apiURL, form)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
logger.Debugf(string(body))
if err != nil {
return nil, err
}
var r GeekPayResp
err = json.Unmarshal(body, &r)
if err != nil {
return nil, errors.New("当前支付渠道暂不支持")
}
if r.Code != 1 {
return nil, errors.New(r.Msg)
}
return &r, nil
}

View File

@@ -37,7 +37,7 @@ func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
}
}
type HuPiPayReq struct {
type HuPiPayParams struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
@@ -49,9 +49,11 @@ type HuPiPayReq struct {
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
}
type HuPiResp struct {
type HuPiPayResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
@@ -60,7 +62,7 @@ type HuPiResp struct {
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
@@ -78,22 +80,22 @@ func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
}
var res HuPiResp
var res HuPiPayResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
@@ -125,10 +127,10 @@ func (s *HuPiPayService) Sign(params url.Values) string {
}
// Check 校验订单状态
func (s *HuPiPayService) Check(tradeNo string) error {
func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("open_order_id", tradeNo)
data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)

View File

@@ -1,155 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
type PayJS struct {
config *types.JPayConfig
}
func NewPayJS(appConfig *types.AppConfig) *PayJS {
return &PayJS{
config: &appConfig.JPayConfig,
}
}
type JPayReq struct {
TotalFee int `json:"total_fee"`
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"body"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"callback_url"`
}
type JPayReps struct {
OutTradeNo string `json:"out_trade_no"`
OrderId string `json:"payjs_order_id"`
ReturnCode int `json:"return_code"`
ReturnMsg string `json:"return_msg"`
Sign string `json:"Sign"`
TotalFee string `json:"total_fee"`
CodeUrl string `json:"code_url,omitempty"`
Qrcode string `json:"qrcode,omitempty"`
}
func (r JPayReps) IsOK() bool {
return r.ReturnMsg == "SUCCESS"
}
func (js *PayJS) Pay(param JPayReq) JPayReps {
param.NotifyURL = js.config.NotifyURL
var p = url.Values{}
encode := utils.JsonEncode(param)
m := make(map[string]interface{})
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
p.Add(k, fmt.Sprintf("%v", v))
}
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
cli := http.Client{}
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
r, err := cli.PostForm(apiURL, p)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
defer r.Body.Close()
bs, err := io.ReadAll(r.Body)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
var data JPayReps
err = utils.JsonDecode(string(bs), &data)
if err != nil {
return JPayReps{ReturnMsg: err.Error()}
}
return data
}
func (js *PayJS) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *PayJS) sign(params url.Values) string {
params.Del(`sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += "&key=" + js.config.PrivateKey
md5bs := md5.Sum([]byte(src))
md5res := hex.EncodeToString(md5bs[:])
return strings.ToUpper(md5res)
}
// Check 查询订单支付状态
// @param tradeNo 支付平台交易 ID
func (js *PayJS) Check(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{}
params.Add("payjs_order_id", tradeNo)
params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
defer resp.Body.Close()
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ReturnCode int `json:"return_code"`
Status int `json:"status"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ReturnCode == 1 && r.Status == 1 {
return nil
} else {
logger.Errorf("PayJs 支付验证响应:%s", string(body))
return errors.New("order not paid")
}
}

View File

@@ -0,0 +1,19 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -0,0 +1,144 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
"net/http"
"time"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
priKey, err := readKey(config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -1,143 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploader *oss.UploaderManager
levelDB *store.LevelDB
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
uploader: manager,
levelDB: levelDB,
}
}
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range configs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf(" sd-service-%d", k)
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
// run sd service
go func() {
service.Run()
}()
p.services = append(p.services, service)
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
p.db.Delete(&job)
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -10,95 +10,91 @@ package sd
import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"github.com/go-redis/redis/v8"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// SD 绘画服务
type Service struct {
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
leveldb *store.LevelDB
running bool // 运行状态
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
return &Service{
name: name,
config: config,
httpClient: req.C(),
taskQueue: taskQueue,
notifyQueue: notifyQueue,
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db,
leveldb: levelDB,
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
running: true,
}
}
func (s *Service) Run() {
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
for s.running {
var task types.SdTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Starting Stable-Diffusion job consumer")
go func() {
for {
var task types.SdTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
continue
}
}
}
func (s *Service) Stop() {
s.running = false
}()
}
// Txt2ImgReq 文生图请求实体
@@ -160,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
}
var res Txt2ImgResp
var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
var apiKey model.ApiKey
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
}
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
logger.Debugf("send image request to %s", apiURL)
// send a request to sd api endpoint
go func() {
response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey).
SetHeader("Authorization", apiKey.Value).
SetBody(body).
SetSuccessResult(&res).
Post(apiURL)
@@ -178,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return
}
// update the last used time
apiKey.LastUsedAt = time.Now().Unix()
s.db.Updates(&apiKey)
// 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil {
@@ -192,7 +199,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return
}
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt})
errChan <- nil
}()
@@ -206,17 +213,17 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
default:
err, resp := s.checkTaskProgress()
err, resp := s.checkTaskProgress(apiKey)
// 更新任务进度
if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
@@ -229,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
}
// 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
var res TaskProgressResp
response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey).
SetHeader("Authorization", apiKey.Value).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
@@ -245,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
return nil, &res
}
func (s *Service) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -1,24 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import logger2 "geekai/logger"
var logger = logger2.GetLogger()
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)

View File

@@ -28,8 +28,8 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
}
func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := "Geek-AI 注册验证码"
body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", code)
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
body := fmt.Sprintf("【%s】您的验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", s.config.AppName, code)
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
if s.config.UseTls {

456
api/service/suno/service.go Normal file
View File

@@ -0,0 +1,456 @@
package suno
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
func (s *Service) PushTask(task types.SunoTask) {
logger.Infof("add a new Suno task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.SunoJob
s.db.Where("task_id", "").Find(&jobs)
for _, v := range jobs {
s.PushTask(types.SunoTask{
Id: v.Id,
Channel: v.Channel,
UserId: v.UserId,
Type: v.Type,
Title: v.Title,
RefTaskId: v.RefTaskId,
RefSongId: v.RefSongId,
Prompt: v.Prompt,
Tags: v.Tags,
Model: v.ModelName,
Instrumental: v.Instrumental,
ExtendSecs: v.ExtendSecs,
})
}
logger.Info("Starting Suno job consumer...")
go func() {
for {
var task types.SunoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
var r RespVo
if task.Type == 3 && task.SongId != "" { // 歌曲拼接
r, err = s.Merge(task)
} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
r, err = s.Upload(task)
} else { // 歌曲创作
r, err = s.Create(task)
}
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Data,
"channel": r.Channel,
})
}
}()
}
type RespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data string `json:"data"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"task_id": task.RefTaskId,
"continue_clip_id": task.RefSongId,
"continue_at": task.ExtendSecs,
"make_instrumental": task.Instrumental,
}
// 灵感模式
if task.Type == 1 {
reqBody["gpt_description_prompt"] = task.Prompt
} else { // 自定义模式
reqBody["prompt"] = task.Prompt
reqBody["tags"] = task.Tags
reqBody["mv"] = task.Model
reqBody["title"] = task.Title
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"clip_id": task.SongId,
"is_infill": false,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"url": task.AudioURL,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.SunoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
// 下载图片和音频
logger.Infof("try download cover image: %s", v.CoverURL)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
if err != nil {
logger.Errorf("download image with error: %v", err)
continue
}
logger.Infof("try download audio: %s", v.AudioURL)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue
}
v.CoverURL = coverURL
v.AudioURL = audioURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.SunoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
continue
}
if task.Code != "success" {
logger.Errorf("query task with error: %v", task.Message)
continue
}
logger.Debugf("task: %+v", task.Data.Status)
// 任务完成,删除旧任务插入两条新任务
if task.Data.Status == "SUCCESS" {
var jobId = job.Id
var flag = false
tx := s.db.Begin()
for _, v := range task.Data.Data {
job.Id = 0
job.Progress = 102 // 102 表示资源未下载完成
job.Title = v.Title
job.SongId = v.Id
job.Duration = int(v.Metadata.Duration)
job.Prompt = v.Metadata.Prompt
job.Tags = v.Metadata.Tags
job.ModelName = v.ModelName
job.RawData = utils.JsonEncode(v)
job.CoverURL = v.ImageLargeUrl
job.AudioURL = v.AudioUrl
if err = tx.Create(&job).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
break
}
flag = true
}
// 删除旧任务
if flag {
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
continue
}
}
tx.Commit()
} else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason
s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
}
}
time.Sleep(time.Second * 10)
}
}()
}
type QueryRespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data struct {
TaskId string `json:"task_id"`
Action string `json:"action"`
Status string `json:"status"`
FailReason string `json:"fail_reason"`
SubmitTime int `json:"submit_time"`
StartTime int `json:"start_time"`
FinishTime int `json:"finish_time"`
Progress string `json:"progress"`
Data []struct {
Id string `json:"id"`
Title string `json:"title"`
Status string `json:"status"`
Metadata struct {
Tags string `json:"tags"`
Type string `json:"type"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Duration float64 `json:"duration"`
ErrorMessage interface{} `json:"error_message"`
} `json:"metadata"`
AudioUrl string `json:"audio_url"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
ModelName string `json:"model_name"`
DisplayName string `json:"display_name"`
ImageLargeUrl string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
} `json:"data"`
} `json:"data"`
}
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "suno").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno")
}
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return QueryRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

View File

@@ -1,4 +1,17 @@
package service
const FailTaskProgress = 101
const (
TaskStatusRunning = "RUNNING"
TaskStatusFinished = "FINISH"
TaskStatusFailed = "FAIL"
)
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@@ -0,0 +1,83 @@
package service
import (
"fmt"
"geekai/core/types"
"geekai/store/model"
"gorm.io/gorm"
"sync"
"time"
)
type UserService struct {
db *gorm.DB
lock sync.Mutex
}
func NewUserService(db *gorm.DB) *UserService {
return &UserService{db: db, lock: sync.Mutex{}}
}
// IncreasePower 增加用户算力
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
if err != nil {
tx.Rollback()
return err
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
// DecreasePower 减少用户算力
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("扣减算力失败:%v", err)
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerSub,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("记录算力日志失败:%v", err)
}
tx.Commit()
return nil
}

341
api/service/video/luma.go Normal file
View File

@@ -0,0 +1,341 @@
package video
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.VideoJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var params types.VideoParams
if err := utils.JsonDecode(v.Params, &params); err != nil {
logger.Errorf("unmarshal params failed: %v", err)
continue
}
s.PushTask(types.VideoTask{
Id: v.Id,
Channel: v.Channel,
UserId: v.UserId,
Type: v.Type,
TaskId: v.TaskId,
Prompt: v.Prompt,
Params: params,
})
}
logger.Info("Starting Video job consumer...")
go func() {
for {
var task types.VideoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
}
}()
}
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
CreatedAt time.Time `json:"created_at"`
Video interface{} `json:"video"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
"user_prompt": task.Prompt,
"expand_prompt": task.Params.PromptOptimize,
"loop": task.Params.Loop,
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 && r.StatusCode != 201 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
if v.WaterURL == "" {
continue
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
logger.Infof("download video success: %s", videoURL)
v.WaterURL = videoURL
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
}
logger.Info("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.State == "completed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
}
time.Sleep(time.Second * 10)
}
}()
}
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
}
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
if r.StatusCode != 200 {
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

View File

@@ -1,101 +0,0 @@
package wx
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
logger2 "geekai/logger"
"geekai/store/model"
"github.com/eatmoreapple/openwechat"
"github.com/skip2/go-qrcode"
"gorm.io/gorm"
"os"
"strconv"
)
// 微信收款机器人
var logger = logger2.GetLogger()
type Bot struct {
bot *openwechat.Bot
token string
db *gorm.DB
}
func NewWeChatBot(db *gorm.DB) *Bot {
bot := openwechat.DefaultBot(openwechat.Desktop)
return &Bot{
bot: bot,
db: db,
}
}
func (b *Bot) Run() error {
logger.Info("Starting WeChat Bot...")
// set message handler
b.bot.MessageHandler = func(msg *openwechat.Message) {
b.messageHandler(msg)
}
// scan code login callback
b.bot.UUIDCallback = b.qrCodeCallBack
debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG"))
if debug {
reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json")
err = b.bot.HotLogin(reloadStorage, true)
} else {
err = b.bot.Login()
}
if err != nil {
return err
}
logger.Info("微信登录成功!")
return nil
}
// message handler
func (b *Bot) messageHandler(msg *openwechat.Message) {
sender, err := msg.Sender()
if err != nil {
return
}
// 只处理微信支付的推送消息
if sender.NickName == "微信支付" ||
msg.MsgType == openwechat.MsgTypeApp ||
msg.AppMsgType == openwechat.AppMsgTypeUrl {
// 解析支付金额
message := parseTransactionMessage(msg.Content)
transaction := extractTransaction(message)
logger.Infof("解析到收款信息:%+v", transaction)
if transaction.TransId != "" {
var item model.Reward
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
if item.Id > 0 {
logger.Error("当前交易 ID 己经存在!")
return
}
res = b.db.Create(&model.Reward{
TxId: transaction.TransId,
Amount: transaction.Amount,
Remark: transaction.Remark,
Status: false,
})
if res.Error != nil {
logger.Errorf("交易保存失败: %v", res.Error)
}
}
}
}
func (b *Bot) qrCodeCallBack(uuid string) {
logger.Info("请使用微信扫描下面二维码登录")
q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
logger.Info(q.ToString(true))
}

View File

@@ -1,112 +0,0 @@
package wx
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/xml"
"net/url"
"strconv"
"strings"
)
// Message 转账消息
type Message struct {
Des string
Url string
}
// Transaction 解析后的交易信息
type Transaction struct {
TransId string `json:"trans_id"` // 微信转账交易 ID
Amount float64 `json:"amount"` // 微信转账交易金额
Remark string `json:"remark"` // 转账备注
}
// 解析微信转账消息
func parseTransactionMessage(xmlData string) *Message {
decoder := xml.NewDecoder(strings.NewReader(xmlData))
message := Message{}
for {
token, err := decoder.Token()
if err != nil {
break
}
switch se := token.(type) {
case xml.StartElement:
var value string
if se.Name.Local == "des" && message.Des == "" {
if err := decoder.DecodeElement(&value, &se); err == nil {
message.Des = strings.TrimSpace(value)
}
break
}
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
if err := decoder.DecodeElement(&value, &se); err == nil {
if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
message.Url = value
}
}
break
}
}
}
// 兼容旧版消息记录
if message.Url == "" {
var msg struct {
XMLName xml.Name `xml:"msg"`
AppMsg struct {
Des string `xml:"des"`
Url string `xml:"url"`
} `xml:"appmsg"`
}
if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil {
message.Url = msg.AppMsg.Url
}
}
return &message
}
// 导出交易信息
func extractTransaction(message *Message) Transaction {
var tx = Transaction{}
// 导出交易金额和备注
lines := strings.Split(message.Des, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
// 解析收款金额
prefix := "收款金额¥"
if strings.HasPrefix(line, prefix) {
if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
tx.Amount = value
continue
}
}
// 解析收款备注
prefix = "付款方备注"
if strings.HasPrefix(line, prefix) {
tx.Remark = line[len(prefix):]
break
}
}
// 解析交易 ID
parse, err := url.Parse(message.Url)
if err == nil {
tx.TransId = parse.Query().Get("id")
if tx.TransId == "" {
tx.TransId = parse.Query().Get("trans_id")
}
}
return tx
}

View File

@@ -81,51 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
// 自动将 VIP 会员的算力补充到每月赠送的最大值
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("开始进行月底账号盘点...")
var users []model.User
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
if res.Error != nil {
return "No vip users found"
}
var sysConfig model.Config
res = e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return "error with get system config: " + res.Error.Error()
}
var config types.SystemConfig
err := utils.JsonDecode(sysConfig.Config, &config)
if err != nil {
return "error with decode system config: " + err.Error()
}
for _, u := range users {
// 处理过期的 VIP
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
u.Vip = false
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
continue
}
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: config.VipMonthPower,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
}
}
logger.Info("月底盘点完成!")
return "success"
}

View File

@@ -29,15 +29,9 @@ func NewLevelDB() (*LevelDB, error) {
}
func (db *LevelDB) Put(key string, value interface{}) error {
var byteData []byte
if v, ok := value.(string); ok {
byteData = []byte(v)
} else {
b, err := json.Marshal(value)
if err != nil {
return err
}
byteData = b
byteData, err := json.Marshal(value)
if err != nil {
return err
}
return db.driver.Put([]byte(key), byteData, nil)
}

View File

@@ -3,7 +3,6 @@ package model
// ApiKey OpenAI API 模型
type ApiKey struct {
BaseModel
Platform string
Name string
Type string // 用途 chat => 聊天img => 绘图
Value string // API Key 的值

View File

@@ -0,0 +1,12 @@
package model
import "time"
type AppType struct {
Id uint `gorm:"primarykey"`
Name string
Icon string
Enabled bool
SortNum int
CreatedAt time.Time
}

View File

@@ -4,16 +4,17 @@ import "gorm.io/gorm"
type ChatMessage struct {
BaseModel
ChatId string // 会话 ID
UserId uint // 用户 ID
RoleId uint // 角色 ID
Model string // AI模型
Type string
Icon string
Tokens int
Content string
UseContext bool // 是否可以作为聊天上下文
DeletedAt gorm.DeletedAt
ChatId string // 会话 ID
UserId uint // 用户 ID
RoleId uint // 角色 ID
Model string // AI模型
Type string
Icon string
Tokens int
TotalTokens int // 总 token 消耗
Content string
UseContext bool // 是否可以作为聊天上下文
DeletedAt gorm.DeletedAt
}
func (ChatMessage) TableName() string {

View File

@@ -2,7 +2,6 @@ package model
type ChatModel struct {
BaseModel
Platform string
Name string
Value string // API Key 的值
SortNum int

View File

@@ -2,6 +2,7 @@ package model
type ChatRole struct {
BaseModel
Tid int
Key string `gorm:"column:marker;unique"` // 角色唯一标识
Name string // 角色名称
Context string `gorm:"column:context_json"` // 角色语料信息 json

View File

@@ -2,7 +2,6 @@ package model
import (
"geekai/core/types"
"gorm.io/gorm"
)
// Order 充值订单
@@ -18,6 +17,6 @@ type Order struct {
Status types.OrderStatus
Remark string
PayTime int64
PayWay string // 支付方式
DeletedAt gorm.DeletedAt
PayWay string // 支付渠道
PayType string // 支付类型
}

Some files were not shown because too many files have changed in this diff Show More