Compare commits

...

140 Commits

Author SHA1 Message Date
RockYang
8eed7ff534 fix: fix overflow hidden for admin page 2024-03-28 18:51:09 +08:00
RockYang
c79c4e74d0 fix: fix overflow hidden for mobile page 2024-03-28 18:13:33 +08:00
RockYang
f1855fd0a1 fix: use slide captcha for iphone 2024-03-28 15:00:53 +08:00
RockYang
1f964c74e9 feat: add mj_action_power system config item 2024-03-28 09:53:41 +08:00
RockYang
4fb2c5803c feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy 2024-03-27 18:57:15 +08:00
RockYang
b5947545cb feat: auto translate and rewrite prompt for midjourney and stable-diffusion 2024-03-27 13:45:52 +08:00
RockYang
342b76f666 feat: stable-diffusion refactored, replace websocket api with sdapi 2024-03-26 18:23:08 +08:00
RockYang
49b5906bc7 fix: can not change user's power in admin console 2024-03-25 11:40:03 +08:00
RockYang
3075bfb7fc fix: fix bug for update user's power in admin page did not work 2024-03-23 16:21:37 +08:00
RockYang
82e06fad33 No need to login with Stable-Diffusion page and Invite page 2024-03-23 15:45:37 +08:00
RockYang
4a9028747b fix: fix code highlight error when add formule detecting 2024-03-22 19:09:04 +08:00
RockYang
4a8ff0ccf0 chore: correct prompt messages 2024-03-22 18:27:57 +08:00
RockYang
99341f0484 feat: add chart for admin dashbord 2024-03-22 16:57:30 +08:00
RockYang
f58ac29ad0 feat: integrate xxl-job-admin to implements automatic task scheduling 2024-03-22 13:47:16 +08:00
RockYang
7060edb3e5 feat: save prompt in power log for dalle-3 2024-03-21 15:55:39 +08:00
RockYang
41ae411f9b feat: add manager list page in console page 2024-03-21 15:24:28 +08:00
RockYang
79b7fee47c feat: add powerlog page for admin console 2024-03-21 13:46:39 +08:00
RockYang
0044bf10af opt: optimize the formula show styles 2024-03-21 11:04:12 +08:00
RockYang
e9348d3611 always parse authorization token for all request 2024-03-20 21:11:52 +08:00
RockYang
b9236e09a7 fixed conflicts 2024-03-20 20:40:22 +08:00
RockYang
09b38d5f42 update README.md.
Signed-off-by: RockYang <yangjian102621@gmail.com>
2024-03-20 20:39:44 +08:00
RockYang
7bb539a06e feat: Async loading midjourney job for mobile MidJourney page 2024-03-20 18:39:14 +08:00
RockYang
5cdada8265 feat: h5 payment for payjs is ready 2024-03-20 17:46:39 +08:00
RockYang
4147c217b1 feat: payment for mobile pages is ready 2024-03-20 16:14:02 +08:00
RockYang
8dda639b23 feat: the power log page is ready 2024-03-20 14:14:30 +08:00
RockYang
8487d2c9eb feat: no need login refactor with member and chatApps page 2024-03-19 18:59:02 +08:00
RockYang
c5e583b215 feat: load preview page do not require user to login 2024-03-19 18:25:01 +08:00
RockYang
549f618cff feat: optimize login dialog 2024-03-19 10:47:13 +08:00
RockYang
e9a3510346 feat: refactoring adjustments for reward page is ready 2024-03-18 18:28:34 +08:00
RockYang
30e6e963b3 feat: refactoring adjustments for member pages 2024-03-18 16:59:07 +08:00
RockYang
c72d963f45 feat: The 'chat_models' field of user table, holds the model IDS in place of the model values 2024-03-18 15:37:46 +08:00
RockYang
172d498618 remove new-ui files 2024-03-18 12:01:34 +08:00
RockYang
313993532e chore: adjust page styles 2024-03-18 06:46:08 +08:00
RockYang
e53db3582c 重构主体工作完成 2024-03-15 18:35:10 +08:00
RockYang
72c6bd3f77 restore new ui files 2024-03-15 11:13:02 +08:00
廖彦棋
ca8b349df3 fix(ui): 交互修复调整 2024-03-15 11:07:41 +08:00
RockYang
1b206c3640 feat: refactor user list page for new UI 2024-03-15 09:29:19 +08:00
廖彦棋
c60276fc9f fix(ui): 用户管理有效期传参调整,权限标识补充 2024-03-15 09:25:06 +08:00
廖彦棋
d00a3167c0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:49:18 +08:00
廖彦棋
6b1cd8c30c refactor(ui): 无权限页面调整 2024-03-14 17:49:15 +08:00
吴汉强
46f12dc9ad feat(ui): 后台首页去掉权限判断 2024-03-14 17:24:40 +08:00
廖彦棋
a3e1d8ae21 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:19:41 +08:00
廖彦棋
72a066b93e feat(ui): 无权限判断 2024-03-14 17:19:39 +08:00
吴汉强
0327a829ac Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 17:12:53 +08:00
吴汉强
882e9b8819 feat(ui): 新增 sql 2024-03-14 17:12:48 +08:00
廖彦棋
ef58cfadaa Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 17:06:15 +08:00
吴汉强
bf958d6113 feat(ui): 403,没权限 2024-03-14 16:41:38 +08:00
吴汉强
71611273d7 feat(ui): 网站配置需授权,去掉 2024-03-14 16:28:49 +08:00
廖彦棋
b27c654311 fix(ui): 调整 2024-03-14 16:14:49 +08:00
吴汉强
90930ea9f9 feat(ui): 后端加权限验证 2024-03-14 15:39:12 +08:00
廖彦棋
1ab2185ff1 feat(ui): 角色管理 2024-03-14 15:25:17 +08:00
廖彦棋
0f2f978d4c feat(ui): 新增系统分类菜单 2024-03-14 15:17:53 +08:00
廖彦棋
f61963b0b0 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-14 10:56:56 +08:00
廖彦棋
2aa413960d fix(ui): 修复 2024-03-14 10:56:54 +08:00
吴汉强
aa4bbba5ec Merge remote-tracking branch 'origin/ui' into ui 2024-03-14 10:28:39 +08:00
吴汉强
eba61fea2d feat(ui): 登录接口返回权限 2024-03-14 10:28:32 +08:00
廖彦棋
34e3455128 feat(ui): 管理后台新增权限及部分组合式函数优化 2024-03-14 10:27:09 +08:00
廖彦棋
07dca3e739 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 17:30:26 +08:00
廖彦棋
4cb4b145f9 feat(ui): web移动端初始化 2024-03-13 17:30:24 +08:00
吴汉强
1ed417cb69 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 17:24:38 +08:00
吴汉强
6cf91a84ca feat(ui): 增加角色管理,管理员方法新增角色关联 2024-03-13 17:24:30 +08:00
chenzifan
0b566980fc Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 14:40:45 +08:00
chenzifan
f86176b342 refactor: remove api change to post request 2024-03-13 14:40:38 +08:00
吴汉强
c700b32670 Merge remote-tracking branch 'origin/ui' into ui 2024-03-13 11:41:07 +08:00
吴汉强
22641b452a feat(ui): 增加系统权限管理 2024-03-13 11:41:01 +08:00
廖彦棋
d3fbb8c19e fix(ui): ui调整 2024-03-13 09:54:20 +08:00
RockYang
e3bb69ff10 docs: update sql file 2024-03-13 08:49:40 +08:00
RockYang
770360c614 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:48:10 +08:00
chenzifan
f302a0478f fix: 删除系统管理员失效的问题 2024-03-13 08:47:17 +08:00
chenzifan
a88697b43a Merge remote-tracking branch 'origin/ui' into ui
# Conflicts:
#	api/handler/admin/admin_user_handler.go
2024-03-13 08:46:16 +08:00
chenzifan
cc6f140812 fix: 删除系统管理员失效的问题 2024-03-13 08:45:09 +08:00
廖彦棋
424f2b3bdc Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-13 08:44:50 +08:00
廖彦棋
ec0c13a600 feat(ui): 调整 2024-03-13 08:44:48 +08:00
chenzf@pvc123.com
a1f03bec4c feat: 超级管理员不支持修改和删除 2024-03-12 21:16:05 +08:00
RockYang
b5bd4a5e0e Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 18:07:24 +08:00
RockYang
7c2e49bfdb fix conflicts 2024-03-12 18:07:19 +08:00
chenzifan
f80fe6d041 feat: 增加系统管理员 2024-03-12 18:06:49 +08:00
RockYang
72f80a96bc fix conflicts 2024-03-12 18:03:24 +08:00
RockYang
2de655a1cf refactor: use power replace calls for front pages 2024-03-12 17:47:06 +08:00
RockYang
da2bd4a501 refactor: 重构项目,为所有的 AI 工具都引入算力,采用算力统一结算各个工具的调用次数和权限 2024-03-12 15:40:44 +08:00
廖彦棋
e0aa62c40d Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-12 08:37:29 +08:00
廖彦棋
9d26a892d1 refactor(ui): 调整 2024-03-12 08:37:27 +08:00
huangqj
4ece7f2847 fix(ui):环境变量 2024-03-11 16:10:49 +08:00
廖彦棋
32368caf1b feat(ui): 新增系统管理员 2024-03-11 15:59:15 +08:00
廖彦棋
e91f54e79e fix(ui): 删除冗余 2024-03-11 14:23:11 +08:00
廖彦棋
bb8f4c57c4 fix(ui): 删除冗余 2024-03-11 14:22:39 +08:00
RockYang
43bfac99b6 feat: replace Tools param with Function param for OpenAI chat API 2024-03-11 14:09:19 +08:00
廖彦棋
be379b6d63 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-11 13:52:24 +08:00
廖彦棋
17f3c9b840 fix(ui): type 2024-03-11 13:52:22 +08:00
chenzifan
24de97fac2 feat: 优化后台UI 2024-03-11 13:51:26 +08:00
chenzifan
bf27b44fee Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 13:46:46 +08:00
廖彦棋
1802b4fe4d refactor(ui): 调整优化 2024-03-11 13:46:08 +08:00
廖彦棋
241a5c7bc9 feat(ui): 细节优化 2024-03-11 12:02:20 +08:00
廖彦棋
557d547bf1 feat(ui): 上传功能补充 2024-03-11 11:41:50 +08:00
廖彦棋
2e7b75affb feat(ui): 登录新增验证码及记住密码功能 2024-03-11 10:49:13 +08:00
廖彦棋
bc21a1d443 feat(ul): 顶部信息 2024-03-11 09:00:00 +08:00
huangqj
3fc9e10a24 feat(ui):看板图表 样式调整 2024-03-11 08:35:21 +08:00
chenzifan
5fa1aa2060 Merge remote-tracking branch 'origin/ui' into ui 2024-03-11 08:01:54 +08:00
廖彦棋
ff4b267858 fix(ui): 细节调整 2024-03-08 17:46:48 +08:00
huangqj
a590d0497f feat(ui):细节调整 2024-03-08 10:24:38 +08:00
廖彦棋
ac30d906f0 feat(ui): prettier 2024-03-08 09:59:40 +08:00
廖彦棋
5bc071e038 refactor(ui): 登录页重构 2024-03-08 09:45:09 +08:00
廖彦棋
88b956cf98 refactor(ui): 优化 2024-03-08 09:12:39 +08:00
huangqj
f725cf4661 feat(ui):路由 2024-03-08 08:35:41 +08:00
廖彦棋
057cc1e8a6 Merge branch 'ui' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 18:03:08 +08:00
廖彦棋
de122735b8 feat(ui): 过期跳转登录 2024-03-07 18:03:06 +08:00
huangqj
e87ede981c feat(ui):apiKey 语言模型 角色管理 产品 2024-03-07 17:58:25 +08:00
廖彦棋
606fb498e1 feat(ui): 新增系统设置 2024-03-07 17:24:50 +08:00
廖彦棋
a0c06e40a4 feat(ui): 对话管理 2024-03-07 15:32:32 +08:00
huangqj
aba8f57279 feat(ui):用户 2024-03-07 15:05:01 +08:00
huangqj
960286a350 feat(ui):simpleTable 2024-03-07 14:59:45 +08:00
huangqj
8c93fa51f6 feat(ui):searchtable 2024-03-07 14:59:16 +08:00
huangqj
cb0e7d64ff feat(ui):用户 2024-03-07 14:03:55 +08:00
廖彦棋
8e7413da97 feat(ui): 函数管理 2024-03-07 11:40:57 +08:00
廖彦棋
a36f14eb94 feat(ui): 新增弹窗及时间格式化 2024-03-07 09:23:45 +08:00
chenzifan
f2f9f6e488 Merge branch 'main' of 172.28.1.6:yangjian/chatgpt-plus into ui 2024-03-07 08:37:54 +08:00
chenzifan
85068b8ca2 feat: 增加系统用户管理 2024-03-07 08:37:48 +08:00
廖彦棋
f2cfcfeefc fix(ui): ts类型 2024-03-06 18:20:07 +08:00
RockYang
755273a898 feat: update changelogs 2024-03-06 17:58:17 +08:00
廖彦棋
d4a24a0f1d feat(ui): 新增登录 2024-03-06 17:54:38 +08:00
RockYang
92281fcbb7 feat: Mj and sd jobs data loading in pages 2024-03-06 17:31:54 +08:00
RockYang
636db4afcc add prompt translating function for mobile midjourney page 2024-03-06 16:22:03 +08:00
huangqj
ba25b8755e feat(ui):simpleTable 2024-03-06 15:33:37 +08:00
廖彦棋
6399d13a49 feat(ui): 新增请求方法及表格 2024-03-06 13:55:38 +08:00
廖彦棋
06fa54fd25 feat(ui): 管理后台基础配置 2024-03-06 10:23:55 +08:00
廖彦棋
a335b965d0 chore(ui): 目录结构调整 2024-03-06 09:32:47 +08:00
廖彦棋
725adaa7d0 feat(ui): 初始化 2024-03-06 09:27:11 +08:00
chenzifan
7e7e81e974 refactor: 初始化UI重构 2024-03-06 08:57:46 +08:00
RockYang
8cfe6bfc17 Merge branch 'dev' of gitee.com:blackfox/chatgpt-plus-pro into dev 2024-03-04 08:34:12 +08:00
RockYang
33de83f2ac feat: add removing order button in admin order list page 2024-03-03 19:27:22 +08:00
RockYang
3f856afec8 fix: fix major bugs for unauthorized access to data 2024-03-03 10:40:32 +08:00
RockYang
02a9c422fe fix: fixed bug image preview im mobile chat session page 2024-02-29 15:41:45 +08:00
RockYang
ca69341024 feat: add draw same image for midjourney page 2024-02-29 11:44:09 +08:00
RockYang
169bf069ce opt: add logs for mj-plus api error 2024-02-28 15:50:42 +08:00
RockYang
1bee0ab04d opt: replace proxy url for discord image url 2024-02-27 17:45:57 +08:00
RockYang
440d91dd0e feat: add change password in with mobile page 2024-02-27 15:36:20 +08:00
RockYang
8168e246a8 feat: add image preview for mobile chat page 2024-02-26 18:11:37 +08:00
RockYang
2ef07574ae feat: replace http polling with webscoket notify in sd image page 2024-02-26 15:45:54 +08:00
RockYang
37392f2bb2 chore: replace 'token' with power 2024-02-23 18:11:57 +08:00
RockYang
a80cd3848e docs: update mj-plus api domain 2024-02-23 15:41:02 +08:00
193 changed files with 9690 additions and 5996 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
deploy
docs
api/static
web/node_modules
desktop

View File

@@ -1,4 +1,33 @@
# 更新日志 # 更新日志
## 4.0.2
* 功能新增:支持前端菜单可以配置
* 功能优化:手机端支持免登录预览功能
* 功能新增:手机端支持 Stable-Diffusion 绘画
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
## v4.0.1
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API兼容 MJ-Plus 中转
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
* Bug修复修复 iphone 手机无法通过图形验证码的Bug使用滑动验证码替换
* Bug修复修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
## v4.0.0
非兼容版本重大重构引入算力概念将系统中所有的能力AI对话MJ绘画SD绘画DALL绘画全部使用算力来兑换。
只要你的算力值余额不为0你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
* 功能重构:重构整体系统,全部采用算力来进行结算
* 功能优化SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
* 功能优化:移动端聊天页面图片支持预览和放大功能
* 功能优化MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
* 功能优化:**PC端不登录也可以预览功能只有在发起操作的时候才需要登录**
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
* 功能新增支持H5支付
* 功能优化:支持数学公式的识别和美化输出
* 功能新增:新增算力消费日志功能
* 功能优化:整合 XXL-JOB 实现订单清理每日算力派发VIP 算力重置等任务
* 功能新增管理后台新增7日内新增用户和新增订单统计
## v3.2.7 ## v3.2.7
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能 * 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能 * 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能

View File

@@ -73,9 +73,11 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。 **演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
```shell ```shell
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.5-400fea2598.sh)" bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)"
``` ```
最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址 目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
* 前端访问地址http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。 * 前端访问地址http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
@@ -145,6 +147,3 @@ KEY。
![打赏](docs/imgs/donate.png) ![打赏](docs/imgs/donate.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/chatgpt-plus&type=Date) ![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/chatgpt-plus&type=Date)

View File

@@ -1,6 +1,6 @@
Listen = "0.0.0.0:5678" Listen = "0.0.0.0:5678"
ProxyURL = "" # 如 http://127.0.0.1:7777 ProxyURL = "" # 如 http://127.0.0.1:7777
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local" MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
StaticDir = "./static" # 静态资源的目录 StaticDir = "./static" # 静态资源的目录
StaticUrl = "/static" # 静态资源访问 URL StaticUrl = "/static" # 静态资源访问 URL
AesEncryptKey = "" AesEncryptKey = ""
@@ -10,10 +10,6 @@ WeChatBot = false
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换 SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
MaxAge = 86400 MaxAge = 86400
[Manager]
Username = "admin"
Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
[Redis] # redis 配置信息 [Redis] # redis 配置信息
Host = "localhost" Host = "localhost"
Port = 6379 Port = 6379
@@ -46,7 +42,7 @@ WeChatBot = false
Active = "local" # 默认使用本地文件存储引擎 Active = "local" # 默认使用本地文件存储引擎
[OSS.Local] [OSS.Local]
BasePath = "./static/upload" # 本地文件上传根路径 BasePath = "./static/upload" # 本地文件上传根路径
BaseURL = "/static/upload" # 本地上传文件 URL 如果是线上,则直接设置为 /static/upload 即可 BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
[OSS.Minio] [OSS.Minio]
Endpoint = "" # 如 172.22.11.200:9000 Endpoint = "" # 如 172.22.11.200:9000
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
@@ -60,25 +56,24 @@ WeChatBot = false
AccessSecret = "" AccessSecret = ""
Bucket = "" Bucket = ""
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
[OSS.AliYun]
Endpoint = "oss-cn-hangzhou.aliyuncs.com"
AccessKey = ""
AccessSecret = ""
Bucket = "chatgpt-plus"
SubDir = ""
Domain = ""
[[MjConfigs]] [[MjProxyConfigs]]
Enabled = false Enabled = true
UserToken = "" ApiURL = "http://midjourney-proxy:8082"
BotToken = "" ApiKey = "sk-geekmaster"
GuildId = ""
ChanelId = ""
UseCDN = false #是否使用反向代理访问设置为true下面的设置才会生效
DiscordAPI = "" # discord API 反代地址
DiscordCDN = "" # mj 图片反代地址
DiscordGateway = "" # discord 机器人反代地址
[[MjPlusConfigs]] [[MjPlusConfigs]]
Enabled = false Enabled = false
ApiURL = "https://api.chatgpt-plus.net" # 目前暂时不支持更改 ApiURL = "https://api.chat-plus.net"
CdnURL = "" # CND 加速的 URL如果有的话就设置
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx" ApiKey = "sk-xxx"
NotifyURL = "https://ai.r9it.com/api/mj/notify" # 这里需要改成你的域名
[[SdConfigs]] [[SdConfigs]]
Enabled = false Enabled = false

View File

@@ -28,10 +28,9 @@ type AppServer struct {
Debug bool Debug bool
Config *types.AppConfig Config *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache SysConfig *types.SystemConfig // system config cache
SysConfig *types.SystemConfig // system config cache
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API // 防止第三方直接连接 socket 调用 OpenAI API
@@ -47,7 +46,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false, Debug: false,
Config: appConfig, Config: appConfig,
Engine: gin.Default(), Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []interface{}](), ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](), ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](), ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
@@ -69,23 +68,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
} }
func (s *AppServer) Run(db *gorm.DB) error { func (s *AppServer) Run(db *gorm.DB) error {
// load chat config from database
var chatConfig model.Config
res := db.Where("marker", "chat").First(&chatConfig)
if res.Error != nil {
return res.Error
}
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
if err != nil {
return err
}
// load system configs // load system configs
var sysConfig model.Config var sysConfig model.Config
res = db.Where("marker", "system").First(&sysConfig) res := db.Where("marker", "system").First(&sysConfig)
if res.Error != nil { if res.Error != nil {
return res.Error return res.Error
} }
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig) err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
if err != nil { if err != nil {
return err return err
} }
@@ -143,73 +132,64 @@ func corsMiddleware() gin.HandlerFunc {
// 用户授权验证 // 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/mj/jobs" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/jobs" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
c.Request.URL.Path == "/api/admin/config/get" {
c.Next()
return
}
var tokenString string var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader) tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" { } else if c.Request.URL.Path == "/api/chat/new" {
tokenString = c.Query("token") tokenString = c.Query("token")
} else { } else {
tokenString = c.GetHeader(types.UserAuthHeader) tokenString = c.GetHeader(types.UserAuthHeader)
} }
if tokenString == "" { if tokenString == "" {
resp.ERROR(c, "You should put Authorization in request headers") if needLogin(c) {
c.Abort() resp.ERROR(c, "You should put Authorization in request headers")
return c.Abort()
return
} else { // 直接放行
c.Next()
return
}
} }
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
} }
if isAdminApi {
return []byte(s.Config.AdminSession.SecretKey), nil
} else {
return []byte(s.Config.Session.SecretKey), nil
}
return []byte(s.Config.Session.SecretKey), nil
}) })
if err != nil { if err != nil && needLogin(c) {
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err)) resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
c.Abort() c.Abort()
return return
} }
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid { if !ok || !token.Valid && needLogin(c) {
resp.NotAuth(c, "Token is invalid") resp.NotAuth(c, "Token is invalid")
c.Abort() c.Abort()
return return
} }
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() { if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
resp.NotAuth(c, "Token is expired") resp.NotAuth(c, "Token is expired")
c.Abort() c.Abort()
return return
} }
key := fmt.Sprintf("users/%v", claims["user_id"]) key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := client.Get(context.Background(), key).Result(); err != nil { if isAdminApi {
key = fmt.Sprintf("admin/%v", claims["user_id"])
}
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
resp.NotAuth(c, "Token is not found in redis") resp.NotAuth(c, "Token is not found in redis")
c.Abort() c.Abort()
return return
@@ -218,6 +198,36 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
} }
} }
func needLogin(c *gin.Context) bool {
if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/user/resetPass" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
return false
}
return true
}
// 统一参数处理 // 统一参数处理
func parameterHandlerMiddleware() gin.HandlerFunc { func parameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {

View File

@@ -16,7 +16,6 @@ func NewDefaultConfig() *types.AppConfig {
return &types.AppConfig{ return &types.AppConfig{
Listen: "0.0.0.0:5678", Listen: "0.0.0.0:5678",
ProxyURL: "", ProxyURL: "",
Manager: types.Manager{Username: "admin", Password: "admin123"},
StaticDir: "./static", StaticDir: "./static",
StaticUrl: "http://localhost/5678/static", StaticUrl: "http://localhost/5678/static",
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""}, Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},

View File

@@ -54,10 +54,14 @@ type ChatSession struct {
} }
type ChatModel struct { type ChatModel struct {
Id uint `json:"id"` Id uint `json:"id"`
Platform Platform `json:"platform"` Platform Platform `json:"platform"`
Value string `json:"value"` Name string `json:"name"`
Weight int `json:"weight"` Value string `json:"value"`
Power int `json:"power"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
} }
type ApiError struct { type ApiError struct {
@@ -72,23 +76,36 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message const ReplyMsg = "reply" // reply message
var ModelToTokens = map[string]int{ // PowerType 算力日志类型
"gpt-3.5-turbo": 4096, type PowerType int
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192, const (
"gpt-4-32k": 32768, PowerRecharge = PowerType(1) // 充值
"chatglm_pro": 32768, // 清华智普 PowerConsume = PowerType(2) // 消费
"chatglm_std": 16384, PowerRefund = PowerType(3) // 任务SD,MJ执行失败退款
"chatglm_lite": 4096, PowerInvite = PowerType(4) // 邀请奖励
"ernie_bot_turbo": 8192, // 文心一言 PowerReward = PowerType(5) // 众筹
"general": 8192, // 科大讯飞 PowerGift = PowerType(6) // 系统赠送
"general2": 8192, )
"general3": 8192,
func (t PowerType) String() string {
switch t {
case PowerRecharge:
return "充值"
case PowerConsume:
return "消费"
case PowerRefund:
return "退款"
case PowerReward:
return "众筹"
}
return "其他"
} }
func GetModelMaxToken(model string) int { type PowerMark int
if token, ok := ModelToTokens[model]; ok {
return token const (
} PowerSub = PowerMark(0)
return 4096 PowerAdd = PowerMark(1)
} )

View File

@@ -5,22 +5,22 @@ import (
) )
type AppConfig struct { type AppConfig struct {
Path string `toml:"-"` Path string `toml:"-"`
Listen string Listen string
Session Session Session Session
ProxyURL string AdminSession Session
MysqlDns string // mysql 连接地址 ProxyURL string
Manager Manager // 后台管理员账户信息 MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录 StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息 Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool MjProxyConfigs []MjProxyConfig // MJ proxy config
MjPlusConfigs []MidJourneyPlusConfig // MJ plus config MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人 WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig XXLConfig XXLConfig
AlipayConfig AlipayConfig AlipayConfig AlipayConfig
@@ -43,32 +43,25 @@ type ChatPlusApiConfig struct {
Token string Token string
} }
type MidJourneyConfig struct { type MjProxyConfig struct {
Enabled bool Enabled bool
UserToken string ApiURL string // api 地址
BotToken string Mode string // 绘画模式可选值fast/turbo/relax
GuildId string // Server ID ApiKey string
ChanelId string // Chanel ID
UseCDN bool
ImgCdnURL string // 图片反代加速地址
DiscordAPI string
DiscordGateway string
} }
type StableDiffusionConfig struct { type StableDiffusionConfig struct {
Enabled bool Enabled bool
ApiURL string Model string // 模型名称
ApiKey string ApiURL string
Txt2ImgJsonPath string ApiKey string
} }
type MidJourneyPlusConfig struct { type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务 Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址 ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax Mode string // 绘画模式可选值fast/turbo/relax
CdnURL string // CDN 加速地址 ApiKey string
ApiKey string
NotifyURL string // 任务进度更新回调地址
} }
type AlipayConfig struct { type AlipayConfig struct {
@@ -81,6 +74,7 @@ type AlipayConfig struct {
AlipayPublicKey string // 支付宝公钥文件路径 AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径 RootCert string // Root 秘钥路径
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
} }
type HuPiPayConfig struct { //虎皮椒第四方支付配置 type HuPiPayConfig struct { //虎皮椒第四方支付配置
@@ -90,6 +84,7 @@ type HuPiPayConfig struct { //虎皮椒第四方支付配置
AppSecret string // app 密钥 AppSecret string // app 密钥
ApiURL string // 支付网关 ApiURL string // 支付网关
NotifyURL string // 异步通知回调 NotifyURL string // 异步通知回调
ReturnURL string // 支付成功返回地址
} }
// JPayConfig PayJs 支付配置 // JPayConfig PayJs 支付配置
@@ -100,6 +95,7 @@ type JPayConfig struct {
PrivateKey string // 私钥 PrivateKey string // 私钥
ApiURL string // API 网关 ApiURL string // API 网关
NotifyURL string // 异步回调地址 NotifyURL string // 异步回调地址
ReturnURL string // 支付成功返回地址
} }
type XXLConfig struct { // XXL 任务调度配置 type XXLConfig struct { // XXL 任务调度配置
@@ -122,26 +118,6 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port) return fmt.Sprintf("%s:%d", c.Host, c.Port)
} }
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
}
// ChatConfig 系统默认的聊天配置
type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
XunFei ModelAPIConfig `json:"xun_fei"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
ContextDeep int `json:"context_deep"` // 上下文深度
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
}
type Platform string type Platform string
const OpenAI = Platform("OpenAI") const OpenAI = Platform("OpenAI")
@@ -151,42 +127,33 @@ const Baidu = Platform("Baidu")
const XunFei = Platform("XunFei") const XunFei = Platform("XunFei")
const QWen = Platform("QWen") const QWen = Platform("QWen")
// UserChatConfig 用户的聊天配置
type UserChatConfig struct {
ApiKeys map[Platform]string `json:"api_keys"`
}
type InviteReward struct {
ChatCalls int `json:"chat_calls"`
ImgCalls int `json:"img_calls"`
}
type ModelAPIConfig struct {
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
}
type SystemConfig struct { type SystemConfig struct {
Title string `json:"title"` Title string `json:"title,omitempty"`
AdminTitle string `json:"admin_title"` AdminTitle string `json:"admin_title,omitempty"`
InitChatCalls int `json:"init_chat_calls"` // 新用户注册赠送对话次数 Logo string `json:"logo,omitempty"`
InitImgCalls int `json:"init_img_calls"` // 新用户注册赠送绘图次数 InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
VipMonthCalls int `json:"vip_month_calls"` // VIP 会员每月赠送的对话次数 DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
VipMonthImgCalls int `json:"vip_month_img_calls"` // VIP 会员每月赠送绘图次数 InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways"` // 注册方式:支持手机,邮箱注册 RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册
EnabledRegister bool `json:"enabled_register"` // 是否开放注册 EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
RewardImg string `json:"reward_img"` // 众筹收款二维码地址 RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
EnabledReward bool `json:"enabled_reward"` // 启用众筹功能 EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
ChatCallPrice float64 `json:"chat_call_price"` // 对话单次调用费用 PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
ImgCallPrice float64 `json:"img_call_price"` // 绘图单次调用费用
OrderPayTimeout int `json:"order_pay_timeout"` //订单支付超时时间 OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
DefaultModels []string `json:"default_models"` // 默认开通的 AI 模型 VipInfoText string `json:"vip_info_text"` // 会员页面充值说明
OrderPayInfoText string `json:"order_pay_info_text"` // 订单支付页面说明文字 DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
InviteChatCalls int `json:"invite_chat_calls"` // 邀请用户注册奖励对话次数
InviteImgCalls int `json:"invite_img_calls"` // 邀请用户注册奖励绘图次数
WechatCardURL string `json:"wechat_card_url"` // 微信客服地址 MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"`
} }

View File

@@ -9,7 +9,7 @@ type MKey interface {
string | int | uint string | int | uint
} }
type MValue interface { type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{} *WsClient | *ChatSession | context.CancelFunc | []Message
} }
type LMap[K MKey, T MValue] struct { type LMap[K MKey, T MValue] struct {
lock sync.RWMutex lock sync.RWMutex

View File

@@ -9,10 +9,9 @@ const (
) )
type OrderRemark struct { type OrderRemark struct {
Days int `json:"days"` // 有效期 Days int `json:"days"` // 有效期
Calls int `json:"calls"` // 增加对话次 Power int `json:"power"` // 增加算力点
ImgCalls int `json:"img_calls"` // 增加绘图次数 Name string `json:"name"` // 产品名称
Name string `json:"name"` // 产品名称
Price float64 `json:"price"` Price float64 `json:"price"`
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
} }

View File

@@ -36,7 +36,6 @@ type SdTask struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
Params SdTaskParams `json:"params"` Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
} }

View File

@@ -30,6 +30,7 @@ const (
Success = BizCode(0) Success = BizCode(0)
Failed = BizCode(1) Failed = BizCode(1)
NotAuthorized = BizCode(400) // 未授权 NotAuthorized = BizCode(400) // 未授权
NotPermission = BizCode(403) // 没有权限
OkMsg = "Success" OkMsg = "Success"
ErrorMsg = "系统开小差了" ErrorMsg = "系统开小差了"

View File

@@ -25,7 +25,15 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0 require github.com/xxl-job/xxl-job-executor-go v1.2.0
require github.com/bg5t/mydiscordgo v0.28.1 require (
github.com/mojocn/base64Captcha v1.3.1
github.com/shopspring/decimal v1.3.1
)
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
)
require ( require (
github.com/andybalholm/brotli v1.0.4 // indirect github.com/andybalholm/brotli v1.0.4 // indirect

View File

@@ -7,8 +7,6 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8=
github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -65,6 +63,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
@@ -76,7 +76,6 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -129,6 +128,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
@@ -166,6 +167,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
@@ -219,7 +222,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
@@ -227,6 +229,8 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=

View File

@@ -5,10 +5,15 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/handler" "chatplus/handler"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"context" "context"
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/mojocn/base64Captcha"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -17,47 +22,88 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1
type ManagerHandler struct { type ManagerHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
redis *redis.Client redis *redis.Client
} }
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler { func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
h := ManagerHandler{db: db, redis: client} return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
h.App = app
return &h
} }
// Login 登录 // Login 登录
func (h *ManagerHandler) Login(c *gin.Context) { func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager var data Manager
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
manager := h.App.Config.Manager
if data.Username == manager.Username && data.Password == manager.Password { // add captcha
// 创建 token if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ resp.ERROR(c, "验证码错误!")
"user_id": manager.Username, return
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := "users/" + manager.Username
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, tokenString)
} else {
resp.ERROR(c, "用户名或者密码错误")
} }
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
return
}
password := utils.GenPassword(data.Password, manager.Salt)
if password != manager.Password {
resp.ERROR(c, "用户名或密码错误")
return
}
// 超级管理员默认是ID:1
if manager.Id != SuperManagerID && manager.Status == false {
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
return
}
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": manager.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("admin/%d", manager.Id)
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
// 更新最后登录时间和IP
manager.LastLoginIp = c.ClientIP()
manager.LastLoginAt = time.Now().Unix()
h.DB.Updates(&manager)
var result = struct {
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
}{
IsSuperAdmin: manager.Id == 1,
Token: tokenString,
}
resp.SUCCESS(c, result)
} }
// Logout 注销 // Logout 注销
@@ -72,10 +118,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
// Session 会话检测 // Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) { func (h *ManagerHandler) Session(c *gin.Context) {
token := c.GetHeader(types.AdminAuthHeader) id := h.GetLoginUserId(c)
if token == "" { key := fmt.Sprintf("admin/%d", id)
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c) resp.NotAuth(c)
} else { return
resp.SUCCESS(c)
} }
var manager model.AdminUser
res := h.DB.Where("id", id).First(&manager)
if res.Error != nil {
resp.NotAuth(c)
return
}
resp.SUCCESS(c, manager)
}
// List 数据列表
func (h *ManagerHandler) List(c *gin.Context) {
var items []model.AdminUser
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
users := make([]vo.AdminUser, 0)
for _, item := range items {
var u vo.AdminUser
err := utils.CopyObject(item, &u)
if err != nil {
continue
}
u.Id = item.Id
u.CreatedAt = item.CreatedAt.Unix()
users = append(users, u)
}
resp.SUCCESS(c, users)
}
func (h *ManagerHandler) Save(c *gin.Context) {
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Status bool `json:"status"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("username", data.Username).First(&user)
if res.Error == nil {
resp.ERROR(c, "用户名已存在")
return
}
// 生成密码
salt := utils.RandString(8)
password := utils.GenPassword(data.Password, salt)
res = h.DB.Save(&model.AdminUser{
Username: data.Username,
Password: password,
Salt: salt,
Status: data.Status,
})
if res.Error != nil {
resp.ERROR(c, "failed with update database")
return
}
resp.SUCCESS(c)
}
// Remove 删除管理员
func (h *ManagerHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id == SuperManagerID {
resp.ERROR(c, "超级管理员不能删除")
return
}
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// Enable 启用/禁用
func (h *ManagerHandler) Enable(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
}
// ResetPass 重置密码
func (h *ManagerHandler) ResetPass(c *gin.Context) {
id := h.GetLoginUserId(c)
if id != SuperManagerID {
resp.ERROR(c, "只有超级管理员能够进行该操作")
return
}
var data struct {
Id int `json:"id"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c)
} }

View File

@@ -14,13 +14,10 @@ import (
type ApiKeyHandler struct { type ApiKeyHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler { func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db} return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
h.App = app
return &h
} }
func (h *ApiKeyHandler) Save(c *gin.Context) { func (h *ApiKeyHandler) Save(c *gin.Context) {
@@ -32,7 +29,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
Value string `json:"value"` Value string `json:"value"`
ApiURL string `json:"api_url"` ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
UseProxy bool `json:"use_proxy"` ProxyURL string `json:"proxy_url"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -41,16 +38,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{} apiKey := model.ApiKey{}
if data.Id > 0 { if data.Id > 0 {
h.db.Find(&apiKey, data.Id) h.DB.Find(&apiKey, data.Id)
} }
apiKey.Platform = data.Platform apiKey.Platform = data.Platform
apiKey.Value = data.Value apiKey.Value = data.Value
apiKey.Type = data.Type apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL apiKey.ApiURL = data.ApiURL
apiKey.Enabled = data.Enabled apiKey.Enabled = data.Enabled
apiKey.UseProxy = data.UseProxy apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name apiKey.Name = data.Name
res := h.db.Save(&apiKey) res := h.DB.Save(&apiKey)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -68,9 +65,14 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
} }
func (h *ApiKeyHandler) List(c *gin.Context) { func (h *ApiKeyHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ApiKey var items []model.ApiKey
var keys = make([]vo.ApiKey, 0) var keys = make([]vo.ApiKey, 0)
res := h.db.Find(&items) res := h.DB.Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var key vo.ApiKey var key vo.ApiKey
@@ -100,7 +102,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
return return
} }
res := h.db.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -109,10 +111,15 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
} }
func (h *ApiKeyHandler) Remove(c *gin.Context) { func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) var data struct {
Id uint
if id > 0 { }
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{}) if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@@ -0,0 +1,39 @@
package admin
import (
"chatplus/core"
"chatplus/handler"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
)
type CaptchaHandler struct {
handler.BaseHandler
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {
CaptchaId string `json:"captcha_id"`
PicPath string `json:"pic_path"`
}
// GetCaptcha 获取验证码
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
var captchaVo CaptchaVo
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
// b64s是图片的base64编码
id, b64s, err := cp.Generate()
if err != nil {
resp.ERROR(c, "生成验证码错误!")
return
}
captchaVo.CaptchaId = id
captchaVo.PicPath = b64s
resp.SUCCESS(c, captchaVo)
}

View File

@@ -14,27 +14,30 @@ import (
type ChatHandler struct { type ChatHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
h := ChatHandler{db: db} return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
type chatItemVo struct { type chatItemVo struct {
Username string `json:"username"` Username string `json:"username"`
UserId uint `json:"user_id"` UserId uint `json:"user_id"`
ChatId string `json:"chat_id"` ChatId string `json:"chat_id"`
Title string `json:"title"` Title string `json:"title"`
Model string `json:"model"` Role vo.ChatRole `json:"role"`
Token int `json:"token"` Model string `json:"model"`
CreatedAt int64 `json:"created_at"` Token int `json:"token"`
MsgNum int `json:"msg_num"` // 消息数量 CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
} }
func (h *ChatHandler) List(c *gin.Context) { func (h *ChatHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct { var data struct {
Title string `json:"title"` Title string `json:"title"`
UserId uint `json:"user_id"` UserId uint `json:"user_id"`
@@ -48,7 +51,7 @@ func (h *ChatHandler) List(c *gin.Context) {
return return
} }
session := h.db.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if data.Title != "" { if data.Title != "" {
session = session.Where("title LIKE ?", "%"+data.Title+"%") session = session.Where("title LIKE ?", "%"+data.Title+"%")
} }
@@ -73,18 +76,23 @@ func (h *ChatHandler) List(c *gin.Context) {
if res.Error == nil { if res.Error == nil {
userIds := make([]uint, 0) userIds := make([]uint, 0)
chatIds := make([]string, 0) chatIds := make([]string, 0)
roleIds := make([]uint, 0)
for _, item := range items { for _, item := range items {
userIds = append(userIds, item.UserId) userIds = append(userIds, item.UserId)
chatIds = append(chatIds, item.ChatId) chatIds = append(chatIds, item.ChatId)
roleIds = append(roleIds, item.RoleId)
} }
var messages []model.ChatMessage var messages []model.ChatMessage
var users []model.User var users []model.User
h.db.Where("chat_id IN ?", chatIds).Find(&messages) var roles []model.ChatRole
h.db.Where("id IN ?", userIds).Find(&users) h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles)
tokenMap := make(map[string]int) tokenMap := make(map[string]int)
userMap := make(map[uint]string) userMap := make(map[uint]string)
msgMap := make(map[string]int) msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole)
for _, msg := range messages { for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1 msgMap[msg.ChatId] += 1
@@ -92,6 +100,14 @@ func (h *ChatHandler) List(c *gin.Context) {
for _, user := range users { for _, user := range users {
userMap[user.Id] = user.Username userMap[user.Id] = user.Username
} }
for _, r := range roles {
var roleVo vo.ChatRole
err := utils.CopyObject(r, &roleVo)
if err != nil {
continue
}
roleMap[r.Id] = roleVo
}
for _, item := range items { for _, item := range items {
list = append(list, chatItemVo{ list = append(list, chatItemVo{
UserId: item.UserId, UserId: item.UserId,
@@ -101,6 +117,7 @@ func (h *ChatHandler) List(c *gin.Context) {
Model: item.Model, Model: item.Model,
Token: tokenMap[item.ChatId], Token: tokenMap[item.ChatId],
MsgNum: msgMap[item.ChatId], MsgNum: msgMap[item.ChatId],
Role: roleMap[item.RoleId],
CreatedAt: item.CreatedAt.Unix(), CreatedAt: item.CreatedAt.Unix(),
}) })
} }
@@ -135,7 +152,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
return return
} }
session := h.db.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if data.Content != "" { if data.Content != "" {
session = session.Where("content LIKE ?", "%"+data.Content+"%") session = session.Where("content LIKE ?", "%"+data.Content+"%")
} }
@@ -163,7 +180,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
userIds = append(userIds, item.UserId) userIds = append(userIds, item.UserId)
} }
var users []model.User var users []model.User
h.db.Where("id IN ?", userIds).Find(&users) h.DB.Where("id IN ?", userIds).Find(&users)
userMap := make(map[uint]string) userMap := make(map[uint]string)
for _, user := range users { for _, user := range users {
userMap[user.Id] = user.Username userMap[user.Id] = user.Username
@@ -190,7 +207,7 @@ func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0) var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items) res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No history message") resp.ERROR(c, "No history message")
return return
@@ -212,9 +229,14 @@ func (h *ChatHandler) History(c *gin.Context) {
// RemoveChat 删除对话 // RemoveChat 删除对话
func (h *ChatHandler) RemoveChat(c *gin.Context) { func (h *ChatHandler) RemoveChat(c *gin.Context) {
chatId := h.GetTrim(c, "chat_id") chatId := h.GetTrim(c, "chat_id")
tx := h.db.Begin() if chatId == "" {
resp.ERROR(c, "请传入 ChatId")
return
}
tx := h.DB.Begin()
// 删除聊天记录 // 删除聊天记录
res := tx.Unscoped().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{}) res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "failed to remove chat message") resp.ERROR(c, "failed to remove chat message")
return return
@@ -235,7 +257,7 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
// RemoveMessage 删除聊天记录 // RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) { func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
tx := h.db.Unscoped().Delete(&model.ChatMessage{}, id) tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil { if tx.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@@ -15,26 +15,26 @@ import (
type ChatModelHandler struct { type ChatModelHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db} return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *ChatModelHandler) Save(c *gin.Context) { func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Value string `json:"value"` Value string `json:"value"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"` SortNum int `json:"sort_num"`
Open bool `json:"open"` Open bool `json:"open"`
Platform string `json:"platform"` Platform string `json:"platform"`
Weight int `json:"weight"` Power int `json:"power"`
CreatedAt int64 `json:"created_at"` MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
CreatedAt int64 `json:"created_at"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -42,18 +42,21 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
} }
item := model.ChatModel{ item := model.ChatModel{
Platform: data.Platform, Platform: data.Platform,
Name: data.Name, Name: data.Name,
Value: data.Value, Value: data.Value,
Enabled: data.Enabled, Enabled: data.Enabled,
SortNum: data.SortNum, SortNum: data.SortNum,
Open: data.Open, Open: data.Open,
Weight: data.Weight} MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
Power: data.Power}
item.Id = data.Id item.Id = data.Id
if item.Id > 0 { if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0) item.CreatedAt = time.Unix(data.CreatedAt, 0)
} }
res := h.db.Save(&item) res := h.DB.Save(&item)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -72,7 +75,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
// List 模型列表 // List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) { func (h *ChatModelHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{}) if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable") enable := h.GetBool(c, "enable")
if enable { if enable {
session = session.Where("enabled", enable) session = session.Where("enabled", enable)
@@ -109,7 +117,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return return
} }
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -129,7 +137,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -141,13 +149,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
func (h *ChatModelHandler) Remove(c *gin.Context) { func (h *ChatModelHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if id > 0 { res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{}) if res.Error != nil {
if res.Error != nil { resp.ERROR(c, "更新数据库失败!")
resp.ERROR(c, "更新数据库失败!") return
return
}
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@@ -15,13 +15,10 @@ import (
type ChatRoleHandler struct { type ChatRoleHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
h := ChatRoleHandler{db: db} return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
// Save 创建或者更新某个角色 // Save 创建或者更新某个角色
@@ -41,7 +38,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
if data.CreatedAt > 0 { if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0) role.CreatedAt = time.Unix(data.CreatedAt, 0)
} }
res := h.db.Save(&role) res := h.DB.Save(&role)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -53,9 +50,14 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
} }
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatRoleHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ChatRole var items []model.ChatRole
var roles = make([]vo.ChatRole, 0) var roles = make([]vo.ChatRole, 0)
res := h.db.Order("sort_num ASC").Find(&items) res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No data found") resp.ERROR(c, "No data found")
return return
@@ -88,7 +90,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -110,7 +112,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
return return
} }
res := h.db.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -119,13 +121,18 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
} }
func (h *ChatRoleHandler) Remove(c *gin.Context) { func (h *ChatRoleHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) var data struct {
if id <= 0 { Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if data.Id <= 0 {
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{}) resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "删除失败!") resp.ERROR(c, "删除失败!")
return return

View File

@@ -14,19 +14,20 @@ import (
type ConfigHandler struct { type ConfigHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db} return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *ConfigHandler) Update(c *gin.Context) { func (h *ConfigHandler) Update(c *gin.Context) {
var data struct { var data struct {
Key string `json:"key"` Key string `json:"key"`
Config map[string]interface{} `json:"config"` Config struct {
types.SystemConfig
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@@ -36,7 +37,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
value := utils.JsonEncode(&data.Config) value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value} config := model.Config{Key: data.Key, Config: value}
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key}) res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
@@ -44,7 +45,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
if config.Id > 0 { if config.Id > 0 {
config.Config = value config.Config = value
res := h.db.Updates(&config) res := h.DB.Updates(&config)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
@@ -52,12 +53,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// update config cache for AppServer // update config cache for AppServer
var cfg model.Config var cfg model.Config
h.db.Where("marker", data.Key).First(&cfg) h.DB.Where("marker", data.Key).First(&cfg)
var err error var err error
if data.Key == "system" { if data.Key == "system" {
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig) err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
} else if data.Key == "chat" {
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
} }
if err != nil { if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error()) resp.ERROR(c, "Failed to update config cache: "+err.Error())
@@ -71,20 +70,25 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// Get 获取指定的系统配置 // Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) { func (h *ConfigHandler) Get(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
key := c.Query("key") key := c.Query("key")
var config model.Config var config model.Config
res := h.db.Where("marker", key).First(&config) res := h.DB.Where("marker", key).First(&config)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
} }
var m map[string]interface{} var value map[string]interface{}
err := utils.JsonDecode(config.Config, &m) err := utils.JsonDecode(config.Config, &value)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
resp.SUCCESS(c, m) resp.SUCCESS(c, value)
} }

View File

@@ -7,26 +7,25 @@ import (
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"gorm.io/gorm" "gorm.io/gorm"
"time" "time"
) )
type DashboardHandler struct { type DashboardHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler { func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
h := DashboardHandler{db: db} return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
type statsVo struct { type statsVo struct {
Users int64 `json:"users"` Users int64 `json:"users"`
Chats int64 `json:"chats"` Chats int64 `json:"chats"`
Tokens int `json:"tokens"` Tokens int `json:"tokens"`
Income float64 `json:"income"` Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"`
} }
func (h *DashboardHandler) Stats(c *gin.Context) { func (h *DashboardHandler) Stats(c *gin.Context) {
@@ -35,37 +34,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
var userCount int64 var userCount int64
now := time.Now() now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount) res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil { if res.Error == nil {
stats.Users = userCount stats.Users = userCount
} }
// new chats statistic // new chats statistic
var chatCount int64 var chatCount int64
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount) res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
if res.Error == nil { if res.Error == nil {
stats.Chats = chatCount stats.Chats = chatCount
} }
// tokens took stats // tokens took stats
var historyMessages []model.ChatMessage var historyMessages []model.ChatMessage
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages) res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
for _, item := range historyMessages { for _, item := range historyMessages {
stats.Tokens += item.Tokens stats.Tokens += item.Tokens
} }
// 众筹收入 // 众筹收入
var rewards []model.Reward var rewards []model.Reward
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards) res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards { for _, item := range rewards {
stats.Income += item.Amount stats.Income += item.Amount
} }
// 订单收入 // 订单收入
var orders []model.Order var orders []model.Order
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders) res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders { for _, item := range orders {
stats.Income += item.Amount stats.Income += item.Amount
} }
// 统计7天的订单的图表
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
var statsChart = make(map[string]map[string]float64)
//// 初始化
var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
for i := 0; i < 7; i++ {
var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
userStatistic[initTime] = float64(0)
historyMessagesStatistic[initTime] = float64(0)
incomeStatistic[initTime] = float64(0)
}
// 统计用户7天增加的曲线
var users []model.User
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
if res.Error == nil {
for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
}
}
// 统计7天Token 消耗
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
for _, item := range historyMessages {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
}
// 浮点数相加?
// 统计最近7天的众筹
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
statsChart["users"] = userStatistic
statsChart["historyMessage"] = historyMessagesStatistic
statsChart["orders"] = incomeStatistic
stats.Chart = statsChart
resp.SUCCESS(c, stats) resp.SUCCESS(c, stats)
} }

View File

@@ -17,13 +17,10 @@ import (
type FunctionHandler struct { type FunctionHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler { func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
h := FunctionHandler{db: db} return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *FunctionHandler) Save(c *gin.Context) { func (h *FunctionHandler) Save(c *gin.Context) {
@@ -44,7 +41,7 @@ func (h *FunctionHandler) Save(c *gin.Context) {
Enabled: data.Enabled, Enabled: data.Enabled,
} }
res := h.db.Save(&f) res := h.DB.Save(&f)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "error with save data:"+res.Error.Error()) resp.ERROR(c, "error with save data:"+res.Error.Error())
return return
@@ -65,7 +62,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
return return
} }
res := h.db.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -74,8 +71,13 @@ func (h *FunctionHandler) Set(c *gin.Context) {
} }
func (h *FunctionHandler) List(c *gin.Context) { func (h *FunctionHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Function var items []model.Function
res := h.db.Find(&items) res := h.DB.Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No data found") resp.ERROR(c, "No data found")
return return
@@ -97,7 +99,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.db.Delete(&model.Function{Id: uint(id)}) res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@@ -8,24 +8,28 @@ import (
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
type OrderHandler struct { type OrderHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db} return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *OrderHandler) List(c *gin.Context) { func (h *OrderHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var data struct { var data struct {
OrderNo string `json:"order_no"` OrderNo string `json:"order_no"`
Status int `json:"status"`
PayTime []string `json:"pay_time"` PayTime []string `json:"pay_time"`
Page int `json:"page"` Page int `json:"page"`
PageSize int `json:"page_size"` PageSize int `json:"page_size"`
@@ -35,7 +39,7 @@ func (h *OrderHandler) List(c *gin.Context) {
return return
} }
session := h.db.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if data.OrderNo != "" { if data.OrderNo != "" {
session = session.Where("order_no", data.OrderNo) session = session.Where("order_no", data.OrderNo)
} }
@@ -44,8 +48,9 @@ func (h *OrderHandler) List(c *gin.Context) {
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00") end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end) session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
} }
session = session.Where("status = ?", types.OrderPaidSuccess) if data.Status >= 0 {
session = session.Where("status", data.Status)
}
var total int64 var total int64
session.Model(&model.Order{}).Count(&total) session.Model(&model.Order{}).Count(&total)
var items []model.Order var items []model.Order
@@ -74,7 +79,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
if id > 0 { if id > 0 {
var item model.Order var item model.Order
res := h.db.First(&item, id) res := h.DB.First(&item, id)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "记录不存在!") resp.ERROR(c, "记录不存在!")
return return
@@ -85,7 +90,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return return
} }
res = h.db.Where("id = ?", id).Delete(&model.Order{}) res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@@ -0,0 +1,77 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
handler.BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`
Type int `json:"type"`
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Model != "" {
session = session.Where("model", data.Model)
}
if data.Type > 0 {
session = session.Where("type", data.Type)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
// 统计消费算力总和
var totalPower float64
if len(data.Date) == 2 {
session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
}
resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
}

View File

@@ -15,13 +15,10 @@ import (
type ProductHandler struct { type ProductHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler { func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db} return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *ProductHandler) Save(c *gin.Context) { func (h *ProductHandler) Save(c *gin.Context) {
@@ -32,8 +29,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Days int `json:"days"` Days int `json:"days"`
Calls int `json:"calls"` Power int `json:"power"`
ImgCalls int `json:"img_calls"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@@ -46,14 +42,13 @@ func (h *ProductHandler) Save(c *gin.Context) {
Price: data.Price, Price: data.Price,
Discount: data.Discount, Discount: data.Discount,
Days: data.Days, Days: data.Days,
Calls: data.Calls, Power: data.Power,
ImgCalls: data.ImgCalls,
Enabled: data.Enabled} Enabled: data.Enabled}
item.Id = data.Id item.Id = data.Id
if item.Id > 0 { if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0) item.CreatedAt = time.Unix(data.CreatedAt, 0)
} }
res := h.db.Save(&item) res := h.DB.Save(&item)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -72,7 +67,12 @@ func (h *ProductHandler) Save(c *gin.Context) {
// List 模型列表 // List 模型列表
func (h *ProductHandler) List(c *gin.Context) { func (h *ProductHandler) List(c *gin.Context) {
session := h.db.Session(&gorm.Session{}) if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable") enable := h.GetBool(c, "enable")
if enable { if enable {
session = session.Where("enabled", enable) session = session.Where("enabled", enable)
@@ -108,7 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return return
} }
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled) res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -128,7 +128,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
@@ -142,7 +142,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.Product{}) res := h.DB.Where("id = ?", id).Delete(&model.Product{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@@ -2,6 +2,7 @@ package admin
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types"
"chatplus/handler" "chatplus/handler"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
@@ -13,18 +14,20 @@ import (
type RewardHandler struct { type RewardHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db} return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *RewardHandler) List(c *gin.Context) { func (h *RewardHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Reward var items []model.Reward
res := h.db.Order("id DESC").Find(&items) res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0) var rewards = make([]vo.Reward, 0)
if res.Error == nil { if res.Error == nil {
userIds := make([]uint, 0) userIds := make([]uint, 0)
@@ -32,7 +35,7 @@ func (h *RewardHandler) List(c *gin.Context) {
userIds = append(userIds, v.UserId) userIds = append(userIds, v.UserId)
} }
var users []model.User var users []model.User
h.db.Where("id IN ?", userIds).Find(&users) h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User) var userMap = make(map[uint]model.User)
for _, u := range users { for _, u := range users {
userMap[u.Id] = u userMap[u.Id] = u
@@ -57,10 +60,15 @@ func (h *RewardHandler) List(c *gin.Context) {
} }
func (h *RewardHandler) Remove(c *gin.Context) { func (h *RewardHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) var data struct {
Id uint
if id > 0 { }
res := h.db.Where("id = ?", id).Delete(&model.Reward{}) if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@@ -0,0 +1,45 @@
package admin
import (
"chatplus/core"
"chatplus/handler"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type UploadHandler struct {
handler.BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
return
}
userId := 0
res := h.DB.Create(&model.File{
UserId: userId,
Name: file.Name,
ObjKey: file.ObjKey,
URL: file.URL,
Ext: file.Ext,
Size: file.Size,
CreatedAt: time.Time{},
})
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with update database: "+res.Error.Error())
return
}
resp.SUCCESS(c, file)
}

View File

@@ -9,6 +9,7 @@ import (
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"fmt" "fmt"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -16,17 +17,19 @@ import (
type UserHandler struct { type UserHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
} }
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler { func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
h := UserHandler{db: db} return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
// List 用户列表 // List 用户列表
func (h *UserHandler) List(c *gin.Context) { func (h *UserHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20) pageSize := h.GetInt(c, "page_size", 20)
username := h.GetTrim(c, "username") username := h.GetTrim(c, "username")
@@ -36,7 +39,7 @@ func (h *UserHandler) List(c *gin.Context) {
var users = make([]vo.User, 0) var users = make([]vo.User, 0)
var total int64 var total int64
session := h.db.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if username != "" { if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%") session = session.Where("username LIKE ?", "%"+username+"%")
} }
@@ -66,13 +69,12 @@ func (h *UserHandler) Save(c *gin.Context) {
Id uint `json:"id"` Id uint `json:"id"`
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
ChatRoles []string `json:"chat_roles"` ChatRoles []string `json:"chat_roles"`
ChatModels []string `json:"chat_models"` ChatModels []int `json:"chat_models"`
ExpiredTime string `json:"expired_time"` ExpiredTime string `json:"expired_time"`
Status bool `json:"status"` Status bool `json:"status"`
Vip bool `json:"vip"` Vip bool `json:"vip"`
Power int `json:"power"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -82,18 +84,45 @@ func (h *UserHandler) Save(c *gin.Context) {
var res *gorm.DB var res *gorm.DB
var userVo vo.User var userVo vo.User
if data.Id > 0 { // 更新 if data.Id > 0 { // 更新
user.Id = data.Id res = h.DB.Where("id", data.Id).First(&user)
// 此处需要用 map 更新,用结构体无法更新 0 值 if res.Error != nil {
res = h.db.Model(&user).Updates(map[string]interface{}{ resp.ERROR(c, "user not found")
"username": data.Username, return
"calls": data.Calls, }
"img_calls": data.ImgCalls, var oldPower = user.Power
"status": data.Status, user.Username = data.Username
"vip": data.Vip, user.Status = data.Status
"chat_roles_json": utils.JsonEncode(data.ChatRoles), user.Vip = data.Vip
"chat_models_json": utils.JsonEncode(data.ChatModels), user.Power = data.Power
"expired_time": utils.Str2stamp(data.ExpiredTime), user.ChatRoles = utils.JsonEncode(data.ChatRoles)
}) user.ChatModels = utils.JsonEncode(data.ChatModels)
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力日志
if oldPower != user.Power {
mark := types.PowerAdd
amount := user.Power - oldPower
if oldPower > user.Power {
mark = types.PowerSub
amount = oldPower - user.Power
}
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerGift,
Amount: amount,
Balance: user.Power,
Mark: mark,
Model: "管理员",
Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID%d", oldPower, user.Power, h.GetLoginUserId(c)),
CreatedAt: time.Now(),
})
}
} else { } else {
salt := utils.RandString(8) salt := utils.RandString(8)
u := model.User{ u := model.User{
@@ -102,21 +131,13 @@ func (h *UserHandler) Save(c *gin.Context) {
Password: utils.GenPassword(data.Password, salt), Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png", Avatar: "/images/avatar/user.png",
Salt: salt, Salt: salt,
Power: data.Power,
Status: true, Status: true,
ChatRoles: utils.JsonEncode(data.ChatRoles), ChatRoles: utils.JsonEncode(data.ChatRoles),
ChatModels: utils.JsonEncode(data.ChatModels), ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime), ExpiredTime: utils.Str2stamp(data.ExpiredTime),
ChatConfig: utils.JsonEncode(types.UserChatConfig{
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: data.Calls,
ImgCalls: data.ImgCalls,
} }
res = h.db.Create(&u) res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo) _ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id userVo.Id = u.Id
userVo.CreatedAt = u.CreatedAt.Unix() userVo.CreatedAt = u.CreatedAt.Unix()
@@ -143,7 +164,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
} }
var user model.User var user model.User
res := h.db.First(&user, data.Id) res := h.DB.First(&user, data.Id)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No user found") resp.ERROR(c, "No user found")
return return
@@ -151,7 +172,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
user.Password = password user.Password = password
res = h.db.Updates(&user) res = h.DB.Updates(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c) resp.ERROR(c)
} else { } else {
@@ -161,36 +182,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
func (h *UserHandler) Remove(c *gin.Context) { func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
if id > 0 { if id <= 0 {
tx := h.db.Begin() resp.ERROR(c, types.InvalidArgs)
res := h.db.Where("id = ?", id).Delete(&model.User{}) return
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除聊天历史记录
res = h.db.Where("user_id = ?", id).Delete(&model.ChatMessage{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除登录日志
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
} }
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -198,10 +215,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20) pageSize := h.GetInt(c, "page_size", 20)
var total int64 var total int64
h.db.Model(&model.UserLoginLog{}).Count(&total) h.DB.Model(&model.UserLoginLog{}).Count(&total)
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
var items []model.UserLoginLog var items []model.UserLoginLog
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items) res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "获取数据失败") resp.ERROR(c, "获取数据失败")
return return

View File

@@ -4,8 +4,11 @@ import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/utils" "chatplus/utils"
"errors"
"fmt" "fmt"
"gorm.io/gorm"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -15,6 +18,7 @@ var logger = logger2.GetLogger()
type BaseHandler struct { type BaseHandler struct {
App *core.AppServer App *core.AppServer
DB *gorm.DB
} }
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string { func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
@@ -57,3 +61,27 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
} }
return uint(utils.IntValue(utils.InterfaceToString(userId), 0)) return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
} }
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0
}
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
value, exists := c.Get(types.LoginUserCache)
if exists {
return value.(model.User), nil
}
userId, ok := c.Get(types.LoginUserID)
if !ok {
return model.User{}, errors.New("user not login")
}
var user model.User
res := h.DB.First(&user, userId)
// 更新缓存
if res.Error == nil {
c.Set(types.LoginUserCache, user)
}
return user, res.Error
}

View File

@@ -45,3 +45,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
} }
} }
// SlideGet 获取滑动验证图片
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
data, err := h.service.SlideGet()
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// SlideCheck 滑动验证结果校验
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
var data struct {
Key string `json:"key"`
X int `json:"x"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.service.SlideCheck(data) {
resp.SUCCESS(c)
} else {
resp.ERROR(c)
}
}

View File

@@ -12,37 +12,34 @@ import (
type ChatModelHandler struct { type ChatModelHandler struct {
BaseHandler BaseHandler
db *gorm.DB
} }
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db} return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
// List 模型列表 // List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) { func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0) var chatModels = make([]vo.ChatModel, 0)
// 只加载用户订阅的 AI 模型 var res *gorm.DB
user, err := utils.GetLoginUser(c, h.db) // 如果用户没有登录,则加载所有开放模型
if err != nil { if !h.IsLogin(c) {
resp.NotAuth(c) res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
return } else {
user, _ := h.GetLoginUser(c)
var models []int
err := utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
} }
var models []string
err = utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res := h.db.Where("enabled = ?", true).Where(
h.db.Where("value IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var cm vo.ChatModel var cm vo.ChatModel

View File

@@ -14,27 +14,25 @@ import (
type ChatRoleHandler struct { type ChatRoleHandler struct {
BaseHandler BaseHandler
db *gorm.DB
} }
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
handler := &ChatRoleHandler{db: db} return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
handler.App = app
return handler
} }
// List get user list // List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all") all := h.GetBool(c, "all")
userId := h.GetLoginUserId(c)
var roles []model.ChatRole var roles []model.ChatRole
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles) res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error()) resp.ERROR(c, "No roles found,"+res.Error.Error())
return return
} }
// 获取所有角色 // 获取所有角色
if all { if userId == 0 || all {
// 转成 vo // 转成 vo
var roleVos = make([]vo.ChatRole, 0) var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles { for _, r := range roles {
@@ -49,13 +47,8 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return return
} }
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.NotAuth(c)
return
}
var user model.User var user model.User
h.db.First(&user, userId) h.DB.First(&user, userId)
var roleKeys []string var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys) err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil { if err != nil {
@@ -80,7 +73,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// UpdateRole 更新用户聊天角色 // UpdateRole 更新用户聊天角色
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) { func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
@@ -94,7 +87,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return return
} }
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil { if res.Error != nil {
logger.Error("添加应用失败:", err) logger.Error("添加应用失败:", err)
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")

View File

@@ -19,7 +19,7 @@ import (
// 微软 Azure 模型消息发送实现 // 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage( func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -103,8 +103,6 @@ func (h *ChatHandler) sendAzureMessage(
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
@@ -113,66 +111,64 @@ func (h *ChatHandler) sendAzureMessage(
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
if h.App.ChatConfig.EnableHistory { // for prompt
// for prompt promptToken, err := utils.CalcTokens(prompt, req.Model)
promptToken, err := utils.CalcTokens(prompt, req.Model) if err != nil {
if err != nil { logger.Error(err)
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
} }
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId chatItem.UserId = session.UserId
@@ -184,7 +180,7 @@ func (h *ChatHandler) sendAzureMessage(
chatItem.Title = prompt chatItem.Title = prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.db.Create(&chatItem) h.DB.Create(&chatItem)
} }
} }
} else { } else {

View File

@@ -36,7 +36,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现 // 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage( func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage(
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
} }
@@ -138,65 +135,63 @@ func (h *ChatHandler) sendBaiduMessage(
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
if h.App.ChatConfig.EnableHistory { // for prompt
// for prompt promptToken, err := utils.CalcTokens(prompt, req.Model)
promptToken, err := utils.CalcTokens(prompt, req.Model) if err != nil {
if err != nil { logger.Error(err)
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
} }
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId chatItem.UserId = session.UserId
@@ -208,7 +203,7 @@ func (h *ChatHandler) sendBaiduMessage(
chatItem.Title = prompt chatItem.Title = prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.db.Create(&chatItem) h.DB.Create(&chatItem)
} }
} }
} else { } else {

View File

@@ -35,19 +35,16 @@ var logger = logger2.GetLogger()
type ChatHandler struct { type ChatHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB
redis *redis.Client redis *redis.Client
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
h := ChatHandler{ return &ChatHandler{
db: db, BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis, redis: redis,
uploadManager: manager, uploadManager: manager,
} }
h.App = app
return &h
} }
func (h *ChatHandler) Init() { func (h *ChatHandler) Init() {
@@ -57,8 +54,6 @@ func (h *ChatHandler) Init() {
} }
} }
var chatConfig types.ChatConfig
// ChatHandle 处理聊天 WebSocket 请求 // ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) { func (h *ChatHandler) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
@@ -75,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
client := types.NewWsClient(ws) client := types.NewWsClient(ws)
// get model info // get model info
var chatModel model.ChatModel var chatModel model.ChatModel
res := h.db.First(&chatModel, modelId) res := h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false { if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭") utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort() c.Abort()
@@ -84,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session := h.App.ChatSession.Get(sessionId) session := h.App.ChatSession.Get(sessionId)
if session == nil { if session == nil {
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
logger.Info("用户未登录") logger.Info("用户未登录")
c.Abort() c.Abort()
@@ -101,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID // use old chat data override the chat model and role ID
var chat model.ChatItem var chat model.ChatItem
res = h.db.Where("chat_id = ?", chatId).First(&chat) res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil { if res.Error == nil {
chatModel.Id = chat.ModelId chatModel.Id = chat.ModelId
roleId = int(chat.RoleId) roleId = int(chat.RoleId)
@@ -109,28 +104,24 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId session.ChatId = chatId
session.Model = types.ChatModel{ session.Model = types.ChatModel{
Id: chatModel.Id, Id: chatModel.Id,
Value: chatModel.Value, Name: chatModel.Name,
Weight: chatModel.Weight, Value: chatModel.Value,
Platform: types.Platform(chatModel.Platform)} Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId) res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable { if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort() c.Abort()
return return
} }
// 初始化聊天配置 h.Init()
var config model.Config
h.db.Where("marker", "chat").First(&config)
err = utils.JsonDecode(config.Config, &chatConfig)
if err != nil {
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
c.Abort()
return
}
// 保存会话连接 // 保存会话连接
h.App.ChatClients.Put(sessionId, client) h.App.ChatClients.Put(sessionId, client)
@@ -188,9 +179,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
var user model.User var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId) res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil { if res.Error != nil {
utils.ReplyMessage(ws, "非法用户,请联系管理员") utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作")
return res.Error return res.Error
} }
var userVo vo.User var userVo vo.User
@@ -206,14 +197,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil return nil
} }
if userVo.Calls < session.Model.Weight { if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数%d已不足以支付当前模型的单次对话需要消耗的对话额度%d", userVo.Calls, session.Model.Weight)) utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
utils.ReplyMessage(ws, ErrImg) utils.ReplyMessage(ws, ErrImg)
return nil return nil
} }
@@ -223,35 +208,34 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg) utils.ReplyMessage(ws, ErrImg)
return nil return nil
} }
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{ var req = types.ApiRequest{
Model: session.Model.Value, Model: session.Model.Value,
Stream: true, Stream: true,
} }
switch session.Model.Platform { switch session.Model.Platform {
case types.Azure: case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
req.Temperature = h.App.ChatConfig.Azure.Temperature req.Temperature = session.Model.Temperature
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens req.MaxTokens = session.Model.MaxTokens
break
case types.ChatGLM:
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
break
case types.Baidu:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
// TODO 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
break break
case types.OpenAI: case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature req.Temperature = session.Model.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能 // OpenAI 支持函数功能
var items []model.Function var items []model.Function
res := h.db.Where("enabled", true).Find(&items) res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil { if res.Error != nil {
break break
} }
var tools = make([]interface{}, 0) var tools = make([]interface{}, 0)
var functions = make([]interface{}, 0)
for _, v := range items { for _, v := range items {
var parameters map[string]interface{} var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters) err = utils.JsonDecode(v.Parameters, &parameters)
@@ -269,30 +253,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"required": required, "required": required,
}, },
}) })
functions = append(functions, gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
})
} }
//if len(tools) > 0 { if len(tools) > 0 {
// req.Tools = tools req.Tools = tools
// req.ToolChoice = "auto" req.ToolChoice = "auto"
//}
if len(functions) > 0 {
req.Functions = functions
} }
case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
break
case types.QWen: case types.QWen:
req.Input = map[string]interface{}{"messages": []map[string]string{{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}}} req.Parameters = map[string]interface{}{
req.Parameters = map[string]interface{}{} "max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break break
default: default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg) utils.ReplyMessage(ws, ErrImg)
@@ -300,40 +273,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
// 加载聊天上下文 // 加载聊天上下文
var chatCtx []interface{} chatCtx := make([]types.Message, 0)
if h.App.ChatConfig.EnableContext { messages := make([]types.Message, 0)
if h.App.SysConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) { if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId) messages = h.App.ChatContexts.Get(session.ChatId)
} else { } else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num _ = utils.JsonDecode(role.Context, &messages)
tokens := req.MaxTokens if h.App.SysConfig.ContextDeep > 0 {
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
if chatConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage var historyMessages []model.ChatMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages) res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil { if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- { for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i] msg := historyMessages[i]
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content} ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg { if msg.Type == types.ReplyMsg {
ms.Role = "assistant" ms.Role = "assistant"
@@ -343,6 +295,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
} }
} }
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx) logger.Debugf("聊天上下文:%+v", chatCtx)
} }
reqMgs := make([]interface{}, 0) reqMgs := make([]interface{}, 0)
@@ -350,10 +325,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m) reqMgs = append(reqMgs, m)
} }
req.Messages = append(reqMgs, map[string]interface{}{ if session.Model.Platform == types.QWen {
"role": "user", req.Input = map[string]interface{}{"prompt": prompt}
"content": prompt, if len(reqMgs) > 0 {
}) req.Input["messages"] = reqMgs
}
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
}
switch session.Model.Platform { switch session.Model.Platform {
case types.Azure: case types.Azure:
@@ -392,7 +374,7 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
if data.Text == "" && data.ChatId != "" { if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID) userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item) res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
@@ -443,7 +425,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器 // 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY // useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) { func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil { if res.Error != nil {
return nil, errors.New("no available key, please import key") return nil, errors.New("no available key, please import key")
} }
@@ -469,7 +451,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
apiURL = apiKey.ApiURL apiURL = apiKey.ApiURL
} }
// 更新 API KEY 的最后使用时间 // 更新 API KEY 的最后使用时间
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token // 百度文心,需要串接 access_token
if platform == types.Baidu { if platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value) token, err := h.getBaiduToken(apiKey.Value)
@@ -496,9 +478,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
request = request.WithContext(ctx) request = request.WithContext(ctx)
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
var proxyURL string var proxyURL string
if h.App.Config.ProxyURL != "" && apiKey.UseProxy { // 使用代理 if apiKey.ProxyURL != "" { // 使用代理
proxyURL = h.App.Config.ProxyURL proxy, _ := url.Parse(apiKey.ProxyURL)
proxy, _ := url.Parse(proxyURL)
client = &http.Client{ client = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyURL(proxy), Proxy: http.ProxyURL(proxy),
@@ -532,23 +513,30 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
return client.Do(request) return client.Do(request)
} }
// 扣减用户的对话次数 // 扣减用户算力
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) { func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
// 仅当用户没有导入自己的 API KEY 时才进行扣减 power := 1
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { if session.Model.Power > 0 {
num := 1 power = session.Model.Power
if session.Model.Weight > 0 { }
num = session.Model.Weight res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
} if res.Error == nil {
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num)) // 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
} }
}
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
} }
// 将AI回复消息中生成的图片链接下载到本地 // 将AI回复消息中生成的图片链接下载到本地

View File

@@ -6,27 +6,29 @@ import (
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
// List 获取会话列表 // List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) { func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0) if !h.IsLogin(c) {
if userId == 0 { resp.SUCCESS(c)
resp.ERROR(c, "The parameter 'user_id' is needed.")
return return
} }
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0) var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats) res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil { if res.Error == nil {
var roleIds = make([]uint, 0) var roleIds = make([]uint, 0)
for _, chat := range chats { for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId) roleIds = append(roleIds, chat.RoleId)
} }
var roles []model.ChatRole var roles []model.ChatRole
res = h.db.Find(&roles, roleIds) res = h.DB.Find(&roles, roleIds)
if res.Error == nil { if res.Error == nil {
roleMap := make(map[uint]model.ChatRole) roleMap := make(map[uint]model.ChatRole)
for _, role := range roles { for _, role := range roles {
@@ -58,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title) res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Failed to update database") resp.ERROR(c, "Failed to update database")
return return
@@ -70,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Clear 清空所有聊天记录 // Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) { func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话 // 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
var chats []model.ChatItem var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats) res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No chats found") resp.ERROR(c, "No chats found")
return return
@@ -89,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// 清空会话上下文 // 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId) h.App.ChatContexts.Delete(chat.ChatId)
} }
err = h.db.Transaction(func(tx *gorm.DB) error { err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{}) res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil { if res.Error != nil {
return res.Error return res.Error
} }
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{}) res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
if res.Error != nil { if res.Error != nil {
return res.Error return res.Error
} }
@@ -118,7 +120,7 @@ func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0) var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items) res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No history message") resp.ERROR(c, "No history message")
return return
@@ -144,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{}) res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Failed to update database") resp.ERROR(c, "Failed to update database")
return return
} }
// 删除当前会话的聊天记录 // 删除当前会话的聊天记录
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{}) res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.") resp.ERROR(c, "Failed to remove chat from database.")
return return
@@ -179,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
} }
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", chatId).First(&chatItem) res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No chat found") resp.ERROR(c, "No chat found")
return return

View File

@@ -20,7 +20,7 @@ import (
// 清华大学 ChatGML 消息发送实现 // 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage( func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -107,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
} }
@@ -117,65 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage(
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
if h.App.ChatConfig.EnableHistory { // for prompt
// for prompt promptToken, err := utils.CalcTokens(prompt, req.Model)
promptToken, err := utils.CalcTokens(prompt, req.Model) if err != nil {
if err != nil { logger.Error(err)
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
} }
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId chatItem.UserId = session.UserId
@@ -187,7 +183,7 @@ func (h *ChatHandler) sendChatGLMMessage(
chatItem.Title = prompt chatItem.Title = prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.db.Create(&chatItem) h.DB.Create(&chatItem)
} }
} }
} else { } else {

View File

@@ -20,7 +20,7 @@ import (
// OPenAI 消息发送实现 // OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage( func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrorMsg) utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg) utils.ReplyMessage(ws, ErrImg)
all, _ := io.ReadAll(response.Body) if response.Body != nil {
logger.Error(string(all)) all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
}
return err return err
} else { } else {
defer response.Body.Close() defer response.Body.Close()
@@ -98,7 +100,7 @@ func (h *ChatHandler) sendOpenAiMessage(
} }
if !utils.IsEmptyValue(tool) { if !utils.IsEmptyValue(tool) {
res := h.db.Where("name = ?", tool.Function.Name).First(&function) res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil { if res.Error == nil {
toolCall = true toolCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
@@ -171,9 +173,6 @@ func (h *ChatHandler) sendOpenAiMessage(
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
} }
@@ -181,79 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage(
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext && toolCall == false { if h.App.SysConfig.EnableContext && toolCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
if h.App.ChatConfig.EnableHistory { useContext := true
useContext := true if toolCall {
if toolCall { useContext = false
useContext = false
}
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: totalTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
} }
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: useContext,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// 计算本次对话消耗的总 token 数量
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyTokens += tokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId chatItem.UserId = session.UserId
@@ -265,17 +262,19 @@ func (h *ChatHandler) sendOpenAiMessage(
chatItem.Title = prompt chatItem.Title = prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.db.Create(&chatItem) h.DB.Create(&chatItem)
} }
} }
} else { } else {
body, err := io.ReadAll(response.Body) body, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err) return fmt.Errorf("error with reading response: %v", err)
} }
var res types.ApiError var res types.ApiError
err = json.Unmarshal(body, &res) err = json.Unmarshal(body, &res)
if err != nil { if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err) return fmt.Errorf("error with decode response: %v", err)
} }
@@ -283,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。") utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key // 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{}) h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") { } else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。") utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") { } else if strings.Contains(res.Error.Message, "This model's maximum context length") {

View File

@@ -20,18 +20,21 @@ type qWenResp struct {
Output struct { Output struct {
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
Text string `json:"text"` Text string `json:"text"`
} `json:"output"` } `json:"output,omitempty"`
Usage struct { Usage struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
} `json:"usage"` } `json:"usage,omitempty"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
} }
// 通义千问消息发送实现 // 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage( func (h *ChatHandler) sendQWenMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -70,6 +73,7 @@ func (h *ChatHandler) sendQWenMessage(
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string var content, lastText, newText string
var outPutStart = false
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@@ -77,24 +81,32 @@ func (h *ChatHandler) sendQWenMessage(
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue continue
} }
if strings.HasPrefix(line, "data:") { if strings.HasPrefix(line, "data:") {
content = line[5:] content = line[5:]
} }
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
var resp qWenResp var resp qWenResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 { // 发送消息头 if len(contents) == 0 { // 发送消息头
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
} }
//通过比较 lastText上一次的文本和 currentText当前的文本 //通过比较 lastText上一次的文本和 currentText当前的文本
@@ -128,9 +140,6 @@ func (h *ChatHandler) sendQWenMessage(
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
} }
@@ -138,65 +147,64 @@ func (h *ChatHandler) sendQWenMessage(
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
if h.App.ChatConfig.EnableHistory { // for prompt
// for prompt promptToken, err := utils.CalcTokens(prompt, req.Model)
promptToken, err := utils.CalcTokens(prompt, req.Model) if err != nil {
if err != nil { logger.Error(err)
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
} }
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId chatItem.UserId = session.UserId
@@ -208,7 +216,7 @@ func (h *ChatHandler) sendQWenMessage(
chatItem.Title = prompt chatItem.Title = prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.db.Create(&chatItem) h.DB.Create(&chatItem)
} }
} }
} else { } else {

View File

@@ -50,15 +50,16 @@ type xunFeiResp struct {
} }
var Model2URL = map[string]string{ var Model2URL = map[string]string{
"general": "v1.1", "general": "v1.1",
"generalv2": "v2.1", "generalv2": "v2.1",
"generalv3": "v3.1", "generalv3": "v3.1",
"generalv3.5": "v3.5",
} }
// 科大讯飞消息发送实现 // 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage( func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@@ -68,13 +69,13 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey var apiKey model.ApiKey
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if res.Error != nil { if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员") utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil return nil
} }
// 更新 API KEY 的最后使用时间 // 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{ d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
@@ -86,6 +87,7 @@ func (h *ChatHandler) sendXunFeiMessage(
} }
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接 //握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil) conn, resp, err := d.Dial(wsURL, nil)
@@ -166,9 +168,6 @@ func (h *ChatHandler) sendXunFeiMessage(
// 消息发送成功 // 消息发送成功
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
} }
@@ -176,65 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage(
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
if h.App.ChatConfig.EnableHistory { // for prompt
// for prompt promptToken, err := utils.CalcTokens(prompt, req.Model)
promptToken, err := utils.CalcTokens(prompt, req.Model) if err != nil {
if err != nil { logger.Error(err)
logger.Error(err)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
} }
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
// 保存当前会话 // 保存当前会话
var chatItem model.ChatItem var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil { if res.Error != nil {
chatItem.ChatId = session.ChatId chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId chatItem.UserId = session.UserId
@@ -246,7 +244,7 @@ func (h *ChatHandler) sendXunFeiMessage(
chatItem.Title = prompt chatItem.Title = prompt
} }
chatItem.Model = req.Model chatItem.Model = req.Model
h.db.Create(&chatItem) h.DB.Create(&chatItem)
} }
} }
@@ -262,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
"parameter": map[string]interface{}{ "parameter": map[string]interface{}{
"chat": map[string]interface{}{ "chat": map[string]interface{}{
"domain": req.Model, "domain": req.Model,
"temperature": float64(req.Temperature), "temperature": req.Temperature,
"top_k": int64(6), "top_k": int64(6),
"max_tokens": int64(req.MaxTokens), "max_tokens": int64(req.MaxTokens),
"auditing": "default", "auditing": "default",

View File

@@ -0,0 +1,39 @@
package handler
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ConfigHandler struct {
BaseHandler
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var value map[string]interface{}
err := utils.JsonDecode(config.Config, &value)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, value)
}

View File

@@ -19,21 +19,18 @@ import (
type FunctionHandler struct { type FunctionHandler struct {
BaseHandler BaseHandler
db *gorm.DB
config types.ChatPlusApiConfig config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
proxyURL string
} }
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
return &FunctionHandler{ return &FunctionHandler{
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: server, App: server,
DB: db,
}, },
db: db,
config: config.ApiConfig, config: config.ApiConfig,
uploadManager: manager, uploadManager: manager,
proxyURL: config.ProxyURL,
} }
} }
@@ -192,68 +189,49 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
} }
logger.Debugf("绘画参数:%+v", params) logger.Debugf("绘画参数:%+v", params)
// check img calls
var user model.User var user model.User
tx := h.db.Where("id = ?", params["user_id"]).First(&user) tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil { if tx.Error != nil {
resp.ERROR(c, "当前用户不存在!") resp.ERROR(c, "当前用户不存在!")
return return
} }
if user.ImgCalls <= 0 { if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户的绘图次数额度不足") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return return
} }
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY // get image generation API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil { if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
return return
} }
// get image generation api URL
var conf model.Config
var chatConfig types.ChatConfig
tx = h.db.Where("marker", "chat").First(&conf)
if tx.Error != nil {
resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
return
}
err := utils.JsonDecode(conf.Config, &chatConfig)
if err != nil {
resp.ERROR(c, "error with decode chat config: "+err.Error())
return
}
// translate prompt // translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]), h.App.Config.ProxyURL) pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil { if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt prompt = pt
} }
imgNum := chatConfig.DallImgNum
if imgNum <= 0 {
imgNum = 1
}
var res imgRes var res imgRes
var errRes ErrRes var errRes ErrRes
var request *req.Request var request *req.Request
if apiKey.UseProxy && h.proxyURL != "" { if apiKey.ProxyURL != "" {
request = req.C().SetProxyURL(h.proxyURL).R() request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else { } else {
request = req.C().R() request = req.C().R()
} }
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, h.proxyURL) logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json"). r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value). SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{ SetBody(imgReq{
Model: "dall-e-3", Model: "dall-e-3",
Prompt: prompt, Prompt: prompt,
N: imgNum, N: 1,
Size: "1024x1024", Size: "1024x1024",
}). }).
SetErrorResult(&errRes). SetErrorResult(&errRes).
@@ -263,7 +241,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return return
} }
// 更新 API KEY 的最后使用时间 // 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res) logger.Debugf("%+v", res)
// 存储图片 // 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false) imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
@@ -273,8 +251,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
} }
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL) content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// update user's img_calls // 更新用户算力
h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
h.DB.Where("id", user.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content) resp.SUCCESS(c, content)
} }

View File

@@ -15,32 +15,29 @@ import (
// InviteHandler 用户邀请 // InviteHandler 用户邀请
type InviteHandler struct { type InviteHandler struct {
BaseHandler BaseHandler
db *gorm.DB
} }
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler { func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
h := InviteHandler{db: db} return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
// Code 获取当前用户邀请码 // Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) { func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var inviteCode model.InviteCode var inviteCode model.InviteCode
res := h.db.Where("user_id = ?", userId).First(&inviteCode) res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
// 如果邀请码不存在,则创建一个 // 如果邀请码不存在,则创建一个
if res.Error != nil { if res.Error != nil {
code := strings.ToUpper(utils.RandString(8)) code := strings.ToUpper(utils.RandString(8))
for { for {
res = h.db.Where("code = ?", code).First(&inviteCode) res = h.DB.Where("code = ?", code).First(&inviteCode)
if res.Error != nil { // 不存在相同的邀请码则退出 if res.Error != nil { // 不存在相同的邀请码则退出
break break
} }
} }
inviteCode.UserId = userId inviteCode.UserId = userId
inviteCode.Code = code inviteCode.Code = code
h.db.Create(&inviteCode) h.DB.Create(&inviteCode)
} }
var codeVo vo.InviteCode var codeVo vo.InviteCode
@@ -65,7 +62,7 @@ func (h *InviteHandler) List(c *gin.Context) {
return return
} }
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
session := h.db.Session(&gorm.Session{}).Where("inviter_id = ?", userId) session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64 var total int64
session.Model(&model.InviteLog{}).Count(&total) session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog var items []model.InviteLog
@@ -91,6 +88,6 @@ func (h *InviteHandler) List(c *gin.Context) {
// Hits 访问邀请码 // Hits 访问邀请码
func (h *InviteHandler) Hits(c *gin.Context) { func (h *InviteHandler) Hits(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
h.db.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1)) h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@@ -5,7 +5,6 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service" "chatplus/service"
"chatplus/service/mj" "chatplus/service/mj"
"chatplus/service/mj/plus"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
@@ -24,32 +23,32 @@ import (
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
db *gorm.DB
pool *mj.ServicePool pool *mj.ServicePool
snowflake *service.Snowflake snowflake *service.Snowflake
uploader *oss.UploaderManager uploader *oss.UploaderManager
} }
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
h := MidJourneyHandler{ return &MidJourneyHandler{
db: db,
snowflake: snowflake, snowflake: snowflake,
pool: pool, pool: pool,
uploader: manager, uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
} }
h.App = app
return &h
} }
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return false return false
} }
if user.ImgCalls <= 0 { if user.Power < h.App.SysConfig.MjPower {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false return false
} }
@@ -160,14 +159,19 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskId: taskId, TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: prompt, Prompt: prompt,
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
opt := "绘图"
if data.TaskType == types.TaskBlend.String() { if data.TaskType == types.TaskBlend.String() {
data.Prompt = "融图:" + strings.Join(data.ImgArr, ",") job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
opt = "融图"
} else if data.TaskType == types.TaskSwapFace.String() { } else if data.TaskType == types.TaskSwapFace.String() {
data.Prompt = "换脸:" + strings.Join(data.ImgArr, ",") job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
opt = "换脸"
} }
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error()) resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return return
} }
@@ -187,8 +191,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
// update user's img calls // update user's power
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -226,9 +245,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
TaskId: taskId, TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: data.Prompt, Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 { if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error()) resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return return
} }
@@ -249,7 +269,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -276,9 +312,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId, TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: data.Prompt, Prompt: data.Prompt,
Power: h.App.SysConfig.MjActionPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 { if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error()) resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return return
} }
@@ -300,21 +337,60 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
// update user's img calls // update user's power
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// ImgWall 照片墙
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表 // JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) JobList(c *gin.Context) { func (h *MidJourneyHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0) status := h.GetBool(c, "status")
userId := h.GetInt(c, "user_id", 0) userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0) page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0) pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish") publish := h.GetBool(c, "publish")
session := h.db.Session(&gorm.Session{}) err, jobs := h.getData(status, userId, page, pageSize, publish)
if status == 1 { if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress = ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
@@ -333,8 +409,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
var items []model.MidJourneyJob var items []model.MidJourneyJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, types.NoData) return res.Error, nil
return
} }
var jobs = make([]vo.MidJourneyJob, 0) var jobs = make([]vo.MidJourneyJob, 0)
@@ -345,13 +420,6 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
continue continue
} }
// 失败的任务直接删除
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
jobs = append(jobs, job)
continue
}
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" { if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// discord 服务器图片需要使用代理转发图片数据流 // discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") { if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
@@ -366,7 +434,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
jobs = append(jobs, job) jobs = append(jobs, job)
} }
resp.SUCCESS(c, jobs) return nil, jobs
} }
// Remove remove task image // Remove remove task image
@@ -382,7 +450,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
} }
// remove job recode // remove job recode
res := h.db.Delete(&model.MidJourneyJob{Id: data.Id}) res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
@@ -402,27 +470,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// Notify MidJourney Plus 服务任务回调处理
func (h *MidJourneyHandler) Notify(c *gin.Context) {
var data plus.CBReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error("非法任务回调:%+v", err)
return
}
err := h.pool.Notify(data)
if err != nil {
logger.Error(err)
} else {
userId := h.GetLoginUserId(c)
client := h.pool.Clients.Get(userId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
}
resp.SUCCESS(c)
}
// Publish 发布图片到画廊显示 // Publish 发布图片到画廊显示
func (h *MidJourneyHandler) Publish(c *gin.Context) { func (h *MidJourneyHandler) Publish(c *gin.Context) {
var data struct { var data struct {
@@ -434,7 +481,7 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
return return
} }
res := h.db.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action) res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, "更新数据库失败")
return return

View File

@@ -7,19 +7,17 @@ import (
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
type OrderHandler struct { type OrderHandler struct {
BaseHandler BaseHandler
db *gorm.DB
} }
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db} return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
func (h *OrderHandler) List(c *gin.Context) { func (h *OrderHandler) List(c *gin.Context) {
@@ -31,8 +29,8 @@ func (h *OrderHandler) List(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
user, _ := utils.GetLoginUser(c, h.db) userId := h.GetLoginUserId(c)
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess) session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64 var total int64
session.Model(&model.Order{}).Count(&total) session.Model(&model.Order{}).Count(&total)
var items []model.Order var items []model.Order

View File

@@ -11,6 +11,7 @@ import (
"embed" "embed"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/shopspring/decimal"
"math" "math"
"net/http" "net/http"
"net/url" "net/url"
@@ -34,7 +35,6 @@ type PaymentHandler struct {
huPiPayService *payment.HuPiPayService huPiPayService *payment.HuPiPayService
js *payment.PayJS js *payment.PayJS
snowflake *service.Snowflake snowflake *service.Snowflake
db *gorm.DB
fs embed.FS fs embed.FS
lock sync.Mutex lock sync.Mutex
} }
@@ -44,20 +44,21 @@ func NewPaymentHandler(
alipayService *payment.AlipayService, alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService, huPiPayService *payment.HuPiPayService,
js *payment.PayJS, js *payment.PayJS,
snowflake *service.Snowflake,
db *gorm.DB, db *gorm.DB,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler { fs embed.FS) *PaymentHandler {
h := PaymentHandler{ return &PaymentHandler{
alipayService: alipayService, alipayService: alipayService,
huPiPayService: huPiPayService, huPiPayService: huPiPayService,
js: js, js: js,
snowflake: snowflake, snowflake: snowflake,
fs: fs, fs: fs,
db: db,
lock: sync.Mutex{}, lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
} }
h.App = server
return &h
} }
func (h *PaymentHandler) DoPay(c *gin.Context) { func (h *PaymentHandler) DoPay(c *gin.Context) {
@@ -70,7 +71,7 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
} }
var order model.Order var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order) res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Order not found") resp.ERROR(c, "Order not found")
return return
@@ -83,7 +84,7 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
} }
// 更新扫码状态 // 更新扫码状态
h.db.Model(&order).UpdateColumn("status", types.OrderScanned) h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
if payWay == "alipay" { // 支付宝 if payWay == "alipay" { // 支付宝
// 生成支付链接 // 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL notifyURL := h.App.Config.AlipayConfig.NotifyURL
@@ -129,7 +130,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
} }
var order model.Order var order model.Order
res := h.db.Where("order_no = ?", data.OrderNo).First(&order) res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Order not found") resp.ERROR(c, "Order not found")
return return
@@ -144,7 +145,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
for { for {
time.Sleep(time.Second) time.Sleep(time.Second)
var item model.Order var item model.Order
h.db.Where("order_no = ?", data.OrderNo).First(&item) h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status { if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status order.Status = item.Status
break break
@@ -168,7 +169,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
} }
var product model.Product var product model.Product
res := h.db.First(&product, data.ProductId) res := h.DB.First(&product, data.ProductId)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Product not found") resp.ERROR(c, "Product not found")
return return
@@ -180,7 +181,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return return
} }
var user model.User var user model.User
res = h.db.First(&user, data.UserId) res = h.DB.First(&user, data.UserId)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Invalid user ID") resp.ERROR(c, "Invalid user ID")
return return
@@ -202,24 +203,25 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
// 创建订单 // 创建订单
remark := types.OrderRemark{ remark := types.OrderRemark{
Days: product.Days, Days: product.Days,
Calls: product.Calls, Power: product.Power,
ImgCalls: product.ImgCalls,
Name: product.Name, Name: product.Name,
Price: product.Price, Price: product.Price,
Discount: product.Discount, Discount: product.Discount,
} }
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
order := model.Order{ order := model.Order{
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
ProductId: product.Id, ProductId: product.Id,
OrderNo: orderNo, OrderNo: orderNo,
Subject: product.Name, Subject: product.Name,
Amount: product.Price - product.Discount, Amount: amount,
Status: types.OrderNotPaid, Status: types.OrderNotPaid,
PayWay: payWay, PayWay: payWay,
Remark: utils.JsonEncode(remark), Remark: utils.JsonEncode(remark),
} }
res = h.db.Create(&order) res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 { if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error()) resp.ERROR(c, "error with create order: "+res.Error.Error())
return return
@@ -275,10 +277,121 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL}) resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
} }
// Mobile 移动端支付
func (h *PaymentHandler) Mobile(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"` // 支付方式
ProductId uint `json:"product_id"`
UserId int `json:"user_id"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var product model.Product
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
}
orderNo, err := h.snowflake.Next(false)
if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
var user model.User
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
var payWay string
var notifyURL, returnURL string
var payURL string
switch data.PayWay {
case "hupi":
payWay = PayWayXunHu
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
params := payment.HuPiPayReq{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
CallbackURL: returnURL,
WapName: "极客学长",
}
r, err := h.huPiPayService.Pay(params)
if err != nil {
logger.Error("error with generating Pay URL: ", err.Error())
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
payURL = r.URL
case "payjs":
payWay = PayWayJs
notifyURL = h.App.Config.JPayConfig.NotifyURL
returnURL = h.App.Config.JPayConfig.ReturnURL
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
params := url.Values{}
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
params.Add("out_trade_no", orderNo)
params.Add("body", product.Name)
params.Add("notify_url", notifyURL)
params.Add("auto", "0")
payURL = h.js.PayH5(params)
case "alipay":
payWay = PayWayAlipay
notifyURL = h.App.Config.AlipayConfig.NotifyURL
returnURL = h.App.Config.AlipayConfig.ReturnURL
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
if err != nil {
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
return
}
default:
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
return
}
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
}
resp.SUCCESS(c, payURL)
}
// 异步通知回调公共逻辑 // 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order) res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil { if res.Error != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error) err := fmt.Errorf("error with fetch order: %v", res.Error)
logger.Error(err) logger.Error(err)
@@ -294,7 +407,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
} }
var user model.User var user model.User
res = h.db.First(&user, order.UserId) res = h.DB.First(&user, order.UserId)
if res.Error != nil { if res.Error != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error) err := fmt.Errorf("error with fetch user info: %v", res.Error)
logger.Error(err) logger.Error(err)
@@ -309,29 +422,33 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
return err return err
} }
var opt string
var power int
if user.Vip { // 已经是 VIP 用户 if user.Vip { // 已经是 VIP 用户
if remark.Days > 0 { // 只延期 VIP不增加调用次数 if remark.Days > 0 { // 只延期 VIP不增加调用次数
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix() user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
} else { // 充值点卡,直接增加次数即可 } else { // 充值点卡,直接增加次数即可
user.Calls += remark.Calls user.Power += remark.Power
user.ImgCalls += remark.ImgCalls opt = "点卡充值"
power = remark.Power
} }
} else { // 非 VIP 用户 } else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, calls == 0 if remark.Days > 0 { // vip 套餐days > 0, power == 0
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix() user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Calls += h.App.SysConfig.VipMonthCalls user.Power += h.App.SysConfig.VipMonthPower
user.ImgCalls += h.App.SysConfig.VipMonthImgCalls
user.Vip = true user.Vip = true
opt = "VIP充值"
power = h.App.SysConfig.VipMonthPower
} else { //点卡days == 0, calls > 0 } else { //点卡days == 0, calls > 0
user.Calls += remark.Calls user.Power += remark.Power
user.ImgCalls += remark.ImgCalls opt = "点卡充值"
power = remark.Power
} }
} }
// 更新用户信息 // 更新用户信息
res = h.db.Updates(&user) res = h.DB.Updates(&user)
if res.Error != nil { if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error) err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err) logger.Error(err)
@@ -342,7 +459,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
order.PayTime = time.Now().Unix() order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo order.TradeNo = tradeNo
res = h.db.Updates(&order) res = h.DB.Updates(&order)
if res.Error != nil { if res.Error != nil {
err := fmt.Errorf("error with update order info: %v", res.Error) err := fmt.Errorf("error with update order info: %v", res.Error)
logger.Error(err) logger.Error(err)
@@ -350,7 +467,23 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
} }
// 更新产品销量 // 更新产品销量
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1)) h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
// 记录算力充值日志
if opt != "" {
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: order.PayWay,
Remark: fmt.Sprintf("%s金额%f订单号%s", opt, order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
}
return nil return nil
} }

View File

@@ -0,0 +1,67 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type PowerLogHandler struct {
BaseHandler
}
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Model string `json:"model"`
Date []string `json:"date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
userId := h.GetLoginUserId(c)
session = session.Where("user_id", userId)
if data.Model != "" {
session = session.Where("model", data.Model)
}
if len(data.Date) == 2 {
start := data.Date[0] + " 00:00:00"
end := data.Date[1] + " 00:00:00"
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
}
var total int64
session.Model(&model.PowerLog{}).Count(&total)
var items []model.PowerLog
var list = make([]vo.PowerLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var log vo.PowerLog
err := utils.CopyObject(item, &log)
if err != nil {
continue
}
log.Id = item.Id
log.CreatedAt = item.CreatedAt.Unix()
log.TypeStr = item.Type.String()
list = append(list, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}

View File

@@ -12,20 +12,17 @@ import (
type ProductHandler struct { type ProductHandler struct {
BaseHandler BaseHandler
db *gorm.DB
} }
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler { func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db} return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
h.App = app
return &h
} }
// List 模型列表 // List 模型列表
func (h *ProductHandler) List(c *gin.Context) { func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product var items []model.Product
var list = make([]vo.Product, 0) var list = make([]vo.Product, 0)
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items) res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil { if res.Error == nil {
for _, item := range items { for _, item := range items {
var product vo.Product var product vo.Product

View File

@@ -1,63 +0,0 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
type PromptHandler struct {
BaseHandler
db *gorm.DB
}
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
h := &PromptHandler{db: db}
h.App = app
return h
}
// Rewrite translate and rewrite prompt with ChatGPT
func (h *PromptHandler) Rewrite(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) Translate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}

View File

@@ -7,37 +7,35 @@ import (
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"math" "math"
"strings" "strings"
"sync" "sync"
"time"
) )
type RewardHandler struct { type RewardHandler struct {
BaseHandler BaseHandler
db *gorm.DB
lock sync.Mutex lock sync.Mutex
} }
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler { func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db, lock: sync.Mutex{}} return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
h.App = server
return &h
} }
// Verify 打赏码核销 // Verify 打赏码核销
func (h *RewardHandler) Verify(c *gin.Context) { func (h *RewardHandler) Verify(c *gin.Context) {
var data struct { var data struct {
TxId string `json:"tx_id"` TxId string `json:"tx_id"`
Type string `json:"type"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.HACKER(c) resp.HACKER(c)
return return
@@ -50,7 +48,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
defer h.lock.Unlock() defer h.lock.Unlock()
var item model.Reward var item model.Reward
res := h.db.Where("tx_id = ?", data.TxId).First(&item) res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!") resp.ERROR(c, "无效的众筹交易流水号!")
return return
@@ -61,18 +59,13 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return return
} }
tx := h.db.Begin() tx := h.DB.Begin()
exchange := vo.RewardExchange{} exchange := vo.RewardExchange{}
if data.Type == "chat" { power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
calls := math.Ceil(item.Amount / h.App.SysConfig.ChatCallPrice) exchange.Power = int(power)
exchange.Calls = int(calls) res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
} else if data.Type == "img" {
calls := math.Ceil(item.Amount / h.App.SysConfig.ImgCallPrice)
exchange.ImgCalls = int(calls)
res = h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", calls))
}
if res.Error != nil { if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
} }
@@ -81,13 +74,25 @@ func (h *RewardHandler) Verify(c *gin.Context) {
item.Status = true item.Status = true
item.UserId = user.Id item.UserId = user.Id
item.Exchange = utils.JsonEncode(exchange) item.Exchange = utils.JsonEncode(exchange)
res = h.db.Updates(&item) res = tx.Updates(&item)
if res.Error != nil { if res.Error != nil {
tx.Rollback() tx.Rollback()
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return
} }
// 记录算力充值日志
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "众筹支付",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit() tx.Commit()
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -11,10 +11,11 @@ import (
"chatplus/utils/resp" "chatplus/utils/resp"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/gorilla/websocket"
"net/http" "net/http"
"time" "time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
@@ -23,19 +24,19 @@ import (
type SdJobHandler struct { type SdJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
db *gorm.DB
pool *sd.ServicePool pool *sd.ServicePool
uploader *oss.UploaderManager uploader *oss.UploaderManager
} }
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler { func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
h := SdJobHandler{ return &SdJobHandler{
db: db,
pool: pool, pool: pool,
uploader: manager, uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
} }
h.App = app
return &h
} }
// Client WebSocket 客户端,用于通知任务状态变更 // Client WebSocket 客户端,用于通知任务状态变更
@@ -60,7 +61,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
} }
func (h *SdJobHandler) checkLimits(c *gin.Context) bool { func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return false return false
@@ -71,8 +72,8 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
return false return false
} }
if user.ImgCalls <= 0 { if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false return false
} }
@@ -132,6 +133,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
HdScaleAlg: data.HdScaleAlg, HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps, HdSteps: data.HdSteps,
} }
job := model.SdJob{ job := model.SdJob{
UserId: userId, UserId: userId,
Type: types.TaskImage.String(), Type: types.TaskImage.String(),
@@ -139,9 +141,10 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params), Params: utils.JsonEncode(params),
Prompt: data.Prompt, Prompt: data.Prompt,
Progress: 0, Progress: 0,
Power: h.App.SysConfig.SdPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
res := h.db.Create(&job) res := h.DB.Create(&job)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error()) resp.ERROR(c, "error with save job: "+res.Error.Error())
return return
@@ -151,27 +154,71 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Id: int(job.Id), Id: int(job.Id),
SessionId: data.SessionId, SessionId: data.SessionId,
Type: types.TaskImage, Type: types.TaskImage,
Prompt: data.Prompt,
Params: params, Params: params,
UserId: userId, UserId: userId,
}) })
// update user's img calls client := h.pool.Clients.Get(uint(job.UserId))
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// JobList 获取 stable diffusion 任务列表 // ImgWall 照片墙
func (h *SdJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) { func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0) status := h.GetBool(c, "status")
userId := h.GetInt(c, "user_id", 0) userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0) page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0) pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish") publish := h.GetBool(c, "publish")
session := h.db.Session(&gorm.Session{}) err, jobs := h.getData(status, userId, page, pageSize, publish)
if status == 1 { if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC") session = session.Where("progress = ?", 100).Order("id DESC")
} else { } else {
session = session.Where("progress < ?", 100).Order("id ASC") session = session.Where("progress < ?", 100).Order("id ASC")
@@ -190,8 +237,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
var items []model.SdJob var items []model.SdJob
res := session.Find(&items) res := session.Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, types.NoData) return res.Error, nil
return
} }
var jobs = make([]vo.SdJob, 0) var jobs = make([]vo.SdJob, 0)
@@ -202,18 +248,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
continue continue
} }
if job.Progress == -1 {
h.db.Delete(&model.SdJob{Id: job.Id})
}
if item.Progress < 100 { if item.Progress < 100 {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*5 {
h.db.Delete(&item)
// 退回绘图次数
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
}
// 正在运行中任务使用代理访问图片 // 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "") image, err := utils.DownloadImage(item.ImgURL, "")
if err == nil { if err == nil {
@@ -222,13 +257,15 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
} }
jobs = append(jobs, job) jobs = append(jobs, job)
} }
resp.SUCCESS(c, jobs)
return nil, jobs
} }
// Remove remove task image // Remove remove task image
func (h *SdJobHandler) Remove(c *gin.Context) { func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@@ -237,7 +274,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
} }
// remove job recode // remove job recode
res := h.db.Delete(&model.SdJob{Id: data.Id}) res := h.DB.Delete(&model.SdJob{Id: data.Id})
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
@@ -249,6 +286,11 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -263,7 +305,7 @@ func (h *SdJobHandler) Publish(c *gin.Context) {
return return
} }
res := h.db.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true) res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, "更新数据库失败")
return return

View File

@@ -29,9 +29,12 @@ func NewSmsHandler(
sms *sms.ServiceManager, sms *sms.ServiceManager,
smtp *service.SmtpService, smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler { captcha *service.CaptchaService) *SmsHandler {
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha, smtp: smtp} return &SmsHandler{
handler.App = app redis: client,
return handler sms: sms,
captcha: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
} }
// SendCode 发送验证码 // SendCode 发送验证码

View File

@@ -3,12 +3,6 @@ package handler
import ( import (
"chatplus/service" "chatplus/service"
"chatplus/service/payment" "chatplus/service/payment"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -21,208 +15,3 @@ type TestHandler struct {
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler { func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js} return &TestHandler{db: db, snowflake: snowflake, js: js}
} }
type reqBody struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
type resBody struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
func (h *TestHandler) Test(c *gin.Context) {
image(c)
}
func upscale(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/submit/action"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := map[string]string{
"customId": "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a",
"taskId": "1704880156226095",
"notifyHook": "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type queryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func query(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+r.Status)
return
}
resp.SUCCESS(c, res)
}
type errRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func image(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := reqBody{
BotType: "MID_JOURNEY",
Prompt: "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6",
NotifyHook: "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type cbReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (h *TestHandler) Mj(c *gin.Context) {
var data cbReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error(err)
}
logger.Debugf("任务ID%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt)
apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
_, _ = req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress)
}
func (h *TestHandler) initUserNickname(c *gin.Context) {
var users []model.User
tx := h.db.Find(&users)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
for _, u := range users {
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
h.db.Updates(&u)
}
resp.SUCCESS(c)
}
func (h *TestHandler) initMjTaskId(c *gin.Context) {
var jobs []model.MidJourneyJob
tx := h.db.Find(&jobs)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
for _, job := range jobs {
id, _ := h.snowflake.Next(true)
job.TaskId = id
h.db.Updates(&job)
}
resp.SUCCESS(c)
}

View File

@@ -14,14 +14,11 @@ import (
type UploadHandler struct { type UploadHandler struct {
BaseHandler BaseHandler
db *gorm.DB
uploaderManager *oss.UploaderManager uploaderManager *oss.UploaderManager
} }
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler { func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
handler := &UploadHandler{db: db, uploaderManager: manager} return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
handler.App = app
return handler
} }
func (h *UploadHandler) Upload(c *gin.Context) { func (h *UploadHandler) Upload(c *gin.Context) {
@@ -32,8 +29,8 @@ func (h *UploadHandler) Upload(c *gin.Context) {
} }
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
res := h.db.Create(&model.File{ res := h.DB.Create(&model.File{
UserId: userId, UserId: int(userId),
Name: file.Name, Name: file.Name,
ObjKey: file.ObjKey, ObjKey: file.ObjKey,
URL: file.URL, URL: file.URL,
@@ -53,7 +50,7 @@ func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var items []model.File var items []model.File
var files = make([]vo.File, 0) var files = make([]vo.File, 0)
h.db.Where("user_id = ?", userId).Find(&items) h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 { if len(items) > 0 {
for _, v := range items { for _, v := range items {
var file vo.File var file vo.File
@@ -75,14 +72,14 @@ func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
var file model.File var file model.File
tx := h.db.Where("user_id = ? AND id = ?", userId, id).First(&file) tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 { if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed") resp.ERROR(c, "file not existed")
return return
} }
// remove database // remove database
tx = h.db.Model(&model.File{}).Delete("id = ?", id) tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 { if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database") resp.ERROR(c, "failed to update database")
return return

View File

@@ -21,7 +21,6 @@ import (
type UserHandler struct { type UserHandler struct {
BaseHandler BaseHandler
db *gorm.DB
searcher *xdb.Searcher searcher *xdb.Searcher
redis *redis.Client redis *redis.Client
} }
@@ -31,15 +30,14 @@ func NewUserHandler(
db *gorm.DB, db *gorm.DB,
searcher *xdb.Searcher, searcher *xdb.Searcher,
client *redis.Client) *UserHandler { client *redis.Client) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, redis: client} return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
handler.App = app
return handler
} }
// Register user register // Register user register
func (h *UserHandler) Register(c *gin.Context) { func (h *UserHandler) Register(c *gin.Context) {
// parameters process // parameters process
var data struct { var data struct {
RegWay string `json:"reg_way"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Code string `json:"code"` Code string `json:"code"`
@@ -57,8 +55,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检查验证码 // 检查验证码
var key string var key string
if utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") || if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
key = CodeStorePrefix + data.Username key = CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result() code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code { if err != nil || code != data.Code {
@@ -70,7 +67,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 验证邀请码 // 验证邀请码
inviteCode := model.InviteCode{} inviteCode := model.InviteCode{}
if data.InviteCode != "" { if data.InviteCode != "" {
res := h.db.Where("code = ?", data.InviteCode).First(&inviteCode) res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "无效的邀请码") resp.ERROR(c, "无效的邀请码")
return return
@@ -79,8 +76,8 @@ func (h *UserHandler) Register(c *gin.Context) {
// check if the username is exists // check if the username is exists
var item model.User var item model.User
res := h.db.Where("username = ?", data.Username).First(&item) res := h.DB.Where("username = ?", data.Username).First(&item)
if res.RowsAffected > 0 { if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册") resp.ERROR(c, "该用户名已经被注册")
return return
} }
@@ -95,18 +92,10 @@ func (h *UserHandler) Register(c *gin.Context) {
Status: true, Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型 ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
ChatConfig: utils.JsonEncode(types.UserChatConfig{ Power: h.App.SysConfig.InitPower,
ApiKeys: map[types.Platform]string{
types.OpenAI: "",
types.Azure: "",
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.InitChatCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
} }
res = h.db.Create(&user) res = h.DB.Create(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "保存数据失败") resp.ERROR(c, "保存数据失败")
logger.Error(res.Error) logger.Error(res.Error)
@@ -116,21 +105,32 @@ func (h *UserHandler) Register(c *gin.Context) {
// 记录邀请关系 // 记录邀请关系
if data.InviteCode != "" { if data.InviteCode != "" {
// 增加邀请数量 // 增加邀请数量
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InviteChatCalls > 0 { if h.App.SysConfig.InvitePower > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("calls", gorm.Expr("calls + ?", h.App.SysConfig.InviteChatCalls)) h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
} // 记录邀请算力充值日志
if h.App.SysConfig.InviteImgCalls > 0 { var inviter model.User
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", h.App.SysConfig.InviteImgCalls)) h.DB.Where("id", inviteCode.UserId).First(&inviter)
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
})
} }
// 添加邀请记录 // 添加邀请记录
h.db.Create(&model.InviteLog{ h.DB.Create(&model.InviteLog{
InviterId: inviteCode.UserId, InviterId: inviteCode.UserId,
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
InviteCode: inviteCode.Code, InviteCode: inviteCode.Code,
Reward: utils.JsonEncode(types.InviteReward{ChatCalls: h.App.SysConfig.InviteChatCalls, ImgCalls: h.App.SysConfig.InviteImgCalls}), Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
}) })
} }
@@ -166,7 +166,7 @@ func (h *UserHandler) Login(c *gin.Context) {
return return
} }
var user model.User var user model.User
res := h.db.Where("username = ?", data.Username).First(&user) res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "用户名不存在") resp.ERROR(c, "用户名不存在")
return return
@@ -186,9 +186,9 @@ func (h *UserHandler) Login(c *gin.Context) {
// 更新最后登录时间和IP // 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP() user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix() user.LastLoginAt = time.Now().Unix()
h.db.Model(&user).Updates(user) h.DB.Model(&user).Updates(user)
h.db.Create(&model.UserLoginLog{ h.DB.Create(&model.UserLoginLog{
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
LoginIp: c.ClientIP(), LoginIp: c.ClientIP(),
@@ -233,7 +233,7 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话 // Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) { func (h *UserHandler) Session(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err == nil { if err == nil {
var userVo vo.User var userVo vo.User
err := utils.CopyObject(user, &userVo) err := utils.CopyObject(user, &userVo)
@@ -249,27 +249,23 @@ func (h *UserHandler) Session(c *gin.Context) {
} }
type userProfile struct { type userProfile struct {
Id uint `json:"id"` Id uint `json:"id"`
Nickname string `json:"nickname"` Nickname string `json:"nickname"`
Username string `json:"username"` Username string `json:"username"`
Avatar string `json:"avatar"` Avatar string `json:"avatar"`
ChatConfig types.UserChatConfig `json:"chat_config"` Power int `json:"power"`
Calls int `json:"calls"` ExpiredTime int64 `json:"expired_time"`
ImgCalls int `json:"img_calls"` Vip bool `json:"vip"`
TotalTokens int64 `json:"total_tokens"`
Tokens int64 `json:"tokens"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
} }
func (h *UserHandler) Profile(c *gin.Context) { func (h *UserHandler) Profile(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
h.db.First(&user, user.Id) h.DB.First(&user, user.Id)
var profile userProfile var profile userProfile
err = utils.CopyObject(user, &profile) err = utils.CopyObject(user, &profile)
if err != nil { if err != nil {
@@ -289,15 +285,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
return return
} }
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
h.db.First(&user, user.Id) h.DB.First(&user, user.Id)
user.Avatar = data.Avatar user.Avatar = data.Avatar
user.Nickname = data.Nickname user.Nickname = data.Nickname
res := h.db.Updates(&user) res := h.DB.Updates(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新用户信息失败") resp.ERROR(c, "更新用户信息失败")
return return
@@ -322,21 +318,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
return return
} }
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
password := utils.GenPassword(data.OldPass, user.Salt) password := utils.GenPassword(data.OldPass, user.Salt)
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass) logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
if password != user.Password { if password != user.Password {
resp.ERROR(c, "原密码错误") resp.ERROR(c, "原密码错误")
return return
} }
newPass := utils.GenPassword(data.Password, user.Salt) newPass := utils.GenPassword(data.Password, user.Salt)
res := h.db.Model(&user).UpdateColumn("password", newPass) res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil { if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error) logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, "更新数据库失败")
@@ -359,7 +355,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
} }
var user model.User var user model.User
res := h.db.Where("username", data.Username).First(&user) res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "用户不存在!") resp.ERROR(c, "用户不存在!")
return return
@@ -375,7 +371,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
user.Password = password user.Password = password
res = h.db.Updates(&user) res = h.DB.Updates(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c) resp.ERROR(c)
} else { } else {
@@ -405,19 +401,19 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定 // 检查手机号是否被其他账号绑定
var item model.User var item model.User
res := h.db.Where("username = ?", data.Username).First(&item) res := h.DB.Where("username = ?", data.Username).First(&item)
if res.Error == nil { if res.Error == nil {
resp.ERROR(c, "该账号已经被其他账号绑定") resp.ERROR(c, "该账号已经被其他账号绑定")
return return
} }
user, err := utils.GetLoginUser(c, h.db) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
res = h.db.Model(&user).UpdateColumn("username", data.Username) res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败") resp.ERROR(c, "更新数据库失败")
return return

View File

@@ -125,6 +125,8 @@ func main() {
fx.Provide(handler.NewPaymentHandler), fx.Provide(handler.NewPaymentHandler),
fx.Provide(handler.NewOrderHandler), fx.Provide(handler.NewOrderHandler),
fx.Provide(handler.NewProductHandler), fx.Provide(handler.NewProductHandler),
fx.Provide(handler.NewConfigHandler),
fx.Provide(handler.NewPowerLogHandler),
fx.Provide(admin.NewConfigHandler), fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler), fx.Provide(admin.NewAdminHandler),
@@ -137,6 +139,7 @@ func main() {
fx.Provide(admin.NewProductHandler), fx.Provide(admin.NewProductHandler),
fx.Provide(admin.NewOrderHandler), fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler), fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler),
// 创建服务 // 创建服务
fx.Provide(sms.NewSendServiceManager), fx.Provide(sms.NewSendServiceManager),
@@ -172,6 +175,12 @@ func main() {
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewServicePool), fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool) {
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
}
}),
fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay), fx.Provide(payment.NewHuPiPay),
@@ -229,6 +238,8 @@ func main() {
group := s.Engine.Group("/api/captcha/") group := s.Engine.Group("/api/captcha/")
group.GET("get", h.Get) group.GET("get", h.Get)
group.POST("check", h.Check) group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
group := s.Engine.Group("/api/reward/") group := s.Engine.Group("/api/reward/")
@@ -241,17 +252,23 @@ func main() {
group.POST("upscale", h.Upscale) group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation) group.POST("variation", h.Variation)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove) group.POST("remove", h.Remove)
group.POST("notify", h.Notify)
group.POST("publish", h.Publish) group.POST("publish", h.Publish)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove) group.POST("remove", h.Remove)
group.POST("publish", h.Publish) group.POST("publish", h.Publish)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
}),
// 管理后台控制器 // 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
@@ -264,13 +281,18 @@ func main() {
group.POST("login", h.Login) group.POST("login", h.Login)
group.GET("logout", h.Logout) group.GET("logout", h.Logout)
group.GET("session", h.Session) group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/") group := s.Engine.Group("/api/admin/apikey/")
group.POST("save", h.Save) group.POST("save", h.Save)
group.GET("list", h.List) group.GET("list", h.List)
group.POST("set", h.Set) group.POST("set", h.Set)
group.GET("remove", h.Remove) group.POST("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) { fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
group := s.Engine.Group("/api/admin/user/") group := s.Engine.Group("/api/admin/user/")
@@ -286,12 +308,12 @@ func main() {
group.POST("save", h.Save) group.POST("save", h.Save)
group.POST("sort", h.Sort) group.POST("sort", h.Sort)
group.POST("set", h.Set) group.POST("set", h.Set)
group.GET("remove", h.Remove) group.POST("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/") group := s.Engine.Group("/api/admin/reward/")
group.GET("list", h.List) group.GET("list", h.List)
group.GET("remove", h.Remove) group.POST("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/") group := s.Engine.Group("/api/admin/dashboard/")
@@ -315,6 +337,7 @@ func main() {
group.GET("payWays", h.GetPayWays) group.GET("payWays", h.GetPayWays)
group.POST("query", h.OrderQuery) group.POST("query", h.OrderQuery)
group.POST("qrcode", h.PayQrcode) group.POST("qrcode", h.PayQrcode)
group.POST("mobile", h.Mobile)
group.POST("alipay/notify", h.AlipayNotify) group.POST("alipay/notify", h.AlipayNotify)
group.POST("hupipay/notify", h.HuPiPayNotify) group.POST("hupipay/notify", h.HuPiPayNotify)
group.POST("payjs/notify", h.PayJsNotify) group.POST("payjs/notify", h.PayJsNotify)
@@ -349,13 +372,6 @@ func main() {
group.GET("hits", h.Hits) group.GET("hits", h.Hits)
}), }),
fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt/")
group.POST("rewrite", h.Rewrite)
group.POST("translate", h.Translate)
}),
fx.Provide(admin.NewFunctionHandler), fx.Provide(admin.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) { fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
group := s.Engine.Group("/api/admin/function/") group := s.Engine.Group("/api/admin/function/")
@@ -366,6 +382,18 @@ func main() {
group.GET("token", h.GenToken) group.GET("token", h.GenToken)
}), }),
// 验证码
fx.Provide(admin.NewCaptchaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
group := s.Engine.Group("/api/admin/login/")
group.GET("captcha", h.GetCaptcha)
}),
fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload)
}),
fx.Provide(handler.NewFunctionHandler), fx.Provide(handler.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) { fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
group := s.Engine.Group("/api/function/") group := s.Engine.Group("/api/function/")
@@ -381,10 +409,13 @@ func main() {
group.GET("remove", h.RemoveChat) group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage) group.GET("message/remove", h.RemoveMessage)
}), }),
fx.Provide(handler.NewTestHandler), fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { group := s.Engine.Group("/api/powerLog/")
s.Engine.GET("/api/test", h.Test) group.POST("list", h.List)
s.Engine.POST("/api/test/mj", h.Mj) }),
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
group := s.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}), }),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) { fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db) err := s.Run(db)
@@ -392,9 +423,6 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
}), }),
fx.Invoke(func(h *chatimpl.ChatHandler) {
h.Init()
}),
// 注册生命周期回调函数 // 注册生命周期回调函数
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) { fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{

View File

@@ -1,80 +0,0 @@
{
"data": [
"task(cxvkpawy8onnfti)",
"a cute girl",
"",
[],
20,
"DPM++ 2M Karras",
1,
1,
7,
512,
512,
false,
0.7,
2,
"Latent",
0,
0,
0,
"Use same checkpoint",
"Use same sampler",
"",
"",
[],
"None",
false,
"",
0.8,
-1,
false,
-1,
0,
0,
0,
null,
null,
null,
null,
false,
false,
"positive",
"comma",
0,
false,
false,
"",
"Seed",
"",
[],
"Nothing",
"",
[],
"Nothing",
"",
[],
true,
false,
false,
false,
0,
null,
null,
false,
null,
null,
false,
null,
null,
false,
50,
[],
"",
"",
""
],
"event_data": null,
"fn_index": 446,
"session_hash": "nk5noh1rz1o"
}

View File

@@ -60,3 +60,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true return true
} }
func (s *CaptchaService) SlideGet() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
}
if res.Code != types.Success {
return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
}
return res.Data, nil
}
func (s *CaptchaService) SlideCheck(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return false
}
if res.Code != types.Success {
return false
}
return true
}

View File

@@ -1,233 +0,0 @@
package mj
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
discordgo "github.com/bg5t/mydiscordgo"
"github.com/gorilla/websocket"
"net/http"
"net/url"
"regexp"
"strings"
)
// MidJourney 机器人
var logger = logger2.GetLogger()
type Bot struct {
config types.MidJourneyConfig
bot *discordgo.Session
name string
service *Service
}
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
bot, err := discordgo.New("Bot " + config.BotToken)
if err != nil {
logger.Error(err)
return nil, err
}
// use CDN reverse proxy
if config.UseCDN {
discordgo.SetEndpointDiscord(config.DiscordAPI)
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
bot.MjGateway = config.DiscordGateway + "/"
} else { // use proxy
discordgo.SetEndpointDiscord("https://discord.com")
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
bot.MjGateway = "wss://gateway.discord.gg"
if proxy != "" {
proxy, _ := url.Parse(proxy)
bot.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
bot.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
}
}
}
return &Bot{
config: config,
bot: bot,
name: name,
service: service,
}, nil
}
func (b *Bot) Run() error {
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate)
logger.Infof("Starting MidJourney %s", b.name)
err := b.bot.Open()
if err != nil {
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
return err
}
logger.Infof("Starting MidJourney %s successfully!", b.name)
return nil
}
type TaskStatus string
const (
Start = TaskStatus("Started")
Running = TaskStatus("Running")
Stopped = TaskStatus("Stopped")
Finished = TaskStatus("Finished")
)
type Image struct {
URL string `json:"url"`
ProxyURL string `json:"proxy_url"`
Filename string `json:"filename"`
Width int `json:"width"`
Height int `json:"height"`
Size int `json:"size"`
Hash string `json:"hash"`
}
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author == nil || m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
// parse content
req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: 0,
Status: Start}
b.service.Notify(req)
return
}
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
// ignore messages for other channels
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
return
}
// ignore messages for self
if m.Author == nil || m.Author.ID == s.State.User.ID {
return
}
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
var referenceId = ""
if m.ReferencedMessage != nil {
referenceId = m.ReferencedMessage.ID
}
if strings.Contains(m.Content, "(Stopped)") {
req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
Content: m.Content,
Progress: extractProgress(m.Content),
Status: Stopped}
b.service.Notify(req)
return
}
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
progress := extractProgress(content)
var status TaskStatus
if progress == 100 {
status = Finished
} else {
status = Running
}
for _, attachment := range attachments {
if attachment.Width == 0 || attachment.Height == 0 {
continue
}
image := Image{
URL: attachment.URL,
Height: attachment.Height,
ProxyURL: attachment.ProxyURL,
Width: attachment.Width,
Size: attachment.Size,
Filename: attachment.Filename,
Hash: extractHashFromFilename(attachment.Filename),
}
req := CBReq{
ChannelId: channelId,
MessageId: messageId,
ReferenceId: referenceId,
Image: image,
Prompt: extractPrompt(content),
Content: content,
Progress: progress,
Status: status,
}
b.service.Notify(req)
break // only get one image
}
}
// extract prompt from string
func extractPrompt(input string) string {
pattern := `\*\*(.*?)\*\*`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
func extractProgress(input string) int {
pattern := `\((\d+)\%\)`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(input)
if len(matches) > 1 {
return utils.IntValue(matches[1], 0)
}
return 100
}
func extractHashFromFilename(filename string) string {
if !strings.HasSuffix(filename, ".png") {
return ""
}
index := strings.LastIndex(filename, "_")
if index != -1 {
return filename[index+1 : len(filename)-4]
}
return ""
}

View File

@@ -1,159 +1,61 @@
package mj package mj
import ( import "chatplus/core/types"
"chatplus/core/types"
"errors"
"fmt"
"time"
"github.com/imroc/req/v3" type Client interface {
) Imagine(task types.MjTask) (ImageRes, error)
Blend(task types.MjTask) (ImageRes, error)
// MidJourney client SwapFace(task types.MjTask) (ImageRes, error)
Upscale(task types.MjTask) (ImageRes, error)
type Client struct { Variation(task types.MjTask) (ImageRes, error)
client *req.Client QueryTask(taskId string) (QueryRes, error)
Config types.MidJourneyConfig
apiURL string
} }
func NewClient(config types.MidJourneyConfig, proxy string) *Client { type ImageReq struct {
client := req.C().SetTimeout(10 * time.Second) BotType string `json:"botType,omitempty"`
var apiURL string Prompt string `json:"prompt,omitempty"`
// set proxy URL Dimensions string `json:"dimensions,omitempty"`
if config.UseCDN { Base64Array []string `json:"base64Array,omitempty"`
apiURL = config.DiscordAPI + "/api/v9/interactions" AccountFilter interface{} `json:"accountFilter,omitempty"`
} else { NotifyHook string `json:"notifyHook,omitempty"`
apiURL = "https://discord.com/api/v9/interactions" State string `json:"state,omitempty"`
if proxy != "" {
client.SetProxyURL(proxy)
}
}
return &Client{client: client, Config: config, apiURL: apiURL}
} }
func (c *Client) Imagine(task types.MjTask) error { type ImageRes struct {
interactionsReq := &InteractionsRequest{ Code int `json:"code"`
Type: 2, Description string `json:"description"`
ApplicationID: ApplicationID, Properties struct {
GuildID: c.Config.GuildId, } `json:"properties"`
ChannelID: c.Config.ChanelId, Result string `json:"result"`
SessionID: SessionID,
Data: map[string]any{
"version": "1166847114203123795",
"id": "938956540159881230",
"name": "imagine",
"type": "1",
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
},
},
"application_command": map[string]any{
"id": "938956540159881230",
"application_id": ApplicationID,
"version": "1118961510123847772",
"default_permission": true,
"default_member_permissions": nil,
"type": 1,
"nsfw": false,
"name": "imagine",
"description": "Create images with Midjourney",
"dm_permission": true,
"options": []map[string]any{
{
"type": 3,
"name": "prompt",
"description": "The prompt to imagine",
"required": true,
},
},
"attachments": []any{},
},
},
}
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %w%v", err, r.Err)
}
return nil
} }
func (c *Client) Blend(task types.MjTask) error { type ErrRes struct {
return errors.New("function not implemented") Error struct {
Message string `json:"message"`
} `json:"error"`
} }
func (c *Client) SwapFace(task types.MjTask) error { type QueryRes struct {
return errors.New("function not implemented") Action string `json:"action"`
} Buttons []struct {
CustomId string `json:"customId"`
// Upscale 放大指定的图片 Emoji string `json:"emoji"`
func (c *Client) Upscale(task types.MjTask) error { Label string `json:"label"`
flags := 0 Style int `json:"style"`
interactionsReq := &InteractionsRequest{ Type int `json:"type"`
Type: 3, } `json:"buttons"`
ApplicationID: ApplicationID, Description string `json:"description"`
GuildID: c.Config.GuildId, FailReason string `json:"failReason"`
ChannelID: c.Config.ChanelId, FinishTime int `json:"finishTime"`
MessageFlags: flags, Id string `json:"id"`
MessageID: task.MessageId, ImageUrl string `json:"imageUrl"`
SessionID: SessionID, Progress string `json:"progress"`
Data: map[string]any{ Prompt string `json:"prompt"`
"component_type": 2, PromptEn string `json:"promptEn"`
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), Properties struct {
}, } `json:"properties"`
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), StartTime int `json:"startTime"`
} State string `json:"state"`
Status string `json:"status"`
var res InteractionsResult SubmitTime int `json:"submitTime"`
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
return nil
} }

View File

@@ -1,204 +0,0 @@
package plus
import (
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
)
// Service MJ 绘画服务
type Service struct {
Name string // service Name
Client *Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
HandledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: client,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0),
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// if it's reference message, check if it's this channel's message
//if task.ChannelId != "" && task.ChannelId != s.Name {
// logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
// s.taskQueue.RPush(task)
// time.Sleep(time.Second)
// continue
//}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
var job model.MidJourneyJob
s.db.Where("id = ?", task.Id).First(&job)
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(task.UserId)
// restore img_call quota
if task.Type.String() != types.TaskUpscale.String() {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
}
// TODO: 任务提交失败,加入队列重试
continue
}
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout
s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道
job.TaskId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.HandledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
if s.Client.Config.CdnURL != "" {
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
} else {
job.OrgURL = task.ImageUrl
}
}
job.MessageId = task.Id
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&s.HandledTaskNum, -1)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@@ -1,8 +1,7 @@
package plus package mj
import ( import (
"chatplus/core/types" "chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils" "chatplus/utils"
"encoding/base64" "encoding/base64"
"errors" "errors"
@@ -13,62 +12,21 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var logger = logger2.GetLogger() // PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
// Client MidJourney Plus Client Config types.MjPlusConfig
type Client struct {
Config types.MidJourneyPlusConfig
apiURL string apiURL string
} }
func NewClient(config types.MidJourneyPlusConfig) *Client { func NewPlusClient(config types.MjPlusConfig) *PlusClient {
var apiURL string return &PlusClient{Config: config, apiURL: config.ApiURL}
if config.CdnURL != "" {
apiURL = config.CdnURL
} else {
apiURL = config.ApiURL
}
if config.Mode == "" {
config.Mode = "fast"
}
return &Client{Config: config, apiURL: apiURL}
} }
type ImageReq struct { func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
BotType string `json:"botType"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
type ImageRes struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
body := ImageReq{ body := ImageReq{
BotType: "MID_JOURNEY", BotType: "MID_JOURNEY",
Prompt: task.Prompt, Prompt: task.Prompt,
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 0), Base64Array: make([]string, 0),
} }
// 生成图片 Base64 编码 // 生成图片 Base64 编码
@@ -81,6 +39,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
} }
} }
logger.Info("API URL: ", apiURL)
var res ImageRes var res ImageRes
var errRes ErrRes var errRes ErrRes
r, err := req.C().R(). r, err := req.C().R().
@@ -90,9 +49,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes). SetErrorResult(&errRes).
Post(apiURL) Post(apiURL)
if err != nil { if err != nil {
errStr, _ := io.ReadAll(r.Body) return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
} }
if r.IsErrorState() { if r.IsErrorState() {
@@ -104,12 +61,11 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
} }
// Blend 融图 // Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
body := ImageReq{ body := ImageReq{
BotType: "MID_JOURNEY", BotType: "MID_JOURNEY",
Dimensions: "SQUARE", Dimensions: "SQUARE",
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 0), Base64Array: make([]string, 0),
} }
// 生成图片 Base64 编码 // 生成图片 Base64 编码
@@ -132,8 +88,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes). SetErrorResult(&errRes).
Post(apiURL) Post(apiURL)
if err != nil { if err != nil {
errStr, _ := io.ReadAll(r.Body) return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
} }
if r.IsErrorState() { if r.IsErrorState() {
@@ -144,7 +99,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
} }
// SwapFace 换脸 // SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码 // 生成图片 Base64 编码
if len(task.ImgArr) != 2 { if len(task.ImgArr) != 2 {
@@ -171,8 +126,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
"accountFilter": gin.H{ "accountFilter": gin.H{
"instanceId": "", "instanceId": "",
}, },
"notifyHook": c.Config.NotifyURL, "state": "",
"state": "",
} }
var res ImageRes var res ImageRes
var errRes ErrRes var errRes ErrRes
@@ -183,8 +137,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes). SetErrorResult(&errRes).
Post(apiURL) Post(apiURL)
if err != nil { if err != nil {
errStr, _ := io.ReadAll(r.Body) return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
} }
if r.IsErrorState() { if r.IsErrorState() {
@@ -195,11 +148,10 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
} }
// Upscale 放大指定的图片 // Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{ body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId, "taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
} }
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
var res ImageRes var res ImageRes
@@ -222,11 +174,10 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
} }
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 // Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{ body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId, "taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
} }
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
var res ImageRes var res ImageRes
@@ -248,32 +199,7 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
return res, nil return res, nil
} }
type QueryRes struct { func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes var res QueryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
@@ -290,3 +216,5 @@ func (c *Client) QueryTask(taskId string) (QueryRes, error) {
return res, nil return res, nil
} }
var _ Client = &PlusClient{}

View File

@@ -2,13 +2,12 @@ package mj
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service/mj/plus" logger2 "chatplus/logger"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"fmt" "fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"strings"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@@ -16,7 +15,7 @@ import (
// ServicePool Mj service pool // ServicePool Mj service pool
type ServicePool struct { type ServicePool struct {
services []interface{} services []*Service
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
@@ -24,8 +23,10 @@ type ServicePool struct {
Clients *types.LMap[uint, *types.WsClient] // UserId => Client Clients *types.LMap[uint, *types.WsClient] // UserId => Client
} }
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]interface{}, 0) services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
@@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
if config.Enabled == false { if config.Enabled == false {
continue continue
} }
client := plus.NewClient(config) cli := NewPlusClient(config)
name := fmt.Sprintf("mj-service-plus-%d", k) name := fmt.Sprintf("mj-plus-service-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client) service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
go func() { go func() {
servicePlus.Run() service.Run()
}() }()
services = append(services, servicePlus) services = append(services, service)
} }
if len(services) == 0 { for k, config := range appConfig.MjProxyConfigs {
// create mj client and service if config.Enabled == false {
for k, config := range appConfig.MjConfigs { continue
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
} }
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
go func() {
service.Run()
}()
services = append(services, service)
} }
return &ServicePool{ return &ServicePool{
@@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
client := p.Clients.Get(userId) cli := p.Clients.Get(userId)
if client == nil { if cli == nil {
continue continue
} }
err = client.Send([]byte("Task Updated")) err = cli.Send([]byte("Task Updated"))
if err != nil { if err != nil {
continue continue
} }
@@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() {
logger.Infof("try to download image: %s", v.OrgURL) logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string var imgURL string
var err error var err error
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil { if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId) task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 { if len(task.Buttons) > 0 {
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId) v.Hash = GetImageHash(task.Buttons[0].CustomId)
} }
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false) imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else { } else {
@@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() {
v.ImgURL = imgURL v.ImgURL = imgURL
p.db.Updates(&v) p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId)) cli := p.Clients.Get(uint(v.UserId))
if client == nil { if cli == nil {
continue continue
} }
err = client.Send([]byte("Task Updated")) err = cli.Send([]byte("Task Updated"))
if err != nil { if err != nil {
continue continue
} }
@@ -167,52 +149,42 @@ func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0 return len(p.services) > 0
} }
func (p *ServicePool) Notify(data plus.CBReq) error {
logger.Debugf("收到任务回调:%+v", data)
var job model.MidJourneyJob
res := p.db.Where("task_id = ?", data.Id).First(&job)
if res.Error != nil {
return fmt.Errorf("非法任务:%s", data.Id)
}
// 任务已经拉取完成
if job.Progress == 100 {
return nil
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(job)
}
return nil
}
// SyncTaskProgress 异步拉取任务 // SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() { func (p *ServicePool) SyncTaskProgress() {
go func() { go func() {
var items []model.MidJourneyJob var items []model.MidJourneyJob
for { for {
res := p.db.Where("progress >= ? AND progress < ?", 0, 100).Find(&items) res := p.db.Where("progress < ?", 100).Find(&items)
if res.Error != nil { if res.Error != nil {
continue continue
} }
for _, v := range items { for _, job := range items {
// 30 分钟还没完成的任务直接删除 // 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(v.CreatedAt) > time.Minute*30 { if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&v) // 删除任务
// 非放大任务,退回绘图次数 p.db.Delete(&job)
if v.Type != types.TaskUpscale.String() { // 退回算力
p.db.Model(&model.User{}).Where("id = ?", v.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
} }
continue
} }
if !strings.HasPrefix(v.ChannelId, "mj-service-plus") { if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
continue _ = servicePlus.Notify(job)
}
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(v)
} }
} }
@@ -221,12 +193,10 @@ func (p *ServicePool) SyncTaskProgress() {
}() }()
} }
func (p *ServicePool) getServicePlus(name string) *plus.Service { func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services { for _, s := range p.services {
if servicePlus, ok := s.(*plus.Service); ok { if s.Name == name {
if servicePlus.Name == name { return s
return servicePlus
}
} }
} }
return nil return nil

View File

@@ -0,0 +1,176 @@
package mj
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"errors"
"fmt"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
body := ImageReq{
Prompt: task.Prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
all, err := io.ReadAll(r.Body)
logger.Info(string(all))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@@ -2,8 +2,11 @@ package mj
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils"
"fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -13,24 +16,24 @@ import (
// Service MJ 绘画服务 // Service MJ 绘画服务
type Service struct { type Service struct {
name string // service name Name string // service Name
client *Client // MJ client Client Client // MJ Client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number HandledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64 taskTimeout int64
} }
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
return &Service{ return &Service{
name: name, Name: name,
db: db, db: db,
taskQueue: taskQueue, taskQueue: taskQueue,
notifyQueue: notifyQueue, notifyQueue: notifyQueue,
client: client, Client: cli,
taskTimeout: timeout, taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum, maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0), taskStartTimes: make(map[int]time.Time, 0),
@@ -38,7 +41,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.name) logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for { for {
s.checkTasks() s.checkTasks()
if !s.canHandleTask() { if !s.canHandleTask() {
@@ -55,57 +58,72 @@ func (s *Service) Run() {
continue continue
} }
// if it's reference message, check if it's this channel's message // 如果配置了多个中转平台的 API KEY
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId { // U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task) s.taskQueue.RPush(task)
time.Sleep(time.Second) time.Sleep(time.Second)
continue continue
} }
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) // 如果是 mj-proxy 则自动翻译提示词
if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type { switch task.Type {
case types.TaskImage: case types.TaskImage:
err = s.client.Imagine(task) res, err = s.Client.Imagine(task)
break break
case types.TaskUpscale: case types.TaskUpscale:
err = s.client.Upscale(task) res, err = s.Client.Upscale(task)
break break
case types.TaskVariation: case types.TaskVariation:
err = s.client.Variation(task) res, err = s.Client.Variation(task)
break break
case types.TaskBlend: case types.TaskBlend:
err = s.client.Blend(task) res, err = s.Client.Blend(task)
break break
case types.TaskSwapFace: case types.TaskSwapFace:
err = s.client.SwapFace(task) res, err = s.Client.SwapFace(task)
break break
} }
if err != nil { var job model.MidJourneyJob
logger.Error("绘画任务执行失败:", err.Error()) s.db.Where("id = ?", task.Id).First(&job)
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg
// update the task progress // update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ s.db.Updates(&job)
"progress": -1, // 任务失败,通知前端
"err_msg": err.Error(),
})
s.notifyQueue.RPush(task.UserId) s.notifyQueue.RPush(task.UserId)
// restore img_call quota
if task.Type.String() != types.TaskUpscale.String() {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
}
continue continue
} }
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout // lock the task until the execute timeout
s.taskStartTimes[int(task.Id)] = time.Now() s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1) atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道
job.TaskId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
} }
} }
// check if current service instance can handle more task // check if current service instance can handle more task
func (s *Service) canHandleTask() bool { func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum) handledNum := atomic.LoadInt32(&s.HandledTaskNum)
return handledNum < s.maxHandleTaskNum return handledNum < s.maxHandleTaskNum
} }
@@ -114,64 +132,75 @@ func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes { for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout { if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k) delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1) atomic.AddInt32(&s.HandledTaskNum, -1)
// delete task from database // delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
} }
} }
} }
func (s *Service) Notify(data CBReq) { type CBReq struct {
// extract the task ID Id string `json:"id"`
split := strings.Split(data.Prompt, " ") Action string `json:"action"`
var job model.MidJourneyJob Status string `json:"status"`
res := s.db.Where("message_id = ?", data.MessageId).First(&job) Prompt string `json:"prompt"`
if res.Error == nil && data.Status == Finished { PromptEn string `json:"promptEn"`
logger.Warn("重复消息:", data.MessageId) Description string `json:"description"`
return SubmitTime int64 `json:"submitTime"`
} StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC") Progress string `json:"progress"`
if data.ReferenceId != "" { ImageUrl string `json:"imageUrl"`
tx = tx.Where("reference_id = ?", data.ReferenceId) FailReason interface{} `json:"failReason"`
} else { Properties struct {
tx = tx.Where("task_id = ?", split[0]) FinalPrompt string `json:"finalPrompt"`
} } `json:"properties"`
// fixed: 修复 U/V 操作任务混淆覆盖的 Bug }
if strings.Contains(data.Prompt, "** - Image #") { // for upscale
tx = tx.Where("type = ?", types.TaskUpscale.String()) func (s *Service) Notify(job model.MidJourneyJob) error {
} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations task, err := s.Client.QueryTask(job.TaskId)
tx = tx.Where("type = ?", types.TaskVariation.String()) if err != nil {
} return err
res = tx.First(&job) }
if res.Error != nil {
logger.Warn("非法任务:", res.Error) // 任务执行失败了
return if task.FailReason != "" {
} s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
job.ChannelId = data.ChannelId "err_msg": task.FailReason,
job.MessageId = data.MessageId })
job.ReferenceId = data.ReferenceId return fmt.Errorf("task failed: %v", task.FailReason)
job.Progress = data.Progress }
job.Prompt = data.Prompt
job.Hash = data.Image.Hash if len(task.Buttons) > 0 {
job.OrgURL = data.Image.URL job.Hash = GetImageHash(task.Buttons[0].CustomId)
if s.client.Config.UseCDN { }
job.UseProxy = true oldProgress := job.Progress
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL) job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
} job.Prompt = task.PromptEn
if task.ImageUrl != "" {
res = s.db.Updates(&job) job.OrgURL = task.ImageUrl
if res.Error != nil { }
logger.Error("error with update job: ", res.Error) job.MessageId = task.Id
return tx := s.db.Updates(&job)
} if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
if data.Status == Finished { }
// release lock task if task.Status == "SUCCESS" {
atomic.AddInt32(&s.handledTaskNum, -1) // release lock task
} atomic.AddInt32(&s.HandledTaskNum, -1)
}
s.notifyQueue.RPush(job.UserId) // 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
} }

View File

@@ -1,35 +0,0 @@
package mj
const (
ApplicationID string = "936929561302675456"
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
)
type InteractionsRequest struct {
Type int `json:"type"`
ApplicationID string `json:"application_id"`
MessageFlags int `json:"message_flags,omitempty"`
MessageID string `json:"message_id,omitempty"`
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
SessionID string `json:"session_id"`
Data map[string]any `json:"data"`
Nonce string `json:"nonce,omitempty"`
}
type InteractionsResult struct {
Code int `json:"code"`
Message string
Error map[string]any
}
type CBReq struct {
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`
Content string `json:"content"`
Prompt string `json:"prompt"`
Status TaskStatus `json:"status"`
Progress int `json:"progress"`
}

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"chatplus/core/types" "chatplus/core/types"
"chatplus/utils" "chatplus/utils"
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"path/filepath" "path/filepath"
@@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
} }
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) Delete(fileURL string) error { func (s AliYunOss) Delete(fileURL string) error {
var objectKey string var objectKey string
if strings.HasPrefix(fileURL, "http") { if strings.HasPrefix(fileURL, "http") {

View File

@@ -3,13 +3,13 @@ package oss
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/utils" "chatplus/utils"
"encoding/base64"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type LocalStorage struct { type LocalStorage struct {
@@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
} }
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
err = os.WriteFile(filePath, imageData, 0644)
if err != nil {
return "", fmt.Errorf("error writing to file:%v", err)
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) Delete(fileURL string) error { func (s LocalStorage) Delete(fileURL string) error {
if _, err := os.Stat(fileURL); err == nil { if _, err := os.Stat(fileURL); err == nil {
return os.Remove(fileURL) return os.Remove(fileURL)

View File

@@ -4,6 +4,7 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/utils" "chatplus/utils"
"context" "context"
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"path/filepath" "path/filepath"
@@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil }, nil
} }
func (s MiniOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
objectKey,
strings.NewReader(string(imageData)),
int64(len(imageData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) Delete(fileURL string) error { func (s MiniOss) Delete(fileURL string) error {
var objectKey string var objectKey string
if strings.HasPrefix(fileURL, "http") { if strings.HasPrefix(fileURL, "http") {

View File

@@ -5,6 +5,7 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/utils" "chatplus/utils"
"context" "context"
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"path/filepath" "path/filepath"
@@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
} }
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error { func (s QinNiuOss) Delete(fileURL string) error {
var objectKey string var objectKey string
if strings.HasPrefix(fileURL, "http") { if strings.HasPrefix(fileURL, "http") {

View File

@@ -17,5 +17,6 @@ type File struct {
type Uploader interface { type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error) PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error) PutImg(imageURL string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error Delete(fileURL string) error
} }

View File

@@ -29,16 +29,17 @@ type JPayReq struct {
OutTradeNo string `json:"out_trade_no"` OutTradeNo string `json:"out_trade_no"`
Subject string `json:"body"` Subject string `json:"body"`
NotifyURL string `json:"notify_url"` NotifyURL string `json:"notify_url"`
ReturnURL string `json:"callback_url"`
} }
type JPayReps struct { type JPayReps struct {
CodeUrl string `json:"code_url"`
OutTradeNo string `json:"out_trade_no"` OutTradeNo string `json:"out_trade_no"`
OrderId string `json:"payjs_order_id"` OrderId string `json:"payjs_order_id"`
Qrcode string `json:"qrcode"`
ReturnCode int `json:"return_code"` ReturnCode int `json:"return_code"`
ReturnMsg string `json:"return_msg"` ReturnMsg string `json:"return_msg"`
Sign string `json:"Sign"` Sign string `json:"Sign"`
TotalFee string `json:"total_fee"` TotalFee string `json:"total_fee"`
CodeUrl string `json:"code_url,omitempty"`
Qrcode string `json:"qrcode,omitempty"`
} }
func (r JPayReps) IsOK() bool { func (r JPayReps) IsOK() bool {
@@ -78,8 +79,14 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
return data return data
} }
func (js *PayJS) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *PayJS) sign(params url.Values) string { func (js *PayJS) sign(params url.Values) string {
params.Del(`Sign`) params.Del(`sign`)
var keys = make([]string, 0, 0) var keys = make([]string, 0, 0)
for key := range params { for key := range params {
if params.Get(key) != `` { if params.Get(key) != `` {
@@ -109,7 +116,7 @@ func (js *PayJS) Check(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL) apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{} params := url.Values{}
params.Add("payjs_order_id", tradeNo) params.Add("payjs_order_id", tradeNo)
params.Add("Sign", js.sign(params)) params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode()) data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data) resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
defer resp.Body.Close() defer resp.Body.Close()
@@ -135,6 +142,7 @@ func (js *PayJS) Check(tradeNo string) error {
if r.ReturnCode == 1 && r.Status == 1 { if r.ReturnCode == 1 && r.Status == 1 {
return nil return nil
} else { } else {
logger.Errorf("PayJs 支付验证响应:%s", string(body))
return errors.New("order not paid") return errors.New("order not paid")
} }
} }

View File

@@ -4,7 +4,9 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model"
"fmt" "fmt"
"time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
@@ -14,6 +16,7 @@ type ServicePool struct {
services []*Service services []*Service
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client Clients *types.LMap[uint, *types.WsClient] // UserId => Client
} }
@@ -22,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service // create mj client and service
for k, config := range appConfig.SdConfigs { for _, config := range appConfig.SdConfigs {
if config.Enabled == false { if config.Enabled == false {
continue continue
} }
// create sd service // create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k) name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager) service := NewService(name, config, taskQueue, notifyQueue, db, manager)
// run sd service // run sd service
go func() { go func() {
service.Run() service.Run()
@@ -42,6 +45,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
taskQueue: taskQueue, taskQueue: taskQueue,
notifyQueue: notifyQueue, notifyQueue: notifyQueue,
services: services, services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](), Clients: types.NewLMap[uint, *types.WsClient](),
} }
} }
@@ -52,6 +56,68 @@ func (p *ServicePool) PushTask(task types.SdTask) {
p.taskQueue.RPush(task) p.taskQueue.RPush(task)
} }
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
p.db.Delete(&job)
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
}
}()
}
// HasAvailableService check if it has available mj service in pool // HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool { func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0 return len(p.services) > 0

View File

@@ -2,69 +2,59 @@ package sd
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils" "chatplus/utils"
"encoding/json"
"fmt" "fmt"
"io"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
"strings"
"time"
) )
// SD 绘画服务 // SD 绘画服务
type Service struct { type Service struct {
httpClient *req.Client httpClient *req.Client
config types.StableDiffusionConfig config types.StableDiffusionConfig
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
name string // service name name string // service name
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
} }
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{ return &Service{
name: name, name: name,
config: config, config: config,
httpClient: req.C(), httpClient: req.C(),
taskQueue: taskQueue, taskQueue: taskQueue,
notifyQueue: notifyQueue, notifyQueue: notifyQueue,
db: db, db: db,
uploadManager: manager, uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time),
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
for { for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask var task types.SdTask
err := s.taskQueue.LPop(&task) err := s.taskQueue.LPop(&task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
continue continue
} }
// 翻译提示词
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
if err == nil {
task.Params.Prompt = content
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task) err = s.Txt2Img(task)
if err != nil { if err != nil {
@@ -74,240 +64,135 @@ func (s *Service) Run() {
"progress": -1, "progress": -1,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
// restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
// 通知前端,任务失败 // 通知前端,任务失败
s.notifyQueue.RPush(task.UserId) s.notifyQueue.RPush(task.UserId)
continue continue
} }
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
} }
} }
// check if current service instance can handle more task // Txt2ImgReq 文生图请求实体
func (s *Service) canHandleTask() bool { type Txt2ImgReq struct {
handledNum := atomic.LoadInt32(&s.handledTaskNum) Prompt string `json:"prompt"`
return handledNum < s.maxHandleTaskNum NegativePrompt string `json:"negative_prompt"`
Seed int64 `json:"seed,omitempty"`
Steps int `json:"steps"`
CfgScale float32 `json:"cfg_scale"`
Width int `json:"width"`
Height int `json:"height"`
SamplerName string `json:"sampler_name"`
EnableHr bool `json:"enable_hr,omitempty"`
HrScale int `json:"hr_scale,omitempty"`
HrUpscaler string `json:"hr_upscaler,omitempty"`
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
ForceTaskId string `json:"force_task_id,omitempty"`
} }
// remove the expired tasks // Txt2ImgResp 文生图响应实体
func (s *Service) checkTasks() { type Txt2ImgResp struct {
for k, t := range s.taskStartTimes { Images []string `json:"images"`
if time.Now().Unix()-t.Unix() > s.taskTimeout { Parameters struct {
delete(s.taskStartTimes, k) } `json:"parameters"`
atomic.AddInt32(&s.handledTaskNum, -1) Info string `json:"info"`
// delete task from database }
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
} // TaskProgressResp 任务进度响应实体
} type TaskProgressResp struct {
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
} }
// Txt2Img 文生图 API // Txt2Img 文生图 API
func (s *Service) Txt2Img(task types.SdTask) error { func (s *Service) Txt2Img(task types.SdTask) error {
var taskInfo TaskInfo body := Txt2ImgReq{
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath) Prompt: task.Params.Prompt,
if err != nil { NegativePrompt: task.Params.NegativePrompt,
return fmt.Errorf("error with load text2img json template file: %s", err.Error()) Steps: task.Params.Steps,
CfgScale: task.Params.CfgScale,
Width: task.Params.Width,
Height: task.Params.Height,
SamplerName: task.Params.Sampler,
} }
if task.Params.Seed > 0 {
err = json.Unmarshal(bytes, &taskInfo) body.Seed = task.Params.Seed
if err != nil {
return fmt.Errorf("error with decode json params: %s", err.Error())
} }
if task.Params.HdFix {
data := taskInfo.Data body.EnableHr = true
params := task.Params body.HrScale = task.Params.HdScale
data[ParamKeys["task_id"]] = params.TaskId body.HrUpscaler = task.Params.HdScaleAlg
data[ParamKeys["prompt"]] = params.Prompt body.HrSecondPassSteps = task.Params.HdSteps
data[ParamKeys["negative_prompt"]] = params.NegativePrompt body.DenoisingStrength = task.Params.HdRedrawRate
data[ParamKeys["steps"]] = params.Steps }
data[ParamKeys["sampler"]] = params.Sampler var res Txt2ImgResp
// @fix bug: 有些 stable diffusion 没有面部修复功能 var errChan = make(chan error)
//data[ParamKeys["face_fix"]] = params.FaceFix apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
data[ParamKeys["cfg_scale"]] = params.CfgScale logger.Debugf("send image request to %s", apiURL)
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSteps
taskInfo.SessionId = task.SessionId
taskInfo.TaskId = params.TaskId
taskInfo.Data = data
taskInfo.JobId = task.Id
taskInfo.UserId = uint(task.UserId)
go func() { go func() {
s.runTask(taskInfo, s.httpClient) response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
}()
return nil
}
// 执行任务
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
if err != nil { if err != nil {
cbReq.Message = "error with send request: " + err.Error() errChan <- err
cbReq.Success = false
result <- cbReq
return return
} }
if response.IsErrorState() { if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body) errChan <- fmt.Errorf("error http code status: %v", response.Status)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
return return
} }
var images []struct { // 保存 Base64 图片
Name string `json:"name"` imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
if err != nil { if err != nil {
cbReq.Message = "error with decode image:" + err.Error() errChan <- fmt.Errorf("error with upload image: %v", err)
cbReq.Success = false
result <- cbReq
return return
} }
// 获取绘画真实的 seed
var info map[string]any var info map[string]interface{}
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info) err = utils.JsonDecode(res.Info, &info)
if err != nil { if err != nil {
logger.Error(res.Data) errChan <- fmt.Errorf("error with decode task response: %v", err)
cbReq.Message = "error with decode image url:" + err.Error()
cbReq.Success = false
result <- cbReq
return return
} }
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
// 获取真实的 seed 值 s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
cbReq.ImageName = images[0].Name errChan <- nil
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
cbReq.Seed = seed
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
}() }()
for { for {
select { select {
case value := <-result: case err := <-errChan: // 任务完成
s.callback(value) if err != nil {
return return err
}
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(task.UserId)
return nil
default: default:
var progressReq = map[string]any{ err, resp := s.checkTaskProgress()
"id_task": taskInfo.TaskId, // 更新任务进度
"id_live_preview": 1, if err == nil && resp.Progress > 0 {
logger.Debugf("Check task progress: %+v", resp.Progress)
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(task.UserId)
} }
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
s.callback(cbReq)
time.Sleep(time.Second) time.Sleep(time.Second)
} }
} }
} }
func (s *Service) callback(data CBReq) { // 执行任务
// release task num func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
atomic.AddInt32(&s.handledTaskNum, -1) apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
if data.Success { // 任务成功 var res TaskProgressResp
var job model.SdJob response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
res := s.db.Where("id = ?", data.JobId).First(&job) if err != nil {
if res.Error != nil { return err, nil
logger.Warn("非法任务:", res.Error)
return
}
// 更新任务进度
job.Progress = data.Progress
// 更新任务 seed
var params types.SdTaskParams
err := utils.JsonDecode(job.Params, &params)
if err != nil {
logger.Error("任务解析失败:", err)
return
}
params.Seed = data.Seed
if data.ImageName != "" { // 下载图片
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
if data.Progress == 100 {
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imageURL
}
}
job.Params = utils.JsonEncode(params)
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
logger.Debugf("绘图进度:%d", data.Progress)
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// update the task progress
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": data.Message,
})
// restore img_calls
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
} }
if response.IsErrorState() {
return fmt.Errorf("error http code status: %v", response.Status), nil
}
return nil, &res
} }

4
api/service/types.go Normal file
View File

@@ -0,0 +1,4 @@
package service
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@@ -41,7 +41,7 @@ func parseTransactionMessage(xmlData string) *Message {
} }
if se.Name.Local == "weapp_path" || se.Name.Local == "url" { if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
if err := decoder.DecodeElement(&value, &se); err == nil { if err := decoder.DecodeElement(&value, &se); err == nil {
if strings.Contains(value, "trans_id=") { if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
message.Url = value message.Url = value
} }
} }

View File

@@ -39,13 +39,14 @@ func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor {
func (e *XXLJobExecutor) Run() error { func (e *XXLJobExecutor) Run() error {
e.executor.RegTask("ClearOrders", e.ClearOrders) e.executor.RegTask("ClearOrders", e.ClearOrders)
e.executor.RegTask("ResetVipCalls", e.ResetVipCalls) e.executor.RegTask("ResetVipPower", e.ResetVipPower)
e.executor.RegTask("ResetUserPower", e.ResetUserPower)
return e.executor.Run() return e.executor.Run()
} }
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功 // ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) { func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Debug("执行清理未支付订单...") logger.Info("执行清理未支付订单...")
var sysConfig model.Config var sysConfig model.Config
res := e.db.Where("marker", "system").First(&sysConfig) res := e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil { if res.Error != nil {
@@ -64,15 +65,17 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
timeout := time.Now().Unix() - int64(config.OrderPayTimeout) timeout := time.Now().Unix() - int64(config.OrderPayTimeout)
start := utils.Stamp2str(timeout) start := utils.Stamp2str(timeout)
// 这里不是用软删除,而是永久删除订单 // 这里不是用软删除,而是永久删除订单
res = e.db.Unscoped().Where("status != ? AND created_at < ?", types.OrderPaidSuccess, start).Delete(&model.Order{}) res = e.db.Unscoped().Where("status IN ? AND created_at < ?", []types.OrderStatus{types.OrderNotPaid, types.OrderScanned}, start).Delete(&model.Order{})
return fmt.Sprintf("Clear order successfully, affect rows: %d", res.RowsAffected) logger.Infof("Clear order successfully, affect rows: %d", res.RowsAffected)
return "success"
} }
// ResetVipCalls 清理过期的 VIP 会员 // ResetVipPower 重置VIP会员算力
func (e *XXLJobExecutor) ResetVipCalls(cxt context.Context, param *xxl.RunReq) (msg string) { // 自动将 VIP 会员的算力补充到每月赠送的最大值
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("开始进行月底账号盘点...") logger.Info("开始进行月底账号盘点...")
var users []model.User var users []model.User
res := e.db.Where("vip = ?", 1).Find(&users) res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
if res.Error != nil { if res.Error != nil {
return "No vip users found" return "No vip users found"
} }
@@ -89,60 +92,92 @@ func (e *XXLJobExecutor) ResetVipCalls(cxt context.Context, param *xxl.RunReq) (
return "error with decode system config: " + err.Error() return "error with decode system config: " + err.Error()
} }
// 获取本月月初时间
currentTime := time.Now()
year, month, _ := currentTime.Date()
firstOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, currentTime.Location()).Unix()
for _, u := range users { for _, u := range users {
// 账号到期,直接清零 // 处理过期的 VIP
if u.ExpiredTime <= currentTime.Unix() { if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
logger.Info("账号过期:", u.Username)
u.Calls = 0
u.Vip = false u.Vip = false
} else { e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
if u.Calls <= 0 { continue
u.Calls = 0
}
if u.ImgCalls <= 0 {
u.ImgCalls = 0
}
// 如果该用户当月有充值点卡,则将点卡中未用完的点数结余到下个月
var orders []model.Order
e.db.Debug().Where("user_id = ? AND pay_time > ?", u.Id, firstOfMonth).Find(&orders)
var calls = 0
var imgCalls = 0
for _, o := range orders {
var remark types.OrderRemark
err = utils.JsonDecode(o.Remark, &remark)
if err != nil {
continue
}
if remark.Days > 0 { // 会员续费
continue
}
calls += remark.Calls
imgCalls += remark.ImgCalls
}
if u.Calls > calls { // 本月套餐没有用完
u.Calls = calls + config.VipMonthCalls
} else {
u.Calls = u.Calls + config.VipMonthCalls
}
if u.ImgCalls > imgCalls { // 本月套餐没有用完
u.ImgCalls = imgCalls + config.VipMonthImgCalls
} else {
u.ImgCalls = u.ImgCalls + config.VipMonthImgCalls
}
logger.Infof("%s 点卡结余:%d", u.Username, calls)
} }
u.Tokens = 0
// update user // update user
e.db.Updates(&u) tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: config.VipMonthPower,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
}
} }
logger.Info("月底盘点完成!") logger.Info("月底盘点完成!")
return "success" return "success"
} }
func (e *XXLJobExecutor) ResetUserPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("今日算力派发开始:", time.Now())
var users []model.User
res := e.db.Where("status", 1).Find(&users)
if res.Error != nil {
return "No matching users"
}
var sysConfig model.Config
res = e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return "error with get system config: " + res.Error.Error()
}
var config types.SystemConfig
err := utils.JsonDecode(sysConfig.Config, &config)
if err != nil {
return "error with decode system config: " + err.Error()
}
if config.DailyPower <= 0 {
return "success"
}
var counter = 0
var totalPower = 0
for _, u := range users {
if u.Power >= config.DailyPower {
continue
}
var power = config.DailyPower - u.Power
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
// 记录算力充值日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerGift,
Amount: power,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统赠送",
Remark: fmt.Sprintf("系统每日算力派发,今日额度:%d", config.DailyPower),
CreatedAt: time.Now(),
})
}
counter++
totalPower += power
}
logger.Infof("今日派发算力结束!累计派发 %d 人,累计派发算力:%d", counter, totalPower)
return "success"
}
type customLogger struct{} type customLogger struct{}
func (l *customLogger) Info(format string, a ...interface{}) { func (l *customLogger) Info(format string, a ...interface{}) {

View File

@@ -0,0 +1,11 @@
package model
type AdminUser struct {
BaseModel
Username string
Password string
Salt string // 密码盐
Status bool `gorm:"default:true"` // 当前状态
LastLoginAt int64 // 最后登录时间
LastLoginIp string // 最后登录 IP
}

View File

@@ -9,6 +9,6 @@ type ApiKey struct {
Value string // API Key 的值 Value string // API Key 的值
ApiURL string // 当前 KEY 的 API 地址 ApiURL string // 当前 KEY 的 API 地址
Enabled bool // 是否启用 Enabled bool // 是否启用
UseProxy bool // 是否使用代理访问 API URL ProxyURL string // 代理地址
LastUsedAt int64 // 最后使用时间 LastUsedAt int64 // 最后使用时间
} }

View File

@@ -2,11 +2,14 @@ package model
type ChatModel struct { type ChatModel struct {
BaseModel BaseModel
Platform string Platform string
Name string Name string
Value string // API Key 的值 Value string // API Key 的值
SortNum int SortNum int
Enabled bool Enabled bool
Weight int // 对话权重,每次对话扣减多少次对话额度 Power int // 每次对话消耗算力
Open bool // 是否开放模型给所有人使用 Open bool // 是否开放模型给所有人使用
MaxTokens int // 最大响应长度
MaxContext int // 最大上下文长度
Temperature float32 // 模型温度
} }

View File

@@ -4,7 +4,7 @@ import "time"
type File struct { type File struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
UserId uint UserId int
Name string Name string
ObjKey string ObjKey string
URL string URL string

View File

@@ -10,6 +10,6 @@ type InviteLog struct {
UserId uint UserId uint
Username string Username string
InviteCode string InviteCode string
Reward string `gorm:"column:reward_json"` // 邀请奖励 Remark string
CreatedAt time.Time CreatedAt time.Time
} }

View File

@@ -18,6 +18,7 @@ type MidJourneyJob struct {
UseProxy bool // 是否使用反代加载图片 UseProxy bool // 是否使用反代加载图片
Publish bool //是否发布图片到画廊 Publish bool //是否发布图片到画廊
ErrMsg string // 报错信息 ErrMsg string // 报错信息
Power int // 消耗算力
CreatedAt time.Time CreatedAt time.Time
} }

View File

@@ -0,0 +1,20 @@
package model
import (
"chatplus/core/types"
"time"
)
// PowerLog 算力消费日志
type PowerLog struct {
Id uint `gorm:"primarykey;column:id"`
UserId uint
Username string
Type types.PowerType
Amount int
Balance int
Model string // 模型
Remark string // 备注
Mark types.PowerMark // 资金类型
CreatedAt time.Time
}

View File

@@ -7,8 +7,7 @@ type Product struct {
Price float64 Price float64
Discount float64 Discount float64
Days int Days int
Calls int Power int
ImgCalls int
Enabled bool Enabled bool
Sales int Sales int
SortNum int SortNum int

View File

@@ -13,6 +13,7 @@ type SdJob struct {
Params string Params string
Publish bool //是否发布图片到画廊 Publish bool //是否发布图片到画廊
ErrMsg string // 报错信息 ErrMsg string // 报错信息
Power int // 消耗算力
CreatedAt time.Time CreatedAt time.Time
} }

View File

@@ -7,9 +7,7 @@ type User struct {
Password string Password string
Avatar string Avatar string
Salt string // 密码盐 Salt string // 密码盐
TotalTokens int64 // 总消耗 tokens Power int // 剩余算力
Calls int // 剩余对话次数
ImgCalls int // 剩余绘图次数
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色 ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型 ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型
@@ -18,5 +16,4 @@ type User struct {
LastLoginAt int64 // 最后登录时间 LastLoginAt int64 // 最后登录时间
LastLoginIp string // 最后登录 IP LastLoginIp string // 最后登录 IP
Vip bool // 是否 VIP 会员 Vip bool // 是否 VIP 会员
Tokens int
} }

View File

@@ -0,0 +1,10 @@
package vo
type AdminUser struct {
BaseVo
Username string `json:"username"`
Status bool `json:"status"` // 当前状态
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
RoleIds interface{} `json:"role_ids"` //角色ids
}

View File

@@ -9,6 +9,6 @@ type ApiKey struct {
Value string `json:"value"` // API Key 的值 Value string `json:"value"` // API Key 的值
ApiURL string `json:"api_url"` ApiURL string `json:"api_url"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
UseProxy bool `json:"use_proxy"` ProxyURL string `json:"proxy_url"`
LastUsedAt int64 `json:"last_used_at"` // 最后使用时间 LastUsedAt int64 `json:"last_used_at"` // 最后使用时间
} }

View File

@@ -2,11 +2,14 @@ package vo
type ChatModel struct { type ChatModel struct {
BaseVo BaseVo
Platform string `json:"platform"` Platform string `json:"platform"`
Name string `json:"name"` Name string `json:"name"`
Value string `json:"value"` Value string `json:"value"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
SortNum int `json:"sort_num"` SortNum int `json:"sort_num"`
Weight int `json:"weight"` Power int `json:"power"`
Open bool `json:"open"` Open bool `json:"open"`
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
} }

View File

@@ -5,6 +5,5 @@ import "chatplus/core/types"
type Config struct { type Config struct {
Id uint `json:"id"` Id uint `json:"id"`
Key string `json:"key"` Key string `json:"key"`
ChatConfig types.ChatConfig `json:"chat_config"`
SystemConfig types.SystemConfig `json:"system_config"` SystemConfig types.SystemConfig `json:"system_config"`
} }

View File

@@ -1,15 +1,11 @@
package vo package vo
import (
"chatplus/core/types"
)
type InviteLog struct { type InviteLog struct {
Id uint `json:"id"` Id uint `json:"id"`
InviterId uint `json:"inviter_id"` InviterId uint `json:"inviter_id"`
UserId uint `json:"user_id"` UserId uint `json:"user_id"`
Username string `json:"username"` Username string `json:"username"`
InviteCode string `json:"invite_code"` InviteCode string `json:"invite_code"`
Reward types.InviteReward `json:"reward"` Remark string `json:"remark"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
} }

View File

@@ -18,5 +18,6 @@ type MidJourneyJob struct {
UseProxy bool `json:"use_proxy"` UseProxy bool `json:"use_proxy"`
Publish bool `json:"publish"` Publish bool `json:"publish"`
ErrMsg string `json:"err_msg"` ErrMsg string `json:"err_msg"`
Power int `json:"power"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
} }

17
api/store/vo/power_log.go Normal file
View File

@@ -0,0 +1,17 @@
package vo
import "chatplus/core/types"
type PowerLog struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Username string `json:"username"`
Type types.PowerType `json:"type"`
TypeStr string `json:"type_str"`
Amount int `json:"amount"`
Mark types.PowerMark `json:"mark"`
Balance int `json:"balance"`
Model string `json:"model"`
Remark string `json:"remark"`
CreatedAt int64 `json:"created_at"`
}

View File

@@ -6,8 +6,7 @@ type Product struct {
Price float64 `json:"price"` Price float64 `json:"price"`
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
Days int `json:"days"` Days int `json:"days"`
Calls int `json:"calls"` Power int `json:"power"`
ImgCalls int `json:"img_calls"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Sales int `json:"sales"` Sales int `json:"sales"`
SortNum int `json:"sort_num"` SortNum int `json:"sort_num"`

View File

@@ -12,6 +12,5 @@ type Reward struct {
} }
type RewardExchange struct { type RewardExchange struct {
Calls int `json:"calls"` Power int `json:"power"`
ImgCalls int `json:"img_calls"`
} }

View File

@@ -16,5 +16,6 @@ type SdJob struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Publish bool `json:"publish"` Publish bool `json:"publish"`
ErrMsg string `json:"err_msg"` ErrMsg string `json:"err_msg"`
Power int `json:"power"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
} }

View File

@@ -1,23 +1,17 @@
package vo package vo
import "chatplus/core/types"
type User struct { type User struct {
BaseVo BaseVo
Username string `json:"username"` Username string `json:"username"`
Nickname string `json:"nickname"` Nickname string `json:"nickname"`
Avatar string `json:"avatar"` Avatar string `json:"avatar"`
Salt string `json:"salt"` // 密码盐 Salt string `json:"salt"` // 密码盐
TotalTokens int64 `json:"total_tokens"` // 总消耗tokens Power int `json:"power"` // 剩余算力
Calls int `json:"calls"` // 剩余对话次数 ChatRoles []string `json:"chat_roles"` // 聊天角色集合
ImgCalls int `json:"img_calls"` ChatModels []int `json:"chat_models"` // AI模型集合
ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置 ExpiredTime int64 `json:"expired_time"` // 账户到期时间
ChatRoles []string `json:"chat_roles"` // 聊天角色集合 Status bool `json:"status"` // 当前状态
ChatModels []string `json:"chat_models"` // AI模型集合 LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
ExpiredTime int64 `json:"expired_time"` // 账户到期时间 LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
Status bool `json:"status"` // 当前状态 Vip bool `json:"vip"`
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
Vip bool `json:"vip"`
Tokens int `json:"token"` // 当月消耗的 fee
} }

View File

@@ -1,49 +1,11 @@
package main package main
import ( import (
"chatplus/utils"
"fmt" "fmt"
"regexp"
) )
func main() { func main() {
text := ` text := "一只 蜗牛在树干上爬,阳光透过树叶照在蜗牛的背上 --ar 1:1 --iw 0.250000 --v 6"
> search("Shenzhen weather January 15, 2024") fmt.Println(utils.HasChinese(text))
> mclick([0, 9, 16])
> **end-searching**
今天深圳的天气情况如下:
- 白天气温预计在21°C至24°C之间天气晴朗。
- 晚上气温预计在21°C左右云量较多可能会有间断性小雨。
- 风向主要是东南风风速大约在6至12公里每小时之间。
这些信息表明深圳今天的天气相对舒适,适合户外活动。晚上可能需要带伞以应对间断性小雨。温度较为宜人,早晚可能稍微凉爽一些【[Shenzhen weather in January 2024 | Shenzhen 14 day weather](https://www.weather25.com/asia/china/guangdong/shenzhen?page=month&month=January)】【[Hourly forecast for Shenzhen, Guangdong, China](https://www.timeanddate.com/weather/china/shenzhen/hourly)】【[Shenzhen Guangdong China 15 Day Weather Forecast](https://www.weatheravenue.com/en/asia/cn/guangdong/shenzhen-weather-15-days.html)】。
我将根据这些信息生成一张气象图,展示深圳今天的天气情况。
{"prompt":"A detailed weather map for Shenzhen, China, on January 15, 2024. The map shows a sunny day with clear skies during the day and partly cloudy skies at night. Temperatures range from 21\u00b0C to 24\u00b0C during the day and around 21\u00b0C at night. There are indications of light southeast winds during the day and evening, with wind speeds ranging from 6 to 12 km/h. The map includes symbols for sunshine, light clouds, and wind direction arrows, along with temperature readings for different times of the day. The layout is clear, with a focus on Shenzhen's geographical location and the surrounding region.","size":"1024x1024"}
![image1](https://filesystem.site/cdn/20240115/XD6EjyPDGCD4X3AQt3h3FijRmSb6fB.webp)
![下载1](https://filesystem.site/cdn/download/20240115/XD6EjyPDGCD4X3AQt3h3FijRmSb6fB.webp)
And here is another image link: ![another image](https://example.com/another-image.png).
这是根据今天深圳的天气情况制作的气象图。图中展示了白天晴朗、夜间部分多云的天气,以及相关的温度和风向信息。`
pattern := `!\[([^\]]*)]\(([^)]+)\)`
// 编译正则表达式
re := regexp.MustCompile(pattern)
// 查找匹配的字符串
matches := re.FindAllStringSubmatch(text, -1)
// 提取链接并打印
for _, match := range matches {
fmt.Println(match[2])
}
} }

Some files were not shown because too many files have changed in this diff Show More