mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-07 09:43:43 +08:00
Compare commits
407 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
754ba02263 | ||
|
|
7ddf57ae06 | ||
|
|
cc5180a6f7 | ||
|
|
9f44c34d34 | ||
|
|
b793b81768 | ||
|
|
233f6e00f0 | ||
|
|
b7dba68549 | ||
|
|
bdea12c51a | ||
|
|
a27d9ea259 | ||
|
|
7cd824c284 | ||
|
|
e27d95e2b5 | ||
|
|
6839827db0 | ||
|
|
d6a04f96fe | ||
|
|
5f820b9dc1 | ||
|
|
3cc2263dc7 | ||
|
|
f0a3c5d8ae | ||
|
|
2a4ef27774 | ||
|
|
2b057f32aa | ||
|
|
bc6451026f | ||
|
|
99fd596862 | ||
|
|
f0959b5df6 | ||
|
|
6788edbe9d | ||
|
|
3895305882 | ||
|
|
1b0938b33f | ||
|
|
c2acbaaa94 | ||
|
|
02faff461a | ||
|
|
e18e5a38c6 | ||
|
|
2f9b1b7835 | ||
|
|
717b137a6d | ||
|
|
f755bdccae | ||
|
|
4bba77ab47 | ||
|
|
6944a32ff3 | ||
|
|
5742b40aee | ||
|
|
7f1ec90748 | ||
|
|
4a99be2f15 | ||
|
|
bee19392c1 | ||
|
|
27c816cf3b | ||
|
|
0d81776212 | ||
|
|
00d31a2379 | ||
|
|
cccab31c0f | ||
|
|
5d65505ab7 | ||
|
|
3dc7d0516a | ||
|
|
50335ebc2d | ||
|
|
bcadee7290 | ||
|
|
cac3194d5b | ||
|
|
4ddf3bf2bf | ||
|
|
d45f9fbad6 | ||
|
|
d98b08d7cd | ||
|
|
5a8fe5a6cf | ||
|
|
36c27d6092 | ||
|
|
3ab29da8f0 | ||
|
|
3699f024f1 | ||
|
|
3d37a3d367 | ||
|
|
73d8236697 | ||
|
|
114d0088dc | ||
|
|
43b6665370 | ||
|
|
5fb9f84182 | ||
|
|
e35c34ad9a | ||
|
|
1a4d798f8b | ||
|
|
afb91a7023 | ||
|
|
dc4c1f7877 | ||
|
|
bbc8fe2b40 | ||
|
|
3c34e8e0e7 | ||
|
|
57c932f07c | ||
|
|
922202734a | ||
|
|
8b3b0139b0 | ||
|
|
31828a3336 | ||
|
|
b270960a04 | ||
|
|
5c4899df6e | ||
|
|
9a797bb4a5 | ||
|
|
b0c9ffc5a6 | ||
|
|
f527cc5b98 | ||
|
|
debe8dc209 | ||
|
|
2f0215ac87 | ||
|
|
dd5cc206e5 | ||
|
|
142cd553a3 | ||
|
|
657ecccee3 | ||
|
|
1232c3cd9c | ||
|
|
3ac04a3938 | ||
|
|
b7abc42209 | ||
|
|
a48179ce0e | ||
|
|
e589f25a05 | ||
|
|
cc1a3ce343 | ||
|
|
7bb76d581c | ||
|
|
0d733c0be0 | ||
|
|
8b40ac5b5c | ||
|
|
24479814e9 | ||
|
|
99df028237 | ||
|
|
b354b88876 | ||
|
|
5e0be4d10e | ||
|
|
468b48151f | ||
|
|
fa5c036041 | ||
|
|
0fdc588167 | ||
|
|
2e023cb8dc | ||
|
|
e933f32d9c | ||
|
|
bd4b0c4d65 | ||
|
|
0b2501c1d8 | ||
|
|
9d28e62142 | ||
|
|
c1d892069e | ||
|
|
61b2dbc9f1 | ||
|
|
be3245666e | ||
|
|
dacdd6fe74 | ||
|
|
6807f7e88a | ||
|
|
087f5ab2d1 | ||
|
|
47c5a0387b | ||
|
|
f9da18ad52 | ||
|
|
5c9025ca22 | ||
|
|
d02cb573fd | ||
|
|
caa538a1d0 | ||
|
|
b584b4bfb6 | ||
|
|
bda335212d | ||
|
|
06f4cdc649 | ||
|
|
336a7d5b56 | ||
|
|
a0f464830f | ||
|
|
9bf7fa4081 | ||
|
|
96ead65774 | ||
|
|
7ad41927aa | ||
|
|
4ca9dfd9c0 | ||
|
|
8a9f386d8f | ||
|
|
adfee8bf58 | ||
|
|
fbfa2a71a9 | ||
|
|
9a1368ef17 | ||
|
|
31b02b97d3 | ||
|
|
42da38c5c3 | ||
|
|
0a01b55713 | ||
|
|
3b292c2a12 | ||
|
|
db0ba0d9a0 | ||
|
|
3a23ff6b42 | ||
|
|
1e9c5adb0a | ||
|
|
abab76ccc6 | ||
|
|
6efd92806f | ||
|
|
cfe333e89f | ||
|
|
a7237fe62f | ||
|
|
c3c454b7d7 | ||
|
|
d4d708d44b | ||
|
|
7f0b6a3a46 | ||
|
|
c2a7c089d2 | ||
|
|
df5bd4df60 | ||
|
|
79b6010104 | ||
|
|
97b0a98793 | ||
|
|
5230f90540 | ||
|
|
803db4e895 | ||
|
|
7cee9f2ebb | ||
|
|
8be9a21efd | ||
|
|
6a3e26b566 | ||
|
|
0355c37bef | ||
|
|
9b7ee538c4 | ||
|
|
d900a3d08e | ||
|
|
cdf5b66729 | ||
|
|
1cff4b63cd | ||
|
|
da14309ef9 | ||
|
|
fbb216fe3b | ||
|
|
95efbd5659 | ||
|
|
4596c1049c | ||
|
|
b35d95f0c7 | ||
|
|
01419df998 | ||
|
|
a6c00c42fa | ||
|
|
4cc9db7115 | ||
|
|
4f1ed54059 | ||
|
|
8227a73e35 | ||
|
|
adfd8c1939 | ||
|
|
8eed7ff534 | ||
|
|
c79c4e74d0 | ||
|
|
f1855fd0a1 | ||
|
|
1f964c74e9 | ||
|
|
4fb2c5803c | ||
|
|
b5947545cb | ||
|
|
342b76f666 | ||
|
|
49b5906bc7 | ||
|
|
3075bfb7fc | ||
|
|
82e06fad33 | ||
|
|
4a9028747b | ||
|
|
4a8ff0ccf0 | ||
|
|
99341f0484 | ||
|
|
f58ac29ad0 | ||
|
|
7060edb3e5 | ||
|
|
41ae411f9b | ||
|
|
79b7fee47c | ||
|
|
0044bf10af | ||
|
|
e9348d3611 | ||
|
|
b9236e09a7 | ||
|
|
09b38d5f42 | ||
|
|
7bb539a06e | ||
|
|
5cdada8265 | ||
|
|
4147c217b1 | ||
|
|
8dda639b23 | ||
|
|
8487d2c9eb | ||
|
|
c5e583b215 | ||
|
|
549f618cff | ||
|
|
e9a3510346 | ||
|
|
30e6e963b3 | ||
|
|
c72d963f45 | ||
|
|
172d498618 | ||
|
|
313993532e | ||
|
|
e53db3582c | ||
|
|
72c6bd3f77 | ||
|
|
ca8b349df3 | ||
|
|
1b206c3640 | ||
|
|
c60276fc9f | ||
|
|
d00a3167c0 | ||
|
|
6b1cd8c30c | ||
|
|
46f12dc9ad | ||
|
|
a3e1d8ae21 | ||
|
|
72a066b93e | ||
|
|
0327a829ac | ||
|
|
882e9b8819 | ||
|
|
ef58cfadaa | ||
|
|
bf958d6113 | ||
|
|
71611273d7 | ||
|
|
b27c654311 | ||
|
|
90930ea9f9 | ||
|
|
1ab2185ff1 | ||
|
|
0f2f978d4c | ||
|
|
f61963b0b0 | ||
|
|
2aa413960d | ||
|
|
aa4bbba5ec | ||
|
|
eba61fea2d | ||
|
|
34e3455128 | ||
|
|
07dca3e739 | ||
|
|
4cb4b145f9 | ||
|
|
1ed417cb69 | ||
|
|
6cf91a84ca | ||
|
|
0b566980fc | ||
|
|
f86176b342 | ||
|
|
c700b32670 | ||
|
|
22641b452a | ||
|
|
d3fbb8c19e | ||
|
|
e3bb69ff10 | ||
|
|
770360c614 | ||
|
|
f302a0478f | ||
|
|
a88697b43a | ||
|
|
cc6f140812 | ||
|
|
424f2b3bdc | ||
|
|
ec0c13a600 | ||
|
|
a1f03bec4c | ||
|
|
b5bd4a5e0e | ||
|
|
7c2e49bfdb | ||
|
|
f80fe6d041 | ||
|
|
72f80a96bc | ||
|
|
2de655a1cf | ||
|
|
da2bd4a501 | ||
|
|
e0aa62c40d | ||
|
|
9d26a892d1 | ||
|
|
4ece7f2847 | ||
|
|
32368caf1b | ||
|
|
e91f54e79e | ||
|
|
bb8f4c57c4 | ||
|
|
43bfac99b6 | ||
|
|
be379b6d63 | ||
|
|
17f3c9b840 | ||
|
|
24de97fac2 | ||
|
|
bf27b44fee | ||
|
|
1802b4fe4d | ||
|
|
241a5c7bc9 | ||
|
|
557d547bf1 | ||
|
|
2e7b75affb | ||
|
|
bc21a1d443 | ||
|
|
3fc9e10a24 | ||
|
|
5fa1aa2060 | ||
|
|
be8a0ec184 | ||
|
|
b02e3aad95 | ||
|
|
08eca511ad | ||
|
|
c34e911596 | ||
|
|
8a452c3072 | ||
|
|
13bfb14107 | ||
|
|
4188b0969e | ||
|
|
0c27795a10 | ||
|
|
d05693c5c1 | ||
|
|
c0b2063b38 | ||
|
|
4d183747b1 | ||
|
|
08fe1b2f75 | ||
|
|
db3e8a267e | ||
|
|
8fc62682c4 | ||
|
|
75031914a3 | ||
|
|
a4c9fdd95a | ||
|
|
6a9bfeb5aa | ||
|
|
e654766f60 | ||
|
|
0ef6955f96 | ||
|
|
b4501557c9 | ||
|
|
a2ed99e6cb | ||
|
|
6bd6bb3885 | ||
|
|
399cf65fc9 | ||
|
|
24906a6df1 | ||
|
|
d772bbebe6 | ||
|
|
14988853a3 | ||
|
|
7b3f16ac9f | ||
|
|
82b2755c18 | ||
|
|
ff4b267858 | ||
|
|
a590d0497f | ||
|
|
ac30d906f0 | ||
|
|
5bc071e038 | ||
|
|
88b956cf98 | ||
|
|
f725cf4661 | ||
|
|
057cc1e8a6 | ||
|
|
de122735b8 | ||
|
|
e87ede981c | ||
|
|
606fb498e1 | ||
|
|
a0c06e40a4 | ||
|
|
aba8f57279 | ||
|
|
960286a350 | ||
|
|
8c93fa51f6 | ||
|
|
cb0e7d64ff | ||
|
|
8e7413da97 | ||
|
|
a36f14eb94 | ||
|
|
f2f9f6e488 | ||
|
|
85068b8ca2 | ||
|
|
f2cfcfeefc | ||
|
|
755273a898 | ||
|
|
d4a24a0f1d | ||
|
|
92281fcbb7 | ||
|
|
636db4afcc | ||
|
|
ba25b8755e | ||
|
|
6399d13a49 | ||
|
|
06fa54fd25 | ||
|
|
a335b965d0 | ||
|
|
725adaa7d0 | ||
|
|
7e7e81e974 | ||
|
|
8cfe6bfc17 | ||
|
|
33de83f2ac | ||
|
|
3f856afec8 | ||
|
|
4e4dc4cb73 | ||
|
|
02a9c422fe | ||
|
|
ca69341024 | ||
|
|
169bf069ce | ||
|
|
1bee0ab04d | ||
|
|
440d91dd0e | ||
|
|
8168e246a8 | ||
|
|
2ef07574ae | ||
|
|
37392f2bb2 | ||
|
|
a80cd3848e | ||
|
|
db6ed84451 | ||
|
|
4463cc5963 | ||
|
|
d316158fe2 | ||
|
|
e02a8d7586 | ||
|
|
9988dff885 | ||
|
|
35ef5674ff | ||
|
|
976da45bce | ||
|
|
c83ac48bd2 | ||
|
|
3d159a833e | ||
|
|
4b09878bdd | ||
|
|
b0162e6a92 | ||
|
|
8ab15e5dc4 | ||
|
|
d2ac807252 | ||
|
|
0af01f6f1f | ||
|
|
013b319fab | ||
|
|
2899ba5949 | ||
|
|
a558b7e104 | ||
|
|
7a833e2233 | ||
|
|
bf65746d00 | ||
|
|
f08a7862de | ||
|
|
023a2c2f09 | ||
|
|
1bcd0f4c1a | ||
|
|
a0f3bc8ccb | ||
|
|
dea72738c1 | ||
|
|
a1d1fe7763 | ||
|
|
a39ed9764c | ||
|
|
aaa5ba99aa | ||
|
|
2113508b6d | ||
|
|
7fe4212684 | ||
|
|
8bdda64794 | ||
|
|
ec08c24dca | ||
|
|
a992a5b3b3 | ||
|
|
0f05970141 | ||
|
|
e5e762efcd | ||
|
|
b3d0c1ef9c | ||
|
|
397078f7ff | ||
|
|
3ad8065e20 | ||
|
|
66c7717f04 | ||
|
|
412f8ecc6c | ||
|
|
51dcf642b3 | ||
|
|
bfeea555b2 | ||
|
|
479f94c372 | ||
|
|
0140713e86 | ||
|
|
15b2ec9721 | ||
|
|
c9cd082855 | ||
|
|
d7c002890c | ||
|
|
348dd22279 | ||
|
|
3e99b4cbf6 | ||
|
|
6968da3ac7 | ||
|
|
bf1c1b84c3 | ||
|
|
c70314d930 | ||
|
|
9104ca8e49 | ||
|
|
2af33b3630 | ||
|
|
654e795545 | ||
|
|
c62ba2451e | ||
|
|
d72d1b8a99 | ||
|
|
b939d6016b | ||
|
|
36a2626ccc | ||
|
|
bd057a4cc9 | ||
|
|
dc24a8c781 | ||
|
|
59fa21779b | ||
|
|
a140671aad | ||
|
|
5fe8990fb4 | ||
|
|
12799b7159 | ||
|
|
9929746b1d | ||
|
|
d70035ff0c | ||
|
|
eec90274d8 | ||
|
|
e8fff55c42 | ||
|
|
3cf3cdd705 | ||
|
|
9801fce659 | ||
|
|
4c1f51110b | ||
|
|
913d538587 | ||
|
|
9e704365fc | ||
|
|
485bdbc56a | ||
|
|
7000168fd4 | ||
|
|
5694f97a6b | ||
|
|
b677d3fac7 |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
deploy
|
||||||
|
docs
|
||||||
|
api/static
|
||||||
|
web/node_modules
|
||||||
|
desktop
|
||||||
|
|
||||||
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -1,5 +1,5 @@
|
|||||||
name: Bug 报告 🐛
|
name: Bug 报告 🐛
|
||||||
description: 为 chatgpt-plus 提交错误报告
|
description: 为 geekai 提交错误报告
|
||||||
labels: ['Bug']
|
labels: ['Bug']
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
@@ -1,5 +1,5 @@
|
|||||||
name: 功能优化 🚀
|
name: 功能优化 🚀
|
||||||
description: 为 chatgpt-plus 提交优化建议
|
description: 为 geekai 提交优化建议
|
||||||
labels: ['feature']
|
labels: ['feature']
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
|
|||||||
152
CHANGELOG.md
152
CHANGELOG.md
@@ -1,4 +1,156 @@
|
|||||||
# 更新日志
|
# 更新日志
|
||||||
|
## v4.0.8
|
||||||
|
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
|
||||||
|
* 功能优化:当数据库更新失败的时候记录错误日志
|
||||||
|
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
|
||||||
|
* Bug修复:修复移动端聊天页面模型切换不生效的Bug
|
||||||
|
* 功能优化:给PC端扫码支付增加签名验证和有效期验证
|
||||||
|
* Bug修复:修复支付码生成API权限控制的问题
|
||||||
|
* Bug修复:模型算力设置为0时,不扣减用户算力,并且不记录算力消费日志
|
||||||
|
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
|
||||||
|
* 功能新增:H5端支持 Dalle 绘图
|
||||||
|
|
||||||
|
## v4.0.7
|
||||||
|
|
||||||
|
* 功能优化:升级quic-go,支持 Go1.21
|
||||||
|
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
|
||||||
|
* Bug修复:修复弹窗等于图形验证码一直验证失败的问题
|
||||||
|
* 功能重构:重构前端 UI 页面,增加顶部导航
|
||||||
|
* 功能优化:优化 Vue 非父子组件之间的通信方式
|
||||||
|
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
|
||||||
|
|
||||||
|
## v4.0.6
|
||||||
|
|
||||||
|
* Bug修复:修复PC端画廊页面的瀑布流组件样式错乱问题
|
||||||
|
* 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
|
||||||
|
* Bug修复:修复思维导图不扣费的Bug
|
||||||
|
* Bug修复:修复管理后台角色删除失败的Bug
|
||||||
|
* Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数
|
||||||
|
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
|
||||||
|
* Bug修复:修复注册用户提示注册人数达到上限的 Bug
|
||||||
|
* 功能优化:将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
|
||||||
|
|
||||||
|
## v4.0.5
|
||||||
|
|
||||||
|
* 功能优化:已授权系统在后台显示授权信息
|
||||||
|
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
|
||||||
|
* 功能优化:优化首页登录注册页面的 UI
|
||||||
|
* BUG修复:修复License验证的逻辑漏洞
|
||||||
|
* Bug修复:后台添加用户的时候密码规则限制跟前台注册保持一致
|
||||||
|
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
|
||||||
|
* 功能新增:移动端新增 DALL-E 绘画功能
|
||||||
|
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
|
||||||
|
* 功能新增:移动支持免登录预览功能
|
||||||
|
* Bug修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
|
||||||
|
* Bug修复:修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
|
||||||
|
|
||||||
|
## v4.0.4
|
||||||
|
|
||||||
|
* Bug修复:修复统一千问第二句不回复的问题
|
||||||
|
* 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
|
||||||
|
* 功能新增:Dalle AI 绘画功能实现
|
||||||
|
* Bug修复:修复思维导图格式乱码问题
|
||||||
|
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
|
||||||
|
* 功能新增:支持从应用列表直接和某个应用对话
|
||||||
|
* 功能优化:优化算力日志的页面和首页的UI
|
||||||
|
* 功能新增:支持思维导图导出 PNG 图片下载
|
||||||
|
|
||||||
|
## v4.0.3
|
||||||
|
|
||||||
|
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
|
||||||
|
* Bug修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
|
||||||
|
* Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
|
||||||
|
* 功能新增:支持上传图片和视觉模型
|
||||||
|
* 功能优化:优化聊天页面的复制代码按钮样式乱码
|
||||||
|
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
|
||||||
|
* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4
|
||||||
|
* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
|
||||||
|
* 功能新增:支持管理后台 Logo 修改
|
||||||
|
|
||||||
|
## 4.0.2
|
||||||
|
|
||||||
|
* 功能新增:支持前端菜单可以配置
|
||||||
|
* 功能优化:在登录和注册界面标题显示软件版本号
|
||||||
|
* 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
|
||||||
|
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
|
||||||
|
* Bug修复:解决因为图片上传使用相对路径而导致融图失败的问题。
|
||||||
|
* 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||||
|
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
||||||
|
|
||||||
|
## v4.0.1
|
||||||
|
|
||||||
|
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
|
||||||
|
发行版,稳定性更强一些
|
||||||
|
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
|
||||||
|
MJ-Plus 中转
|
||||||
|
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||||
|
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
|
||||||
|
* Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
|
||||||
|
|
||||||
|
## v4.0.0
|
||||||
|
|
||||||
|
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
||||||
|
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ
|
||||||
|
对话消耗15个算力...
|
||||||
|
|
||||||
|
* 功能重构:重构整体系统,全部采用算力来进行结算
|
||||||
|
* 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
||||||
|
* 功能优化:移动端聊天页面图片支持预览和放大功能
|
||||||
|
* 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
|
||||||
|
* 功能优化:**PC端不登录也可以预览功能,只有在发起操作的时候才需要登录**
|
||||||
|
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
|
||||||
|
* 功能新增:支持H5支付
|
||||||
|
* 功能优化:支持数学公式的识别和美化输出
|
||||||
|
* 功能新增:新增算力消费日志功能
|
||||||
|
* 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
|
||||||
|
* 功能新增:管理后台新增7日内新增用户和新增订单统计
|
||||||
|
|
||||||
|
## v3.2.7
|
||||||
|
|
||||||
|
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
||||||
|
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
||||||
|
* Bug修复:修复 issue [
|
||||||
|
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
|
||||||
|
* 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型
|
||||||
|
* Bug修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
|
||||||
|
* Bug修复:重新生成按钮功能失效问题
|
||||||
|
* Bug修复:对话输入HTML标签不显示的问题
|
||||||
|
* 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
|
||||||
|
* 功能新增:新增删除文件功能
|
||||||
|
* Bug修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
|
||||||
|
* 功能新增:后台管理新怎对话查看和检索功能
|
||||||
|
|
||||||
|
## v3.2.6
|
||||||
|
|
||||||
|
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
|
||||||
|
* 功能优化:兼用旧版本微信收款消息解析
|
||||||
|
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
|
||||||
|
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
|
||||||
|
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
|
||||||
|
* 功能新增:新增网站公告,可以在管理后台自定义配置
|
||||||
|
* 功能新增:新增阿里通义千问大模型支持
|
||||||
|
* Bug修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
|
||||||
|
* 功能优化:新增虎皮椒和PayJS订单状态校验功能,增加安全性
|
||||||
|
* Bug修复:修复微信转账交易 ID 提取失败 Bug
|
||||||
|
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
|
||||||
|
* 功能新增:新增短信宝短信平台发送平台集成
|
||||||
|
|
||||||
|
## v3.2.5
|
||||||
|
|
||||||
|
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
|
||||||
|
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
|
||||||
|
Plus 账号了!!!
|
||||||
|
* 功能优化:增强 markdown 图片和引用块解析。
|
||||||
|
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
|
||||||
|
* 功能优化:function call 兼用中转 API。
|
||||||
|
* Bug修复:修复部分已知的 Bug。
|
||||||
|
|
||||||
|
## v3.2.4.1
|
||||||
|
|
||||||
|
* 功能新增:新增 PayJs 支付通道
|
||||||
|
* Bug修复:紧急修复后台添加用户失败问题
|
||||||
|
* Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题
|
||||||
|
* Bug修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
|
||||||
|
|
||||||
## v3.2.4
|
## v3.2.4
|
||||||
|
|
||||||
|
|||||||
214
LICENSE
214
LICENSE
@@ -1,21 +1,201 @@
|
|||||||
MIT License
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
Copyright (c) 2023 RockYang
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
1. Definitions.
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
copies or substantial portions of the Software.
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
the copyright owner that is granting the License.
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
other entities that control, are controlled by, or are under common
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
control with that entity. For the purposes of this definition,
|
||||||
SOFTWARE.
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|||||||
130
README.md
130
README.md
@@ -1,126 +1,67 @@
|
|||||||
# ChatGPT-Plus
|
# GeekAI
|
||||||
|
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||||
|
|
||||||
**ChatGPT-PLUS** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
|
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
|
||||||
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。主要有如下特性:
|
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
|
||||||
|
|
||||||
* 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
主要特性:
|
||||||
* 基于 Websocket 实现,完美的打字机体验。
|
|
||||||
* 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||||
* 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
- 基于 Websocket 实现,完美的打字机体验。
|
||||||
* 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
|
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
||||||
* 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||||
* 已集成支付宝支付功能,支持多种会员套餐和点卡购买功能。
|
- 支持 Suno 文生音乐
|
||||||
* 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
|
||||||
|
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||||
|
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||||
|
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
||||||
绘画函数插件。
|
绘画函数插件。
|
||||||
|
|
||||||
|
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
|
||||||
|
|
||||||
|
- [x] 更友好的 UI 界面
|
||||||
|
- [x] 支持 Dall-E 文生图功能
|
||||||
|
- [x] 支持文生思维导图
|
||||||
|
- [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能
|
||||||
|
- [x] 支持网站 Logo 版权等信息的修改
|
||||||
|
|
||||||
## 功能截图
|
## 功能截图
|
||||||
|
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
|
||||||
### PC 端聊天界面
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### AI 对话界面
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### MidJourney 专业绘画界面
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### Stable-Diffusion 专业绘画页面
|
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
### 绘图作品展
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### AI应用列表
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### 会员充值
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
### 自动调用函数插件
|
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
### 管理后台
|
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
### 移动端 Web 页面
|
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
### 体验地址
|
### 体验地址
|
||||||
|
|
||||||
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
|
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
|
||||||
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
|
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
|
||||||
|
|
||||||
## 快速部署
|
## 快速部署
|
||||||
|
|
||||||
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
|
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
|
||||||
|
|
||||||
```shell
|
|
||||||
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.4-7b5ff48154.sh)"
|
|
||||||
```
|
|
||||||
|
|
||||||
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
|
|
||||||
|
|
||||||
* 前端访问地址:http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
|
|
||||||
* 后台管理地址:http://localhost:8080/admin
|
|
||||||
* 移动端地址:http://localhost:8080/mobile
|
|
||||||
* 初始后台管理账号:admin/admin123
|
|
||||||
* 初始前端体验账号:18575670125/12345678
|
|
||||||
|
|
||||||
服务启动成功之后不能立刻使用,需要先登录管理后台 -> API-KEY 去添加一个 OpenAI 或者文心一言,科大讯飞等至少一个平台的 API
|
|
||||||
KEY。
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
另外,如果您目前还没有 OpenAI 的 API KEY的,推荐您去 https://gpt.bemore.lol 购买,**无需魔法,高速稳定,且价格还远低于 OpenAI
|
|
||||||
官方**。
|
|
||||||
|
|
||||||
## 使用须知
|
## 使用须知
|
||||||
|
|
||||||
1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
||||||
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
|
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
|
||||||
|
|
||||||
## 项目地址
|
## 项目地址
|
||||||
|
|
||||||
* Github 地址:https://github.com/yangjian102621/chatgpt-plus
|
* Github 地址:https://github.com/yangjian102621/geekai
|
||||||
* 码云地址:https://gitee.com/blackfox/chatgpt-plus
|
* 码云地址:https://gitee.com/blackfox/geekai
|
||||||
|
|
||||||
## 客户端下载
|
## 客户端下载
|
||||||
|
|
||||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2
|
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
|
||||||
|
|
||||||
## TODOLIST
|
## TODOLIST
|
||||||
|
|
||||||
* [ ] 支持基于知识库的 AI 问答
|
* [ ] 支持基于知识库的 AI 问答
|
||||||
* [ ] 会员邀请注册推广功能
|
* [ ] 文生视频,文生歌曲功能
|
||||||
* [ ] 微信支付功能
|
* [ ] 微信支付功能
|
||||||
|
|
||||||
## 项目文档
|
## 项目文档
|
||||||
|
|
||||||
*
|
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
|
||||||
|
|
||||||
*
|
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
|
||||||
最新的部署视频教程:[https://www.bilibili.com/video/BV1ge411C7uA/](https://www.bilibili.com/video/BV1ge411C7uA/?vd_source=dee8b15703ccfcbd24a60ee9a0fabb73)
|
|
||||||
**
|
|
||||||
|
|
||||||
详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)。
|
|
||||||
|
|
||||||
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。**
|
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。**
|
||||||
|
|
||||||
@@ -148,7 +89,4 @@ KEY。
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
SHELL=/usr/bin/env bash
|
SHELL=/usr/bin/env bash
|
||||||
NAME := chatgpt-plus
|
NAME := geekai
|
||||||
all: amd64 arm64
|
all: amd64 arm64
|
||||||
|
|
||||||
amd64:
|
amd64:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
Listen = "0.0.0.0:5678"
|
Listen = "0.0.0.0:5678"
|
||||||
ProxyURL = "" # 如 http://127.0.0.1:7777
|
ProxyURL = "" # 如 http://127.0.0.1:7777
|
||||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
|
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
|
||||||
StaticDir = "./static" # 静态资源的目录
|
StaticDir = "./static" # 静态资源的目录
|
||||||
StaticUrl = "/static" # 静态资源访问 URL
|
StaticUrl = "/static" # 静态资源访问 URL
|
||||||
AesEncryptKey = ""
|
AesEncryptKey = ""
|
||||||
@@ -10,10 +10,6 @@ WeChatBot = false
|
|||||||
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
|
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
|
||||||
MaxAge = 86400
|
MaxAge = 86400
|
||||||
|
|
||||||
[Manager]
|
|
||||||
Username = "admin"
|
|
||||||
Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
|
|
||||||
|
|
||||||
[Redis] # redis 配置信息
|
[Redis] # redis 配置信息
|
||||||
Host = "localhost"
|
Host = "localhost"
|
||||||
Port = 6379
|
Port = 6379
|
||||||
@@ -25,19 +21,28 @@ WeChatBot = false
|
|||||||
AppId = ""
|
AppId = ""
|
||||||
Token = ""
|
Token = ""
|
||||||
|
|
||||||
[SmsConfig] # 阿里云短信服务配置
|
|
||||||
AccessKey = ""
|
[SMS] # Sms 配置,用于发送短信
|
||||||
AccessSecret = ""
|
Active = "Ali" # 当前启用的短信服务,默认使用阿里云
|
||||||
Product = "Dysmsapi"
|
[SMS.Bao]
|
||||||
Domain = "dysmsapi.aliyuncs.com"
|
Username = ""
|
||||||
Sign = ""
|
Password = ""
|
||||||
CodeTempId = ""
|
Domain = "api.smsbao.com"
|
||||||
|
Sign = "【极客学长】"
|
||||||
|
CodeTemplate = "您的验证码是{code}。5分钟有效,若非本人操作,请忽略本短信。"
|
||||||
|
[SMS.Ali]
|
||||||
|
AccessKey = ""
|
||||||
|
AccessSecret = ""
|
||||||
|
Product = "Dysmsapi"
|
||||||
|
Domain = "dysmsapi.aliyuncs.com"
|
||||||
|
Sign = ""
|
||||||
|
CodeTempId = ""
|
||||||
|
|
||||||
[OSS] # OSS 配置,用于存储 MJ 绘画图片
|
[OSS] # OSS 配置,用于存储 MJ 绘画图片
|
||||||
Active = "local" # 默认使用本地文件存储引擎
|
Active = "local" # 默认使用本地文件存储引擎
|
||||||
[OSS.Local]
|
[OSS.Local]
|
||||||
BasePath = "./static/upload" # 本地文件上传根路径
|
BasePath = "./static/upload" # 本地文件上传根路径
|
||||||
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可
|
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
|
||||||
[OSS.Minio]
|
[OSS.Minio]
|
||||||
Endpoint = "" # 如 172.22.11.200:9000
|
Endpoint = "" # 如 172.22.11.200:9000
|
||||||
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
|
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
|
||||||
@@ -51,20 +56,24 @@ WeChatBot = false
|
|||||||
AccessSecret = ""
|
AccessSecret = ""
|
||||||
Bucket = ""
|
Bucket = ""
|
||||||
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
|
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
|
||||||
|
[OSS.AliYun]
|
||||||
|
Endpoint = "oss-cn-hangzhou.aliyuncs.com"
|
||||||
|
AccessKey = ""
|
||||||
|
AccessSecret = ""
|
||||||
|
Bucket = "chatgpt-plus"
|
||||||
|
SubDir = ""
|
||||||
|
Domain = ""
|
||||||
|
|
||||||
[[MjConfigs]]
|
[[MjProxyConfigs]]
|
||||||
Enabled = false
|
Enabled = true
|
||||||
UserToken = ""
|
ApiURL = "http://midjourney-proxy:8082"
|
||||||
BotToken = ""
|
ApiKey = "sk-geekmaster"
|
||||||
GuildId = ""
|
|
||||||
ChanelId = ""
|
|
||||||
|
|
||||||
[[MjConfigs]]
|
[[MjPlusConfigs]]
|
||||||
Enabled = false
|
Enabled = false
|
||||||
UserToken = ""
|
ApiURL = "https://api.chat-plus.net"
|
||||||
BotToken = ""
|
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
|
||||||
GuildId = ""
|
ApiKey = "sk-xxx"
|
||||||
ChanelId = ""
|
|
||||||
|
|
||||||
[[SdConfigs]]
|
[[SdConfigs]]
|
||||||
Enabled = false
|
Enabled = false
|
||||||
@@ -72,12 +81,6 @@ WeChatBot = false
|
|||||||
ApiKey = ""
|
ApiKey = ""
|
||||||
Txt2ImgJsonPath = "res/sd/text2img.json"
|
Txt2ImgJsonPath = "res/sd/text2img.json"
|
||||||
|
|
||||||
[[SdConfigs]]
|
|
||||||
Enabled = false
|
|
||||||
ApiURL = ""
|
|
||||||
ApiKey = ""
|
|
||||||
Txt2ImgJsonPath = "res/text2img.json"
|
|
||||||
|
|
||||||
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
|
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
|
||||||
Enabled = false # 是否启用 XXL JOB 服务
|
Enabled = false # 是否启用 XXL JOB 服务
|
||||||
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
|
ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
|
||||||
@@ -95,4 +98,28 @@ WeChatBot = false
|
|||||||
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
|
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
|
||||||
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
|
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
|
||||||
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
|
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
|
||||||
NotifyURL = "http://r9it.com:6004/api/payment/alipay/notify" # 支付异步回调地址
|
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
|
||||||
|
|
||||||
|
[HuPiPayConfig]
|
||||||
|
Enabled = false
|
||||||
|
Name = "wechat"
|
||||||
|
AppId = ""
|
||||||
|
AppSecret = ""
|
||||||
|
ApiURL = "https://api.xunhupay.com"
|
||||||
|
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
|
||||||
|
|
||||||
|
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
|
||||||
|
UseTls = false
|
||||||
|
Host = "smtp.163.com"
|
||||||
|
Port = 25
|
||||||
|
AppName = "极客学长"
|
||||||
|
From = "test@163.com" # 发件邮箱人地址
|
||||||
|
Password = "" #邮箱 stmp 服务授权码
|
||||||
|
|
||||||
|
[JPayConfig] # PayJs 支付配置
|
||||||
|
Enabled = false
|
||||||
|
Name = "wechat" # 请不要改动
|
||||||
|
AppId = "" # 商户 ID
|
||||||
|
PrivateKey = "" # 秘钥
|
||||||
|
ApiURL = "https://payjs.cn"
|
||||||
|
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
|
||||||
@@ -1,22 +1,29 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/nfnt/resize"
|
"github.com/nfnt/resize"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"image"
|
"image"
|
||||||
"image/jpeg"
|
"image/jpeg"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
@@ -28,10 +35,9 @@ type AppServer struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
Config *types.AppConfig
|
Config *types.AppConfig
|
||||||
Engine *gin.Engine
|
Engine *gin.Engine
|
||||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||||
|
|
||||||
ChatConfig *types.ChatConfig // chat config cache
|
SysConfig *types.SystemConfig // system config cache
|
||||||
SysConfig *types.SystemConfig // system config cache
|
|
||||||
|
|
||||||
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
||||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||||
@@ -47,7 +53,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
|||||||
Debug: false,
|
Debug: false,
|
||||||
Config: appConfig,
|
Config: appConfig,
|
||||||
Engine: gin.Default(),
|
Engine: gin.Default(),
|
||||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||||
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
@@ -69,23 +75,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *AppServer) Run(db *gorm.DB) error {
|
func (s *AppServer) Run(db *gorm.DB) error {
|
||||||
// load chat config from database
|
|
||||||
var chatConfig model.Config
|
|
||||||
res := db.Where("marker", "chat").First(&chatConfig)
|
|
||||||
if res.Error != nil {
|
|
||||||
return res.Error
|
|
||||||
}
|
|
||||||
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// load system configs
|
// load system configs
|
||||||
var sysConfig model.Config
|
var sysConfig model.Config
|
||||||
res = db.Where("marker", "system").First(&sysConfig)
|
res := db.Where("marker", "system").First(&sysConfig)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -143,73 +139,64 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
// 用户授权验证
|
// 用户授权验证
|
||||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if c.Request.URL.Path == "/api/user/login" ||
|
|
||||||
c.Request.URL.Path == "/api/user/resetPass" ||
|
|
||||||
c.Request.URL.Path == "/api/admin/login" ||
|
|
||||||
c.Request.URL.Path == "/api/user/register" ||
|
|
||||||
c.Request.URL.Path == "/api/chat/history" ||
|
|
||||||
c.Request.URL.Path == "/api/chat/detail" ||
|
|
||||||
c.Request.URL.Path == "/api/role/list" ||
|
|
||||||
c.Request.URL.Path == "/api/mj/jobs" ||
|
|
||||||
c.Request.URL.Path == "/api/mj/client" ||
|
|
||||||
c.Request.URL.Path == "/api/invite/hits" ||
|
|
||||||
c.Request.URL.Path == "/api/sd/jobs" ||
|
|
||||||
c.Request.URL.Path == "/api/upload" ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
|
|
||||||
c.Request.URL.Path == "/api/admin/config/get" {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokenString string
|
var tokenString string
|
||||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
|
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||||
|
if isAdminApi { // 后台管理 API
|
||||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||||
} else if c.Request.URL.Path == "/api/chat/new" {
|
} else if c.Request.URL.Path == "/api/chat/new" {
|
||||||
tokenString = c.Query("token")
|
tokenString = c.Query("token")
|
||||||
} else {
|
} else {
|
||||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
resp.ERROR(c, "You should put Authorization in request headers")
|
if needLogin(c) {
|
||||||
c.Abort()
|
resp.ERROR(c, "You should put Authorization in request headers")
|
||||||
return
|
c.Abort()
|
||||||
|
return
|
||||||
|
} else { // 直接放行
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
|
if isAdminApi {
|
||||||
|
return []byte(s.Config.AdminSession.SecretKey), nil
|
||||||
|
} else {
|
||||||
|
return []byte(s.Config.Session.SecretKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
return []byte(s.Config.Session.SecretKey), nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil && needLogin(c) {
|
||||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
if !ok || !token.Valid {
|
if !ok || !token.Valid && needLogin(c) {
|
||||||
resp.NotAuth(c, "Token is invalid")
|
resp.NotAuth(c, "Token is invalid")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
|
||||||
resp.NotAuth(c, "Token is expired")
|
resp.NotAuth(c, "Token is expired")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||||
if _, err := client.Get(context.Background(), key).Result(); err != nil {
|
if isAdminApi {
|
||||||
|
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
||||||
|
}
|
||||||
|
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
|
||||||
resp.NotAuth(c, "Token is not found in redis")
|
resp.NotAuth(c, "Token is not found in redis")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
@@ -218,6 +205,47 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func needLogin(c *gin.Context) bool {
|
||||||
|
if c.Request.URL.Path == "/api/user/login" ||
|
||||||
|
c.Request.URL.Path == "/api/user/logout" ||
|
||||||
|
c.Request.URL.Path == "/api/user/resetPass" ||
|
||||||
|
c.Request.URL.Path == "/api/admin/login" ||
|
||||||
|
c.Request.URL.Path == "/api/admin/logout" ||
|
||||||
|
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||||
|
c.Request.URL.Path == "/api/user/register" ||
|
||||||
|
c.Request.URL.Path == "/api/user/session" ||
|
||||||
|
c.Request.URL.Path == "/api/chat/history" ||
|
||||||
|
c.Request.URL.Path == "/api/chat/detail" ||
|
||||||
|
c.Request.URL.Path == "/api/chat/list" ||
|
||||||
|
c.Request.URL.Path == "/api/role/list" ||
|
||||||
|
c.Request.URL.Path == "/api/model/list" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/client" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/notify" ||
|
||||||
|
c.Request.URL.Path == "/api/invite/hits" ||
|
||||||
|
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||||
|
c.Request.URL.Path == "/api/sd/client" ||
|
||||||
|
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||||
|
c.Request.URL.Path == "/api/dall/client" ||
|
||||||
|
c.Request.URL.Path == "/api/product/list" ||
|
||||||
|
c.Request.URL.Path == "/api/menu/list" ||
|
||||||
|
c.Request.URL.Path == "/api/markMap/client" ||
|
||||||
|
c.Request.URL.Path == "/api/payment/alipay/notify" ||
|
||||||
|
c.Request.URL.Path == "/api/payment/hupipay/notify" ||
|
||||||
|
c.Request.URL.Path == "/api/payment/payjs/notify" ||
|
||||||
|
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||||
|
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/static/") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// 统一参数处理
|
// 统一参数处理
|
||||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
func parameterHandlerMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
@@ -316,6 +344,10 @@ func staticResourceMiddleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
// 解码图片
|
// 解码图片
|
||||||
img, _, err := image.Decode(file)
|
img, _, err := image.Decode(file)
|
||||||
|
// for .webp image
|
||||||
|
if err != nil {
|
||||||
|
img, err = webp.Decode(file)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||||
return
|
return
|
||||||
@@ -332,7 +364,9 @@ func staticResourceMiddleware() gin.HandlerFunc {
|
|||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
logger.Error(err)
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置图片缓存有效期为一年 (365天)
|
// 设置图片缓存有效期为一年 (365天)
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
logger2 "chatplus/logger"
|
logger2 "geekai/logger"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
@@ -16,7 +23,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
return &types.AppConfig{
|
return &types.AppConfig{
|
||||||
Listen: "0.0.0.0:5678",
|
Listen: "0.0.0.0:5678",
|
||||||
ProxyURL: "",
|
ProxyURL: "",
|
||||||
Manager: types.Manager{Username: "admin", Password: "admin123"},
|
|
||||||
StaticDir: "./static",
|
StaticDir: "./static",
|
||||||
StaticUrl: "http://localhost/5678/static",
|
StaticUrl: "http://localhost/5678/static",
|
||||||
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
|
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
|
||||||
@@ -24,7 +30,7 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
SecretKey: utils.RandString(64),
|
SecretKey: utils.RandString(64),
|
||||||
MaxAge: 86400,
|
MaxAge: 86400,
|
||||||
},
|
},
|
||||||
ApiConfig: types.ChatPlusApiConfig{},
|
ApiConfig: types.ApiConfig{},
|
||||||
OSS: types.OSSConfig{
|
OSS: types.OSSConfig{
|
||||||
Active: "local",
|
Active: "local",
|
||||||
Local: types.LocalStorageConfig{
|
Local: types.LocalStorageConfig{
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
// ApiRequest API 请求实体
|
// ApiRequest API 请求实体
|
||||||
type ApiRequest struct {
|
type ApiRequest struct {
|
||||||
Model string `json:"model,omitempty"` // 兼容百度文心一言
|
Model string `json:"model,omitempty"` // 兼容百度文心一言
|
||||||
@@ -8,8 +15,13 @@ type ApiRequest struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Messages []interface{} `json:"messages,omitempty"`
|
Messages []interface{} `json:"messages,omitempty"`
|
||||||
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
||||||
Tools []interface{} `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
ToolChoice string `json:"tool_choice,omitempty"`
|
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||||
|
|
||||||
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
|
|
||||||
|
Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问
|
||||||
|
Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@@ -28,10 +40,14 @@ type ChoiceItem struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Delta struct {
|
type Delta struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Content interface{} `json:"content"`
|
Content interface{} `json:"content"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
FunctionCall struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Arguments string `json:"arguments,omitempty"`
|
||||||
|
} `json:"function_call,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatSession 聊天会话对象
|
// ChatSession 聊天会话对象
|
||||||
@@ -45,10 +61,15 @@ type ChatSession struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatModel struct {
|
type ChatModel struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Platform Platform `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Value string `json:"value"`
|
Name string `json:"name"`
|
||||||
Weight int `json:"weight"`
|
Value string `json:"value"`
|
||||||
|
Power int `json:"power"`
|
||||||
|
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||||
|
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||||
|
Temperature float32 `json:"temperature"` // 模型温度
|
||||||
|
KeyId int `json:"key_id"` // 绑定 API KEY
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiError struct {
|
type ApiError struct {
|
||||||
@@ -63,23 +84,36 @@ type ApiError struct {
|
|||||||
const PromptMsg = "prompt" // prompt message
|
const PromptMsg = "prompt" // prompt message
|
||||||
const ReplyMsg = "reply" // reply message
|
const ReplyMsg = "reply" // reply message
|
||||||
|
|
||||||
var ModelToTokens = map[string]int{
|
// PowerType 算力日志类型
|
||||||
"gpt-3.5-turbo": 4096,
|
type PowerType int
|
||||||
"gpt-3.5-turbo-16k": 16384,
|
|
||||||
"gpt-4": 8192,
|
const (
|
||||||
"gpt-4-32k": 32768,
|
PowerRecharge = PowerType(1) // 充值
|
||||||
"chatglm_pro": 32768, // 清华智普
|
PowerConsume = PowerType(2) // 消费
|
||||||
"chatglm_std": 16384,
|
PowerRefund = PowerType(3) // 任务(SD,MJ)执行失败,退款
|
||||||
"chatglm_lite": 4096,
|
PowerInvite = PowerType(4) // 邀请奖励
|
||||||
"ernie_bot_turbo": 8192, // 文心一言
|
PowerReward = PowerType(5) // 众筹
|
||||||
"general": 8192, // 科大讯飞
|
PowerGift = PowerType(6) // 系统赠送
|
||||||
"general2": 8192,
|
)
|
||||||
"general3": 8192,
|
|
||||||
|
func (t PowerType) String() string {
|
||||||
|
switch t {
|
||||||
|
case PowerRecharge:
|
||||||
|
return "充值"
|
||||||
|
case PowerConsume:
|
||||||
|
return "消费"
|
||||||
|
case PowerRefund:
|
||||||
|
return "退款"
|
||||||
|
case PowerReward:
|
||||||
|
return "众筹"
|
||||||
|
|
||||||
|
}
|
||||||
|
return "其他"
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelMaxToken(model string) int {
|
type PowerMark int
|
||||||
if token, ok := ModelToTokens[model]; ok {
|
|
||||||
return token
|
const (
|
||||||
}
|
PowerSub = PowerMark(0)
|
||||||
return 4096
|
PowerAdd = PowerMark(1)
|
||||||
}
|
)
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
|||||||
@@ -1,25 +1,33 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
Path string `toml:"-"`
|
Path string `toml:"-"`
|
||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
ProxyURL string
|
AdminSession Session
|
||||||
MysqlDns string // mysql 连接地址
|
ProxyURL string
|
||||||
Manager Manager // 后台管理员账户信息
|
MysqlDns string // mysql 连接地址
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
SMS SMSConfig // send mobile message config
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
MjConfigs []MidJourneyConfig // mj AI draw service pool
|
MjProxyConfigs []MjProxyConfig // MJ proxy config
|
||||||
WeChatBot bool // 是否启用微信机器人
|
MjPlusConfigs []MjPlusConfig // MJ plus config
|
||||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
WeChatBot bool // 是否启用微信机器人
|
||||||
|
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
||||||
|
|
||||||
XXLConfig XXLConfig
|
XXLConfig XXLConfig
|
||||||
AlipayConfig AlipayConfig
|
AlipayConfig AlipayConfig
|
||||||
@@ -29,6 +37,7 @@ type AppConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SmtpConfig struct {
|
type SmtpConfig struct {
|
||||||
|
UseTls bool // 是否使用 TLS 发送
|
||||||
Host string
|
Host string
|
||||||
Port int
|
Port int
|
||||||
AppName string // 应用名称
|
AppName string // 应用名称
|
||||||
@@ -36,47 +45,31 @@ type SmtpConfig struct {
|
|||||||
Password string // 发件人邮箱密码
|
Password string // 发件人邮箱密码
|
||||||
}
|
}
|
||||||
|
|
||||||
// JPayConfig PayJs 支付配置
|
type ApiConfig struct {
|
||||||
type JPayConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
AppId string // 商户 ID
|
|
||||||
PrivateKey string // 私钥
|
|
||||||
ApiURL string // API 网关
|
|
||||||
NotifyURL string // 异步回调地址
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatPlusApiConfig struct {
|
|
||||||
ApiURL string
|
ApiURL string
|
||||||
AppId string
|
AppId string
|
||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyConfig struct {
|
type MjProxyConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
UserToken string
|
ApiURL string // api 地址
|
||||||
BotToken string
|
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||||
GuildId string // Server ID
|
ApiKey string
|
||||||
ChanelId string // Chanel ID
|
|
||||||
UseCDN bool
|
|
||||||
DiscordAPI string
|
|
||||||
DiscordCDN string
|
|
||||||
DiscordGateway string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StableDiffusionConfig struct {
|
type StableDiffusionConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
ApiURL string
|
Model string // 模型名称
|
||||||
ApiKey string
|
ApiURL string
|
||||||
Txt2ImgJsonPath string
|
ApiKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliYunSmsConfig struct {
|
type MjPlusConfig struct {
|
||||||
AccessKey string
|
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
||||||
AccessSecret string
|
ApiURL string // api 地址
|
||||||
Product string
|
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||||
Domain string
|
ApiKey string
|
||||||
Sign string // 短信签名
|
|
||||||
CodeTempId string // 验证码短信模板 ID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AlipayConfig struct {
|
type AlipayConfig struct {
|
||||||
@@ -89,6 +82,7 @@ type AlipayConfig struct {
|
|||||||
AlipayPublicKey string // 支付宝公钥文件路径
|
AlipayPublicKey string // 支付宝公钥文件路径
|
||||||
RootCert string // Root 秘钥路径
|
RootCert string // Root 秘钥路径
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知回调
|
||||||
|
ReturnURL string // 支付成功返回地址
|
||||||
}
|
}
|
||||||
|
|
||||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||||
@@ -96,8 +90,20 @@ type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
|||||||
Name string // 支付名称,如:wechat/alipay
|
Name string // 支付名称,如:wechat/alipay
|
||||||
AppId string // App ID
|
AppId string // App ID
|
||||||
AppSecret string // app 密钥
|
AppSecret string // app 密钥
|
||||||
|
ApiURL string // 支付网关
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知回调
|
||||||
PayURL string // 支付网关
|
ReturnURL string // 支付成功返回地址
|
||||||
|
}
|
||||||
|
|
||||||
|
// JPayConfig PayJs 支付配置
|
||||||
|
type JPayConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
Name string // 支付名称,默认 wechat
|
||||||
|
AppId string // 商户 ID
|
||||||
|
PrivateKey string // 私钥
|
||||||
|
ApiURL string // API 网关
|
||||||
|
NotifyURL string // 异步回调地址
|
||||||
|
ReturnURL string // 支付成功返回地址
|
||||||
}
|
}
|
||||||
|
|
||||||
type XXLConfig struct { // XXL 任务调度配置
|
type XXLConfig struct { // XXL 任务调度配置
|
||||||
@@ -116,73 +122,96 @@ type RedisConfig struct {
|
|||||||
DB int
|
DB int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LicenseKey 存储许可证书的 KEY
|
||||||
|
const LicenseKey = "Geek-AI-License"
|
||||||
|
|
||||||
|
type License struct {
|
||||||
|
Key string `json:"key"` // 许可证书密钥
|
||||||
|
MachineId string `json:"machine_id"` // 机器码
|
||||||
|
ExpiredAt int64 `json:"expired_at"` // 过期时间
|
||||||
|
IsActive bool `json:"is_active"` // 是否激活
|
||||||
|
Configs LicenseConfig `json:"configs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LicenseConfig struct {
|
||||||
|
UserNum int `json:"user_num"` // 用户数量
|
||||||
|
DeCopy bool `json:"de_copy"` // 去版权
|
||||||
|
}
|
||||||
|
|
||||||
func (c RedisConfig) Url() string {
|
func (c RedisConfig) Url() string {
|
||||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager 管理员
|
type Platform struct {
|
||||||
type Manager struct {
|
Name string `json:"name"`
|
||||||
Username string `json:"username"`
|
Value string `json:"value"`
|
||||||
Password string `json:"password"`
|
ChatURL string `json:"chat_url"`
|
||||||
|
ImgURL string `json:"img_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatConfig 系统默认的聊天配置
|
var OpenAI = Platform{
|
||||||
type ChatConfig struct {
|
Name: "OpenAI - GPT",
|
||||||
OpenAI ModelAPIConfig `json:"open_ai"`
|
Value: "OpenAI",
|
||||||
Azure ModelAPIConfig `json:"azure"`
|
ChatURL: "https://api.chat-plus.net/v1/chat/completions",
|
||||||
ChatGML ModelAPIConfig `json:"chat_gml"`
|
ImgURL: "https://api.chat-plus.net/v1/images/generations",
|
||||||
Baidu ModelAPIConfig `json:"baidu"`
|
|
||||||
XunFei ModelAPIConfig `json:"xun_fei"`
|
|
||||||
|
|
||||||
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
|
|
||||||
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
|
|
||||||
ContextDeep int `json:"context_deep"` // 上下文深度
|
|
||||||
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
|
|
||||||
}
|
}
|
||||||
|
var Azure = Platform{
|
||||||
type Platform string
|
Name: "微软 - Azure",
|
||||||
|
Value: "Azure",
|
||||||
const OpenAI = Platform("OpenAI")
|
ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15",
|
||||||
const Azure = Platform("Azure")
|
|
||||||
const ChatGLM = Platform("ChatGLM")
|
|
||||||
const Baidu = Platform("Baidu")
|
|
||||||
const XunFei = Platform("XunFei")
|
|
||||||
|
|
||||||
// UserChatConfig 用户的聊天配置
|
|
||||||
type UserChatConfig struct {
|
|
||||||
ApiKeys map[Platform]string `json:"api_keys"`
|
|
||||||
}
|
}
|
||||||
|
var ChatGLM = Platform{
|
||||||
type InviteReward struct {
|
Name: "智谱 - ChatGLM",
|
||||||
ChatCalls int `json:"chat_calls"`
|
Value: "ChatGLM",
|
||||||
ImgCalls int `json:"img_calls"`
|
ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke",
|
||||||
}
|
}
|
||||||
|
var Baidu = Platform{
|
||||||
type ModelAPIConfig struct {
|
Name: "百度 - 文心大模型",
|
||||||
Temperature float32 `json:"temperature"`
|
Value: "Baidu",
|
||||||
MaxTokens int `json:"max_tokens"`
|
ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
|
||||||
|
}
|
||||||
|
var XunFei = Platform{
|
||||||
|
Name: "讯飞 - 星火大模型",
|
||||||
|
Value: "XunFei",
|
||||||
|
ChatURL: "wss://spark-api.xf-yun.com/{version}/chat",
|
||||||
|
}
|
||||||
|
var QWen = Platform{
|
||||||
|
Name: "阿里 - 通义千问",
|
||||||
|
Value: "QWen",
|
||||||
|
ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemConfig struct {
|
type SystemConfig struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title,omitempty"`
|
||||||
AdminTitle string `json:"admin_title"`
|
AdminTitle string `json:"admin_title,omitempty"`
|
||||||
InitChatCalls int `json:"init_chat_calls"` // 新用户注册赠送对话次数
|
Logo string `json:"logo,omitempty"`
|
||||||
InitImgCalls int `json:"init_img_calls"` // 新用户注册赠送绘图次数
|
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||||
VipMonthCalls int `json:"vip_month_calls"` // VIP 会员每月赠送的对话次数
|
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
|
||||||
VipMonthImgCalls int `json:"vip_month_img_calls"` // VIP 会员每月赠送绘图次数
|
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||||
|
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
||||||
|
|
||||||
RegisterWays []string `json:"register_ways"` // 注册方式:支持手机,邮箱注册
|
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
||||||
|
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||||
|
|
||||||
RewardImg string `json:"reward_img"` // 众筹收款二维码地址
|
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
|
||||||
EnabledReward bool `json:"enabled_reward"` // 启用众筹功能
|
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
|
||||||
ChatCallPrice float64 `json:"chat_call_price"` // 对话单次调用费用
|
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
|
||||||
ImgCallPrice float64 `json:"img_call_price"` // 绘图单次调用费用
|
|
||||||
|
|
||||||
OrderPayTimeout int `json:"order_pay_timeout"` //订单支付超时时间
|
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||||
DefaultModels []string `json:"default_models"` // 默认开通的 AI 模型
|
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
||||||
OrderPayInfoText string `json:"order_pay_info_text"` // 订单支付页面说明文字
|
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
|
||||||
InviteChatCalls int `json:"invite_chat_calls"` // 邀请用户注册奖励对话次数
|
|
||||||
InviteImgCalls int `json:"invite_img_calls"` // 邀请用户注册奖励绘图次数
|
|
||||||
|
|
||||||
ShowDemoNotice bool `json:"show_demo_notice"` // 显示演示站公告
|
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||||
|
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||||
|
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||||
|
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
|
||||||
|
|
||||||
|
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||||
|
|
||||||
|
EnableContext bool `json:"enable_context,omitempty"`
|
||||||
|
ContextDeep int `json:"context_deep,omitempty"`
|
||||||
|
|
||||||
|
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||||
|
|
||||||
|
RandBg bool `json:"rand_bg"` // 前端首页是否启用随机背景
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Function struct {
|
Function struct {
|
||||||
@@ -8,19 +15,13 @@ type ToolCall struct {
|
|||||||
} `json:"function"`
|
} `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function Function `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Parameters Parameters `json:"parameters"`
|
Parameters map[string]interface{} `json:"parameters"`
|
||||||
}
|
|
||||||
|
|
||||||
type Parameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]Property `json:"properties"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Property struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -9,7 +16,7 @@ type MKey interface {
|
|||||||
string | int | uint
|
string | int | uint
|
||||||
}
|
}
|
||||||
type MValue interface {
|
type MValue interface {
|
||||||
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
*WsClient | *ChatSession | context.CancelFunc | []Message
|
||||||
}
|
}
|
||||||
type LMap[K MKey, T MValue] struct {
|
type LMap[K MKey, T MValue] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
type OrderStatus int
|
type OrderStatus int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -9,10 +16,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type OrderRemark struct {
|
type OrderRemark struct {
|
||||||
Days int `json:"days"` // 有效期
|
Days int `json:"days"` // 有效期
|
||||||
Calls int `json:"calls"` // 增加对话次数
|
Power int `json:"power"` // 增加算力点数
|
||||||
ImgCalls int `json:"img_calls"` // 增加绘图次数
|
Name string `json:"name"` // 产品名称
|
||||||
Name string `json:"name"` // 产品名称
|
|
||||||
Price float64 `json:"price"`
|
Price float64 `json:"price"`
|
||||||
Discount float64 `json:"discount"`
|
Discount float64 `json:"discount"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
type OSSConfig struct {
|
type OSSConfig struct {
|
||||||
Active string
|
Active string
|
||||||
Local LocalStorageConfig
|
Local LocalStorageConfig
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
const LoginUserID = "LOGIN_USER_ID"
|
const LoginUserID = "LOGIN_USER_ID"
|
||||||
const LoginUserCache = "LOGIN_USER_CACHE"
|
const LoginUserCache = "LOGIN_USER_CACHE"
|
||||||
|
|
||||||
const UserAuthHeader = "Authorization"
|
const UserAuthHeader = "Authorization"
|
||||||
const AdminAuthHeader = "Admin-Authorization"
|
const AdminAuthHeader = "Admin-Authorization"
|
||||||
const ChatTokenHeader = "Chat-Token"
|
|
||||||
|
|
||||||
// Session configs struct
|
// Session configs struct
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
|||||||
33
api/core/types/sms.go
Normal file
33
api/core/types/sms.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
type SMSConfig struct {
|
||||||
|
Active string
|
||||||
|
Ali SmsConfigAli
|
||||||
|
Bao SmsConfigBao
|
||||||
|
}
|
||||||
|
|
||||||
|
// SmsConfigAli 阿里云短信平台配置
|
||||||
|
type SmsConfigAli struct {
|
||||||
|
AccessKey string
|
||||||
|
AccessSecret string
|
||||||
|
Product string
|
||||||
|
Domain string
|
||||||
|
Sign string // 短信签名
|
||||||
|
CodeTempId string // 验证码短信模板 ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// SmsConfigBao 短信宝平台配置
|
||||||
|
type SmsConfigBao struct {
|
||||||
|
Username string //短信宝平台注册的用户名
|
||||||
|
Password string //短信宝平台注册的密码
|
||||||
|
Domain string //域名
|
||||||
|
Sign string // 短信签名
|
||||||
|
CodeTemplate string // 验证码短信模板 匹配
|
||||||
|
}
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
// TaskType 任务类别
|
// TaskType 任务类别
|
||||||
type TaskType string
|
type TaskType string
|
||||||
|
|
||||||
@@ -9,18 +16,24 @@ func (t TaskType) String() string {
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
TaskImage = TaskType("image")
|
TaskImage = TaskType("image")
|
||||||
|
TaskBlend = TaskType("blend")
|
||||||
|
TaskSwapFace = TaskType("swapFace")
|
||||||
TaskUpscale = TaskType("upscale")
|
TaskUpscale = TaskType("upscale")
|
||||||
TaskVariation = TaskType("variation")
|
TaskVariation = TaskType("variation")
|
||||||
)
|
)
|
||||||
|
|
||||||
// MjTask MidJourney 任务
|
// MjTask MidJourney 任务
|
||||||
type MjTask struct {
|
type MjTask struct {
|
||||||
Id int `json:"id"`
|
Id uint `json:"id"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
ImgArr []string `json:"img_arr"`
|
||||||
ChannelId string `json:"channel_id"`
|
ChannelId string `json:"channel_id"`
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
NegPrompt string `json:"neg_prompt,omitempty"`
|
||||||
|
Params string `json:"full_prompt"`
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
MessageId string `json:"message_id,omitempty"`
|
MessageId string `json:"message_id,omitempty"`
|
||||||
MessageHash string `json:"message_hash,omitempty"`
|
MessageHash string `json:"message_hash,omitempty"`
|
||||||
@@ -32,25 +45,38 @@ type SdTask struct {
|
|||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Params SdTaskParams `json:"params"`
|
Params SdTaskParams `json:"params"`
|
||||||
RetryCount int `json:"retry_count"`
|
RetryCount int `json:"retry_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SdTaskParams struct {
|
type SdTaskParams struct {
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
Prompt string `json:"prompt"` // 提示词
|
Prompt string `json:"prompt"` // 提示词
|
||||||
NegativePrompt string `json:"negative_prompt"` // 反向提示词
|
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||||
Steps int `json:"steps"` // 迭代步数,默认20
|
Steps int `json:"steps"` // 迭代步数,默认20
|
||||||
Sampler string `json:"sampler"` // 采样器
|
Sampler string `json:"sampler"` // 采样器
|
||||||
FaceFix bool `json:"face_fix"` // 面部修复
|
Scheduler string `json:"scheduler"`
|
||||||
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
|
FaceFix bool `json:"face_fix"` // 面部修复
|
||||||
Seed int64 `json:"seed"` // 随机数种子
|
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
|
||||||
Height int `json:"height"`
|
Seed int64 `json:"seed"` // 随机数种子
|
||||||
Width int `json:"width"`
|
Height int `json:"height"`
|
||||||
HdFix bool `json:"hd_fix"` // 启用高清修复
|
Width int `json:"width"`
|
||||||
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
|
HdFix bool `json:"hd_fix"` // 启用高清修复
|
||||||
HdScale int `json:"hd_scale"` // 放大倍数
|
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
|
||||||
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
|
HdScale int `json:"hd_scale"` // 放大倍数
|
||||||
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
|
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
|
||||||
|
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
|
||||||
|
}
|
||||||
|
|
||||||
|
// DallTask DALL-E task
|
||||||
|
type DallTask struct {
|
||||||
|
JobId uint `json:"job_id"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Quality string `json:"quality"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
Style string `json:"style"`
|
||||||
|
|
||||||
|
Power int `json:"power"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
// BizVo 业务返回 VO
|
// BizVo 业务返回 VO
|
||||||
type BizVo struct {
|
type BizVo struct {
|
||||||
Code BizCode `json:"code"`
|
Code BizCode `json:"code"`
|
||||||
@@ -21,7 +28,7 @@ const (
|
|||||||
WsStart = WsMsgType("start")
|
WsStart = WsMsgType("start")
|
||||||
WsMiddle = WsMsgType("middle")
|
WsMiddle = WsMsgType("middle")
|
||||||
WsEnd = WsMsgType("end")
|
WsEnd = WsMsgType("end")
|
||||||
WsMjImg = WsMsgType("mj")
|
WsErr = WsMsgType("error")
|
||||||
)
|
)
|
||||||
|
|
||||||
type BizCode int
|
type BizCode int
|
||||||
@@ -30,6 +37,7 @@ const (
|
|||||||
Success = BizCode(0)
|
Success = BizCode(0)
|
||||||
Failed = BizCode(1)
|
Failed = BizCode(1)
|
||||||
NotAuthorized = BizCode(400) // 未授权
|
NotAuthorized = BizCode(400) // 未授权
|
||||||
|
NotPermission = BizCode(403) // 没有权限
|
||||||
|
|
||||||
OkMsg = "Success"
|
OkMsg = "Success"
|
||||||
ErrorMsg = "系统开小差了"
|
ErrorMsg = "系统开小差了"
|
||||||
|
|||||||
51
api/go.mod
51
api/go.mod
@@ -1,6 +1,8 @@
|
|||||||
module chatplus
|
module geekai
|
||||||
|
|
||||||
go 1.19
|
go 1.21
|
||||||
|
|
||||||
|
toolchain go1.22.4
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/BurntSushi/toml v1.1.0
|
github.com/BurntSushi/toml v1.1.0
|
||||||
@@ -25,12 +27,28 @@ require (
|
|||||||
|
|
||||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
||||||
|
|
||||||
require github.com/bg5t/mydiscordgo v0.28.1
|
require (
|
||||||
|
github.com/mojocn/base64Captcha v1.3.1
|
||||||
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
|
github.com/shopspring/decimal v1.3.1
|
||||||
|
github.com/syndtr/goleveldb v1.0.0
|
||||||
|
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||||
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||||
|
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||||
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
|
go.uber.org/mock v0.4.0 // indirect
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/andybalholm/brotli v1.0.4 // indirect
|
github.com/andybalholm/brotli v1.0.4 // indirect
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||||
@@ -41,7 +59,6 @@ require (
|
|||||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/golang/mock v1.6.0 // indirect
|
|
||||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
|
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
@@ -58,9 +75,7 @@ require (
|
|||||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
|
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||||
github.com/quic-go/qpack v0.4.0 // indirect
|
github.com/quic-go/qpack v0.4.0 // indirect
|
||||||
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
|
github.com/quic-go/quic-go v0.45.0 // indirect
|
||||||
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
|
|
||||||
github.com/quic-go/quic-go v0.35.1 // indirect
|
|
||||||
github.com/refraction-networking/utls v1.3.2 // indirect
|
github.com/refraction-networking/utls v1.3.2 // indirect
|
||||||
github.com/rs/xid v1.5.0 // indirect
|
github.com/rs/xid v1.5.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
@@ -70,14 +85,14 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
go.uber.org/dig v1.16.1 // indirect
|
go.uber.org/dig v1.16.1 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||||
golang.org/x/mod v0.11.0 // indirect
|
golang.org/x/mod v0.17.0 // indirect
|
||||||
golang.org/x/net v0.14.0 // indirect
|
golang.org/x/net v0.25.0 // indirect
|
||||||
golang.org/x/sync v0.3.0 // indirect
|
golang.org/x/sync v0.7.0 // indirect
|
||||||
golang.org/x/text v0.12.0 // indirect
|
golang.org/x/text v0.15.0 // indirect
|
||||||
golang.org/x/time v0.3.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.10.0 // indirect
|
golang.org/x/tools v0.21.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.33.0 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
@@ -96,7 +111,7 @@ require (
|
|||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/fx v1.19.3
|
go.uber.org/fx v1.19.3
|
||||||
go.uber.org/multierr v1.6.0 // indirect
|
go.uber.org/multierr v1.6.0 // indirect
|
||||||
golang.org/x/crypto v0.12.0
|
golang.org/x/crypto v0.23.0
|
||||||
golang.org/x/sys v0.11.0 // indirect
|
golang.org/x/sys v0.20.0 // indirect
|
||||||
gorm.io/gorm v1.25.1
|
gorm.io/gorm v1.25.1
|
||||||
)
|
)
|
||||||
|
|||||||
124
api/go.sum
124
api/go.sum
@@ -7,13 +7,12 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
|
|||||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||||
github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8=
|
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||||
github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8=
|
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
@@ -29,7 +28,9 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
|
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
|
||||||
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
||||||
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||||
|
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||||
github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
|
github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
|
||||||
@@ -41,8 +42,12 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU
|
|||||||
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
|
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
|
||||||
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
|
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
|
||||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||||
|
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
@@ -65,18 +70,20 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
|
||||||
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
|
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
|
||||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
@@ -84,6 +91,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
|
|||||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||||
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
|
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
|
||||||
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
|
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
@@ -129,13 +137,21 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
|||||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
|
||||||
|
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||||
|
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||||
|
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||||
|
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||||
|
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||||
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
|
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
|
||||||
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
|
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
|
||||||
|
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||||
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
|
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
|
||||||
|
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
|
||||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
|
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
|
||||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
|
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||||
@@ -153,12 +169,8 @@ github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYX
|
|||||||
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
|
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
|
||||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
|
||||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
|
||||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
|
||||||
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
|
||||||
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
|
|
||||||
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
|
||||||
github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8=
|
github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8=
|
||||||
github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E=
|
github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E=
|
||||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||||
@@ -166,6 +178,10 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
|||||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
|
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||||
|
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||||
|
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||||
|
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
@@ -190,6 +206,12 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
||||||
|
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
|
||||||
|
github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
|
||||||
|
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
|
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
|
||||||
@@ -200,8 +222,9 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
|
|||||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
||||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||||
@@ -210,6 +233,9 @@ go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk=
|
|||||||
go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA=
|
go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA=
|
||||||
go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM=
|
go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM=
|
||||||
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
|
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
|
||||||
|
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||||
|
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||||
|
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||||
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
|
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
|
||||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||||
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
|
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
|
||||||
@@ -218,37 +244,35 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
|||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||||
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||||
|
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
|
||||||
|
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||||
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
|
||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
|
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
|
||||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
@@ -257,8 +281,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
@@ -268,34 +292,32 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
|
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
|
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||||
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
|
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
|
||||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
|
||||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||||
|
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||||
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||||
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||||
|
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|||||||
@@ -1,14 +1,26 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
logger2 "chatplus/logger"
|
logger2 "geekai/logger"
|
||||||
"chatplus/utils/resp"
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/mojocn/base64Captcha"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -17,47 +29,88 @@ import (
|
|||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
// Manager 管理员
|
||||||
|
type Manager struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Captcha string `json:"captcha"` // 验证码
|
||||||
|
CaptchaId string `json:"captcha_id"` // 验证码id
|
||||||
|
}
|
||||||
|
|
||||||
|
const SuperManagerID = 1
|
||||||
|
|
||||||
type ManagerHandler struct {
|
type ManagerHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
|
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
|
||||||
h := ManagerHandler{db: db, redis: client}
|
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login 登录
|
// Login 登录
|
||||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||||
var data types.Manager
|
var data Manager
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
manager := h.App.Config.Manager
|
|
||||||
if data.Username == manager.Username && data.Password == manager.Password {
|
// add captcha
|
||||||
// 创建 token
|
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
resp.ERROR(c, "验证码错误!")
|
||||||
"user_id": manager.Username,
|
return
|
||||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
|
||||||
})
|
|
||||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 保存到 redis
|
|
||||||
key := "users/" + manager.Username
|
|
||||||
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
|
|
||||||
resp.ERROR(c, "error with save token: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c, tokenString)
|
|
||||||
} else {
|
|
||||||
resp.ERROR(c, "用户名或者密码错误")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var manager model.AdminUser
|
||||||
|
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
password := utils.GenPassword(data.Password, manager.Salt)
|
||||||
|
if password != manager.Password {
|
||||||
|
resp.ERROR(c, "用户名或密码错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 超级管理员默认是ID:1
|
||||||
|
if manager.Id != SuperManagerID && manager.Status == false {
|
||||||
|
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 token
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"user_id": manager.Id,
|
||||||
|
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||||
|
})
|
||||||
|
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 保存到 redis
|
||||||
|
key := fmt.Sprintf("admin/%d", manager.Id)
|
||||||
|
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
|
||||||
|
resp.ERROR(c, "error with save token: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新最后登录时间和IP
|
||||||
|
manager.LastLoginIp = c.ClientIP()
|
||||||
|
manager.LastLoginAt = time.Now().Unix()
|
||||||
|
h.DB.Updates(&manager)
|
||||||
|
|
||||||
|
var result = struct {
|
||||||
|
IsSuperAdmin bool `json:"is_super_admin"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}{
|
||||||
|
IsSuperAdmin: manager.Id == 1,
|
||||||
|
Token: tokenString,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout 注销
|
// Logout 注销
|
||||||
@@ -72,10 +125,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
|
|||||||
|
|
||||||
// Session 会话检测
|
// Session 会话检测
|
||||||
func (h *ManagerHandler) Session(c *gin.Context) {
|
func (h *ManagerHandler) Session(c *gin.Context) {
|
||||||
token := c.GetHeader(types.AdminAuthHeader)
|
id := h.GetLoginUserId(c)
|
||||||
if token == "" {
|
key := fmt.Sprintf("admin/%d", id)
|
||||||
|
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
} else {
|
return
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
}
|
||||||
|
var manager model.AdminUser
|
||||||
|
res := h.DB.Where("id", id).First(&manager)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 数据列表
|
||||||
|
func (h *ManagerHandler) List(c *gin.Context) {
|
||||||
|
var items []model.AdminUser
|
||||||
|
res := h.DB.Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]vo.AdminUser, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
var u vo.AdminUser
|
||||||
|
err := utils.CopyObject(item, &u)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
u.Id = item.Id
|
||||||
|
u.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
users = append(users, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, users)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ManagerHandler) Save(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Status bool `json:"status"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.AdminUser
|
||||||
|
res := h.DB.Where("username", data.Username).First(&user)
|
||||||
|
if res.Error == nil {
|
||||||
|
resp.ERROR(c, "用户名已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成密码
|
||||||
|
salt := utils.RandString(8)
|
||||||
|
password := utils.GenPassword(data.Password, salt)
|
||||||
|
res = h.DB.Save(&model.AdminUser{
|
||||||
|
Username: data.Username,
|
||||||
|
Password: password,
|
||||||
|
Salt: salt,
|
||||||
|
Status: data.Status,
|
||||||
|
})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "failed with update database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove 删除管理员
|
||||||
|
func (h *ManagerHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == SuperManagerID {
|
||||||
|
resp.ERROR(c, "超级管理员不能删除")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 启用/禁用
|
||||||
|
func (h *ManagerHandler) Enable(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPass 重置密码
|
||||||
|
func (h *ManagerHandler) ResetPass(c *gin.Context) {
|
||||||
|
id := h.GetLoginUserId(c)
|
||||||
|
if id != SuperManagerID {
|
||||||
|
resp.ERROR(c, "只有超级管理员能够进行该操作")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.AdminUser
|
||||||
|
res := h.DB.Where("id", data.Id).First(&user)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
|
user.Password = password
|
||||||
|
res = h.DB.Updates(&user)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,31 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ApiKeyHandler struct {
|
type ApiKeyHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||||
h := ApiKeyHandler{db: db}
|
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||||
@@ -32,7 +37,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
ApiURL string `json:"api_url"`
|
ApiURL string `json:"api_url"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
UseProxy bool `json:"use_proxy"`
|
ProxyURL string `json:"proxy_url"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -41,17 +46,18 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
|
|
||||||
apiKey := model.ApiKey{}
|
apiKey := model.ApiKey{}
|
||||||
if data.Id > 0 {
|
if data.Id > 0 {
|
||||||
h.db.Find(&apiKey, data.Id)
|
h.DB.Find(&apiKey, data.Id)
|
||||||
}
|
}
|
||||||
apiKey.Platform = data.Platform
|
apiKey.Platform = data.Platform
|
||||||
apiKey.Value = data.Value
|
apiKey.Value = data.Value
|
||||||
apiKey.Type = data.Type
|
apiKey.Type = data.Type
|
||||||
apiKey.ApiURL = data.ApiURL
|
apiKey.ApiURL = data.ApiURL
|
||||||
apiKey.Enabled = data.Enabled
|
apiKey.Enabled = data.Enabled
|
||||||
apiKey.UseProxy = data.UseProxy
|
apiKey.ProxyURL = data.ProxyURL
|
||||||
apiKey.Name = data.Name
|
apiKey.Name = data.Name
|
||||||
res := h.db.Save(&apiKey)
|
res := h.DB.Save(&apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -68,9 +74,24 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||||
|
status := h.GetBool(c, "status")
|
||||||
|
t := h.GetTrim(c, "type")
|
||||||
|
platform := h.GetTrim(c, "platform")
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if status {
|
||||||
|
session = session.Where("enabled", true)
|
||||||
|
}
|
||||||
|
if t != "" {
|
||||||
|
session = session.Where("type", t)
|
||||||
|
}
|
||||||
|
if platform != "" {
|
||||||
|
session = session.Where("platform", platform)
|
||||||
|
}
|
||||||
|
|
||||||
var items []model.ApiKey
|
var items []model.ApiKey
|
||||||
var keys = make([]vo.ApiKey, 0)
|
var keys = make([]vo.ApiKey, 0)
|
||||||
res := h.db.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var key vo.ApiKey
|
var key vo.ApiKey
|
||||||
@@ -100,8 +121,9 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -110,13 +132,16 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *ApiKeyHandler) Remove(c *gin.Context) {
|
func (h *ApiKeyHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if id > 0 {
|
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
|
if res.Error != nil {
|
||||||
if res.Error != nil {
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
46
api/handler/admin/captcha_handler.go
Normal file
46
api/handler/admin/captcha_handler.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mojocn/base64Captcha"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CaptchaHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
|
||||||
|
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CaptchaVo struct {
|
||||||
|
CaptchaId string `json:"captcha_id"`
|
||||||
|
PicPath string `json:"pic_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCaptcha 获取验证码
|
||||||
|
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
|
||||||
|
var captchaVo CaptchaVo
|
||||||
|
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
|
||||||
|
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
|
||||||
|
// b64s是图片的base64编码
|
||||||
|
id, b64s, err := cp.Generate()
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "生成验证码错误!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
captchaVo.CaptchaId = id
|
||||||
|
captchaVo.PicPath = b64s
|
||||||
|
|
||||||
|
resp.SUCCESS(c, captchaVo)
|
||||||
|
}
|
||||||
269
api/handler/admin/chat_handler.go
Normal file
269
api/handler/admin/chat_handler.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||||
|
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatItemVo struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
ChatId string `json:"chat_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Role vo.ChatRole `json:"role"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Token int `json:"token"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
MsgNum int `json:"msg_num"` // 消息数量
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ChatHandler) List(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreateAt []string `json:"created_time"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Title != "" {
|
||||||
|
session = session.Where("title LIKE ?", "%"+data.Title+"%")
|
||||||
|
}
|
||||||
|
if data.UserId > 0 {
|
||||||
|
session = session.Where("user_id = ?", data.UserId)
|
||||||
|
}
|
||||||
|
if data.Model != "" {
|
||||||
|
session = session.Where("model = ?", data.Model)
|
||||||
|
}
|
||||||
|
if len(data.CreateAt) == 2 {
|
||||||
|
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
|
||||||
|
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.ChatItem{}).Count(&total)
|
||||||
|
var items []model.ChatItem
|
||||||
|
var list = make([]chatItemVo, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
userIds := make([]uint, 0)
|
||||||
|
chatIds := make([]string, 0)
|
||||||
|
roleIds := make([]uint, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
userIds = append(userIds, item.UserId)
|
||||||
|
chatIds = append(chatIds, item.ChatId)
|
||||||
|
roleIds = append(roleIds, item.RoleId)
|
||||||
|
}
|
||||||
|
var messages []model.ChatMessage
|
||||||
|
var users []model.User
|
||||||
|
var roles []model.ChatRole
|
||||||
|
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
||||||
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
|
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||||
|
|
||||||
|
tokenMap := make(map[string]int)
|
||||||
|
userMap := make(map[uint]string)
|
||||||
|
msgMap := make(map[string]int)
|
||||||
|
roleMap := make(map[uint]vo.ChatRole)
|
||||||
|
for _, msg := range messages {
|
||||||
|
tokenMap[msg.ChatId] += msg.Tokens
|
||||||
|
msgMap[msg.ChatId] += 1
|
||||||
|
}
|
||||||
|
for _, user := range users {
|
||||||
|
userMap[user.Id] = user.Username
|
||||||
|
}
|
||||||
|
for _, r := range roles {
|
||||||
|
var roleVo vo.ChatRole
|
||||||
|
err := utils.CopyObject(r, &roleVo)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
roleMap[r.Id] = roleVo
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
list = append(list, chatItemVo{
|
||||||
|
UserId: item.UserId,
|
||||||
|
Username: userMap[item.UserId],
|
||||||
|
ChatId: item.ChatId,
|
||||||
|
Title: item.Title,
|
||||||
|
Model: item.Model,
|
||||||
|
Token: tokenMap[item.ChatId],
|
||||||
|
MsgNum: msgMap[item.ChatId],
|
||||||
|
Role: roleMap[item.RoleId],
|
||||||
|
CreatedAt: item.CreatedAt.Unix(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatMessageVo struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Token int `json:"token"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Messages 读取聊天记录列表
|
||||||
|
func (h *ChatHandler) Messages(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreateAt []string `json:"created_time"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Content != "" {
|
||||||
|
session = session.Where("content LIKE ?", "%"+data.Content+"%")
|
||||||
|
}
|
||||||
|
if data.UserId > 0 {
|
||||||
|
session = session.Where("user_id = ?", data.UserId)
|
||||||
|
}
|
||||||
|
if data.Model != "" {
|
||||||
|
session = session.Where("model = ?", data.Model)
|
||||||
|
}
|
||||||
|
if len(data.CreateAt) == 2 {
|
||||||
|
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
|
||||||
|
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.ChatMessage{}).Count(&total)
|
||||||
|
var items []model.ChatMessage
|
||||||
|
var list = make([]chatMessageVo, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
userIds := make([]uint, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
userIds = append(userIds, item.UserId)
|
||||||
|
}
|
||||||
|
var users []model.User
|
||||||
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
|
userMap := make(map[uint]string)
|
||||||
|
for _, user := range users {
|
||||||
|
userMap[user.Id] = user.Username
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
list = append(list, chatMessageVo{
|
||||||
|
Id: item.Id,
|
||||||
|
UserId: item.UserId,
|
||||||
|
Username: userMap[item.UserId],
|
||||||
|
Content: item.Content,
|
||||||
|
Model: item.Model,
|
||||||
|
Token: item.Tokens,
|
||||||
|
Icon: item.Icon,
|
||||||
|
Type: item.Type,
|
||||||
|
CreatedAt: item.CreatedAt.Unix(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||||
|
}
|
||||||
|
|
||||||
|
// History 获取聊天历史记录
|
||||||
|
func (h *ChatHandler) History(c *gin.Context) {
|
||||||
|
chatId := c.Query("chat_id") // 会话 ID
|
||||||
|
var items []model.ChatMessage
|
||||||
|
var messages = make([]vo.HistoryMessage, 0)
|
||||||
|
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "No history message")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
for _, item := range items {
|
||||||
|
var v vo.HistoryMessage
|
||||||
|
err := utils.CopyObject(item, &v)
|
||||||
|
v.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
v.UpdatedAt = item.UpdatedAt.Unix()
|
||||||
|
if err == nil {
|
||||||
|
messages = append(messages, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveChat 删除对话
|
||||||
|
func (h *ChatHandler) RemoveChat(c *gin.Context) {
|
||||||
|
chatId := h.GetTrim(c, "chat_id")
|
||||||
|
if chatId == "" {
|
||||||
|
resp.ERROR(c, "请传入 ChatId")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := h.DB.Begin()
|
||||||
|
// 删除聊天记录
|
||||||
|
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "failed to remove chat message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除对话
|
||||||
|
res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{})
|
||||||
|
if res.Error != nil {
|
||||||
|
tx.Rollback() // 回滚
|
||||||
|
resp.ERROR(c, "failed to remove chat")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveMessage 删除聊天记录
|
||||||
|
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
|
||||||
|
if tx.Error != nil {
|
||||||
|
logger.Error("error with update database:", tx.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
@@ -1,40 +1,48 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatModelHandler struct {
|
type ChatModelHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||||
h := ChatModelHandler{db: db}
|
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatModelHandler) Save(c *gin.Context) {
|
func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
SortNum int `json:"sort_num"`
|
SortNum int `json:"sort_num"`
|
||||||
Open bool `json:"open"`
|
Open bool `json:"open"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Weight int `json:"weight"`
|
Power int `json:"power"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||||
|
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||||
|
Temperature float32 `json:"temperature"` // 模型温度
|
||||||
|
KeyId int `json:"key_id,omitempty"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -42,19 +50,26 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
item := model.ChatModel{
|
item := model.ChatModel{
|
||||||
Platform: data.Platform,
|
Platform: data.Platform,
|
||||||
Name: data.Name,
|
Name: data.Name,
|
||||||
Value: data.Value,
|
Value: data.Value,
|
||||||
Enabled: data.Enabled,
|
Enabled: data.Enabled,
|
||||||
SortNum: data.SortNum,
|
SortNum: data.SortNum,
|
||||||
Open: data.Open,
|
Open: data.Open,
|
||||||
Weight: data.Weight}
|
MaxTokens: data.MaxTokens,
|
||||||
item.Id = data.Id
|
MaxContext: data.MaxContext,
|
||||||
if item.Id > 0 {
|
Temperature: data.Temperature,
|
||||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
KeyId: data.KeyId,
|
||||||
|
Power: data.Power}
|
||||||
|
var res *gorm.DB
|
||||||
|
if data.Id > 0 {
|
||||||
|
item.Id = data.Id
|
||||||
|
res = h.DB.Select("*").Omit("created_at").Updates(&item)
|
||||||
|
} else {
|
||||||
|
res = h.DB.Create(&item)
|
||||||
}
|
}
|
||||||
res := h.db.Save(&item)
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -72,26 +87,45 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
|||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
enable := h.GetBool(c, "enable")
|
enable := h.GetBool(c, "enable")
|
||||||
|
platform := h.GetTrim(c, "platform")
|
||||||
if enable {
|
if enable {
|
||||||
session = session.Where("enabled", enable)
|
session = session.Where("enabled", enable)
|
||||||
}
|
}
|
||||||
|
if platform != "" {
|
||||||
|
session = session.Where("platform", platform)
|
||||||
|
}
|
||||||
var items []model.ChatModel
|
var items []model.ChatModel
|
||||||
var cms = make([]vo.ChatModel, 0)
|
var cms = make([]vo.ChatModel, 0)
|
||||||
res := session.Order("sort_num ASC").Find(&items)
|
res := session.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error != nil {
|
||||||
for _, item := range items {
|
resp.SUCCESS(c, cms)
|
||||||
var cm vo.ChatModel
|
return
|
||||||
err := utils.CopyObject(item, &cm)
|
}
|
||||||
if err == nil {
|
|
||||||
cm.Id = item.Id
|
// initialize key name
|
||||||
cm.CreatedAt = item.CreatedAt.Unix()
|
keyIds := make([]int, 0)
|
||||||
cm.UpdatedAt = item.UpdatedAt.Unix()
|
for _, v := range items {
|
||||||
cms = append(cms, cm)
|
keyIds = append(keyIds, v.KeyId)
|
||||||
} else {
|
}
|
||||||
logger.Error(err)
|
var keys []model.ApiKey
|
||||||
}
|
keyMap := make(map[uint]string)
|
||||||
|
h.DB.Where("id IN ?", keyIds).Find(&keys)
|
||||||
|
for _, v := range keys {
|
||||||
|
keyMap[v.Id] = v.Name
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
var cm vo.ChatModel
|
||||||
|
err := utils.CopyObject(item, &cm)
|
||||||
|
if err == nil {
|
||||||
|
cm.Id = item.Id
|
||||||
|
cm.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
cm.UpdatedAt = item.UpdatedAt.Unix()
|
||||||
|
cm.KeyName = keyMap[uint(item.KeyId)]
|
||||||
|
cms = append(cms, cm)
|
||||||
|
} else {
|
||||||
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, cms)
|
resp.SUCCESS(c, cms)
|
||||||
@@ -109,8 +143,9 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -129,8 +164,9 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -141,13 +177,16 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *ChatModelHandler) Remove(c *gin.Context) {
|
func (h *ChatModelHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if id > 0 {
|
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
|
if res.Error != nil {
|
||||||
if res.Error != nil {
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +1,32 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatRoleHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||||
h := ChatRoleHandler{db: db}
|
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save 创建或者更新某个角色
|
// Save 创建或者更新某个角色
|
||||||
@@ -41,8 +46,9 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
|||||||
if data.CreatedAt > 0 {
|
if data.CreatedAt > 0 {
|
||||||
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
}
|
}
|
||||||
res := h.db.Save(&role)
|
res := h.DB.Save(&role)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -55,12 +61,31 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
|||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatRole
|
var items []model.ChatRole
|
||||||
var roles = make([]vo.ChatRole, 0)
|
var roles = make([]vo.ChatRole, 0)
|
||||||
res := h.db.Order("sort_num ASC").Find(&items)
|
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No data found")
|
resp.ERROR(c, "No data found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initialize model mane for role
|
||||||
|
modelIds := make([]int, 0)
|
||||||
|
for _, v := range items {
|
||||||
|
if v.ModelId > 0 {
|
||||||
|
modelIds = append(modelIds, v.ModelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelNameMap := make(map[int]string)
|
||||||
|
if len(modelIds) > 0 {
|
||||||
|
var models []model.ChatModel
|
||||||
|
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
||||||
|
if tx.Error == nil {
|
||||||
|
for _, m := range models {
|
||||||
|
modelNameMap[int(m.Id)] = m.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var role vo.ChatRole
|
var role vo.ChatRole
|
||||||
err := utils.CopyObject(v, &role)
|
err := utils.CopyObject(v, &role)
|
||||||
@@ -68,6 +93,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
role.Id = v.Id
|
role.Id = v.Id
|
||||||
role.CreatedAt = v.CreatedAt.Unix()
|
role.CreatedAt = v.CreatedAt.Unix()
|
||||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||||
|
role.ModelName = modelNameMap[role.ModelId]
|
||||||
roles = append(roles, role)
|
roles = append(roles, role)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -88,8 +114,9 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -110,8 +137,9 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -120,13 +148,14 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id <= 0 {
|
if id <= 0 {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "删除失败!")
|
resp.ERROR(c, "删除失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,49 +1,73 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/service"
|
||||||
"chatplus/utils"
|
"geekai/service/mj"
|
||||||
"chatplus/utils/resp"
|
"geekai/service/sd"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/shirou/gopsutil/host"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConfigHandler struct {
|
type ConfigHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
levelDB *store.LevelDB
|
||||||
|
licenseService *service.LicenseService
|
||||||
|
mjServicePool *mj.ServicePool
|
||||||
|
sdServicePool *sd.ServicePool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
|
||||||
h := ConfigHandler{db: db}
|
return &ConfigHandler{
|
||||||
h.App = app
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
return &h
|
levelDB: levelDB,
|
||||||
|
mjServicePool: mjPool,
|
||||||
|
sdServicePool: sdPool,
|
||||||
|
licenseService: licenseService,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ConfigHandler) Update(c *gin.Context) {
|
func (h *ConfigHandler) Update(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Config map[string]interface{} `json:"config"`
|
Config struct {
|
||||||
|
types.SystemConfig
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Updated bool `json:"updated,omitempty"`
|
||||||
|
} `json:"config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
str := utils.JsonEncode(&data.Config)
|
|
||||||
config := model.Config{Key: data.Key, Config: str}
|
value := utils.JsonEncode(&data.Config)
|
||||||
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
|
config := model.Config{Key: data.Key, Config: value}
|
||||||
|
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Id > 0 {
|
if config.Id > 0 {
|
||||||
config.Config = str
|
config.Config = value
|
||||||
res := h.db.Updates(&config)
|
res := h.DB.Updates(&config)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -51,12 +75,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
// update config cache for AppServer
|
// update config cache for AppServer
|
||||||
var cfg model.Config
|
var cfg model.Config
|
||||||
h.db.Where("marker", data.Key).First(&cfg)
|
h.DB.Where("marker", data.Key).First(&cfg)
|
||||||
var err error
|
var err error
|
||||||
if data.Key == "system" {
|
if data.Key == "system" {
|
||||||
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
||||||
} else if data.Key == "chat" {
|
|
||||||
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
||||||
@@ -72,18 +94,104 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
|||||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||||
key := c.Query("key")
|
key := c.Query("key")
|
||||||
var config model.Config
|
var config model.Config
|
||||||
res := h.db.Where("marker", key).First(&config)
|
res := h.DB.Where("marker", key).First(&config)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var m map[string]interface{}
|
var value map[string]interface{}
|
||||||
err := utils.JsonDecode(config.Config, &m)
|
err := utils.JsonDecode(config.Config, &value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, m)
|
resp.SUCCESS(c, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Active 激活系统
|
||||||
|
func (h *ConfigHandler) Active(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
License string `json:"license"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info, err := host.Info()
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.licenseService.ActiveLicense(data.License, info.HostID)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, info.HostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLicense 获取 License 信息
|
||||||
|
func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||||
|
license := h.licenseService.GetLicense()
|
||||||
|
resp.SUCCESS(c, license)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAppConfig 获取内置配置
|
||||||
|
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
|
||||||
|
resp.SUCCESS(c, gin.H{
|
||||||
|
"mj_plus": h.App.Config.MjPlusConfigs,
|
||||||
|
"mj_proxy": h.App.Config.MjProxyConfigs,
|
||||||
|
"sd": h.App.Config.SdConfigs,
|
||||||
|
"platforms": Platforms,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveDrawingConfig 保存AI绘画配置
|
||||||
|
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Sd []types.StableDiffusionConfig `json:"sd"`
|
||||||
|
MjPlus []types.MjPlusConfig `json:"mj_plus"`
|
||||||
|
MjProxy []types.MjProxyConfig `json:"mj_proxy"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
if configChanged(data.Sd, h.App.Config.SdConfigs) {
|
||||||
|
logger.Debugf("SD 配置变动了")
|
||||||
|
h.App.Config.SdConfigs = data.Sd
|
||||||
|
h.sdServicePool.InitServices(data.Sd)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
|
||||||
|
logger.Debugf("MidJourney 配置变动了")
|
||||||
|
h.App.Config.MjPlusConfigs = data.MjPlus
|
||||||
|
h.App.Config.MjProxyConfigs = data.MjProxy
|
||||||
|
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
|
err := core.SaveConfig(h.App.Config)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "更新配置文档失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func configChanged(c1 interface{}, c2 interface{}) bool {
|
||||||
|
encode1 := utils.JsonEncode(c1)
|
||||||
|
encode2 := utils.JsonEncode(c2)
|
||||||
|
return utils.Md5(encode1) != utils.Md5(encode2)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,38 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
||||||
h := DashboardHandler{db: db}
|
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type statsVo struct {
|
type statsVo struct {
|
||||||
Users int64 `json:"users"`
|
Users int64 `json:"users"`
|
||||||
Chats int64 `json:"chats"`
|
Chats int64 `json:"chats"`
|
||||||
Tokens int `json:"tokens"`
|
Tokens int `json:"tokens"`
|
||||||
Income float64 `json:"income"`
|
Income float64 `json:"income"`
|
||||||
|
Chart map[string]map[string]float64 `json:"chart"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DashboardHandler) Stats(c *gin.Context) {
|
func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||||
@@ -35,37 +41,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
|||||||
var userCount int64
|
var userCount int64
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||||
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
stats.Users = userCount
|
stats.Users = userCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// new chats statistic
|
// new chats statistic
|
||||||
var chatCount int64
|
var chatCount int64
|
||||||
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
stats.Chats = chatCount
|
stats.Chats = chatCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokens took stats
|
// tokens took stats
|
||||||
var historyMessages []model.HistoryMessage
|
var historyMessages []model.ChatMessage
|
||||||
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
||||||
for _, item := range historyMessages {
|
for _, item := range historyMessages {
|
||||||
stats.Tokens += item.Tokens
|
stats.Tokens += item.Tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// 众筹收入
|
// 众筹收入
|
||||||
var rewards []model.Reward
|
var rewards []model.Reward
|
||||||
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards)
|
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
|
||||||
for _, item := range rewards {
|
for _, item := range rewards {
|
||||||
stats.Income += item.Amount
|
stats.Income += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
// 订单收入
|
// 订单收入
|
||||||
var orders []model.Order
|
var orders []model.Order
|
||||||
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
||||||
for _, item := range orders {
|
for _, item := range orders {
|
||||||
stats.Income += item.Amount
|
stats.Income += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 统计7天的订单的图表
|
||||||
|
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
|
||||||
|
var statsChart = make(map[string]map[string]float64)
|
||||||
|
//// 初始化
|
||||||
|
var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
|
||||||
|
for i := 0; i < 7; i++ {
|
||||||
|
var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
|
||||||
|
userStatistic[initTime] = float64(0)
|
||||||
|
historyMessagesStatistic[initTime] = float64(0)
|
||||||
|
incomeStatistic[initTime] = float64(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计用户7天增加的曲线
|
||||||
|
var users []model.User
|
||||||
|
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range users {
|
||||||
|
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计7天Token 消耗
|
||||||
|
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
|
||||||
|
for _, item := range historyMessages {
|
||||||
|
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 浮点数相加?
|
||||||
|
// 统计最近7天的众筹
|
||||||
|
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
|
||||||
|
for _, item := range rewards {
|
||||||
|
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计最近7天的订单
|
||||||
|
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
||||||
|
for _, item := range orders {
|
||||||
|
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
statsChart["users"] = userStatistic
|
||||||
|
statsChart["historyMessage"] = historyMessagesStatistic
|
||||||
|
statsChart["orders"] = incomeStatistic
|
||||||
|
|
||||||
|
stats.Chart = statsChart
|
||||||
|
|
||||||
resp.SUCCESS(c, stats)
|
resp.SUCCESS(c, stats)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,20 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
@@ -17,13 +24,10 @@ import (
|
|||||||
|
|
||||||
type FunctionHandler struct {
|
type FunctionHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
||||||
h := FunctionHandler{db: db}
|
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *FunctionHandler) Save(c *gin.Context) {
|
func (h *FunctionHandler) Save(c *gin.Context) {
|
||||||
@@ -44,7 +48,7 @@ func (h *FunctionHandler) Save(c *gin.Context) {
|
|||||||
Enabled: data.Enabled,
|
Enabled: data.Enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Save(&f)
|
res := h.DB.Save(&f)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "error with save data:"+res.Error.Error())
|
resp.ERROR(c, "error with save data:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -65,8 +69,9 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -75,7 +80,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *FunctionHandler) List(c *gin.Context) {
|
func (h *FunctionHandler) List(c *gin.Context) {
|
||||||
var items []model.Function
|
var items []model.Function
|
||||||
res := h.db.Find(&items)
|
res := h.DB.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No data found")
|
resp.ERROR(c, "No data found")
|
||||||
return
|
return
|
||||||
@@ -97,8 +102,9 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.db.Delete(&model.Function{Id: uint(id)})
|
res := h.DB.Delete(&model.Function{Id: uint(id)})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
132
api/handler/admin/menu_handler.go
Normal file
132
api/handler/admin/menu_handler.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MenuHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||||
|
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MenuHandler) Save(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
SortNum int `json:"sort_num"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Save(&model.Menu{
|
||||||
|
Id: data.Id,
|
||||||
|
Name: data.Name,
|
||||||
|
Icon: data.Icon,
|
||||||
|
URL: data.URL,
|
||||||
|
SortNum: data.SortNum,
|
||||||
|
Enabled: data.Enabled,
|
||||||
|
})
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 数据列表
|
||||||
|
func (h *MenuHandler) List(c *gin.Context) {
|
||||||
|
var items []model.Menu
|
||||||
|
var list = make([]vo.Menu, 0)
|
||||||
|
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range items {
|
||||||
|
var product vo.Menu
|
||||||
|
err := utils.CopyObject(item, &product)
|
||||||
|
if err == nil {
|
||||||
|
list = append(list, product)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MenuHandler) Enable(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MenuHandler) Sort(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Ids []uint `json:"ids"`
|
||||||
|
Sorts []int `json:"sorts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, id := range data.Ids {
|
||||||
|
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MenuHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
|
if id > 0 {
|
||||||
|
res := h.DB.Where("id", id).Delete(&model.Menu{})
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
@@ -1,31 +1,37 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OrderHandler struct {
|
type OrderHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||||
h := OrderHandler{db: db}
|
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OrderHandler) List(c *gin.Context) {
|
func (h *OrderHandler) List(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
OrderNo string `json:"order_no"`
|
OrderNo string `json:"order_no"`
|
||||||
|
Status int `json:"status"`
|
||||||
PayTime []string `json:"pay_time"`
|
PayTime []string `json:"pay_time"`
|
||||||
Page int `json:"page"`
|
Page int `json:"page"`
|
||||||
PageSize int `json:"page_size"`
|
PageSize int `json:"page_size"`
|
||||||
@@ -35,7 +41,7 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if data.OrderNo != "" {
|
if data.OrderNo != "" {
|
||||||
session = session.Where("order_no", data.OrderNo)
|
session = session.Where("order_no", data.OrderNo)
|
||||||
}
|
}
|
||||||
@@ -44,8 +50,9 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
|
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
|
||||||
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
|
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
|
||||||
}
|
}
|
||||||
session = session.Where("status = ?", types.OrderPaidSuccess)
|
if data.Status >= 0 {
|
||||||
|
session = session.Where("status", data.Status)
|
||||||
|
}
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.Order{}).Count(&total)
|
session.Model(&model.Order{}).Count(&total)
|
||||||
var items []model.Order
|
var items []model.Order
|
||||||
@@ -74,7 +81,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
|||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
var item model.Order
|
var item model.Order
|
||||||
res := h.db.First(&item, id)
|
res := h.DB.First(&item, id)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "记录不存在!")
|
resp.ERROR(c, "记录不存在!")
|
||||||
return
|
return
|
||||||
@@ -85,8 +92,9 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Where("id = ?", id).Delete(&model.Order{})
|
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
84
api/handler/admin/power_log_handler.go
Normal file
84
api/handler/admin/power_log_handler.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PowerLogHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||||
|
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Type int `json:"type"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Date []string `json:"date"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Model != "" {
|
||||||
|
session = session.Where("model", data.Model)
|
||||||
|
}
|
||||||
|
if data.Type > 0 {
|
||||||
|
session = session.Where("type", data.Type)
|
||||||
|
}
|
||||||
|
if len(data.Date) == 2 {
|
||||||
|
start := data.Date[0] + " 00:00:00"
|
||||||
|
end := data.Date[1] + " 00:00:00"
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.PowerLog{}).Count(&total)
|
||||||
|
var items []model.PowerLog
|
||||||
|
var list = make([]vo.PowerLog, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range items {
|
||||||
|
var log vo.PowerLog
|
||||||
|
err := utils.CopyObject(item, &log)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Id = item.Id
|
||||||
|
log.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
log.TypeStr = item.Type.String()
|
||||||
|
list = append(list, log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计消费算力总和
|
||||||
|
var totalPower float64
|
||||||
|
if len(data.Date) == 2 {
|
||||||
|
session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
|
||||||
|
}
|
||||||
@@ -1,13 +1,20 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/handler"
|
"geekai/handler"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
"time"
|
||||||
@@ -15,13 +22,10 @@ import (
|
|||||||
|
|
||||||
type ProductHandler struct {
|
type ProductHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||||
h := ProductHandler{db: db}
|
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProductHandler) Save(c *gin.Context) {
|
func (h *ProductHandler) Save(c *gin.Context) {
|
||||||
@@ -32,8 +36,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
Discount float64 `json:"discount"`
|
Discount float64 `json:"discount"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Days int `json:"days"`
|
Days int `json:"days"`
|
||||||
Calls int `json:"calls"`
|
Power int `json:"power"`
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
@@ -46,15 +49,15 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
Price: data.Price,
|
Price: data.Price,
|
||||||
Discount: data.Discount,
|
Discount: data.Discount,
|
||||||
Days: data.Days,
|
Days: data.Days,
|
||||||
Calls: data.Calls,
|
Power: data.Power,
|
||||||
ImgCalls: data.ImgCalls,
|
|
||||||
Enabled: data.Enabled}
|
Enabled: data.Enabled}
|
||||||
item.Id = data.Id
|
item.Id = data.Id
|
||||||
if item.Id > 0 {
|
if item.Id > 0 {
|
||||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
}
|
}
|
||||||
res := h.db.Save(&item)
|
res := h.DB.Save(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -70,16 +73,11 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, itemVo)
|
resp.SUCCESS(c, itemVo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 数据列表
|
||||||
func (h *ProductHandler) List(c *gin.Context) {
|
func (h *ProductHandler) List(c *gin.Context) {
|
||||||
session := h.db.Session(&gorm.Session{})
|
|
||||||
enable := h.GetBool(c, "enable")
|
|
||||||
if enable {
|
|
||||||
session = session.Where("enabled", enable)
|
|
||||||
}
|
|
||||||
var items []model.Product
|
var items []model.Product
|
||||||
var list = make([]vo.Product, 0)
|
var list = make([]vo.Product, 0)
|
||||||
res := session.Order("sort_num ASC").Find(&items)
|
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var product vo.Product
|
var product vo.Product
|
||||||
@@ -108,8 +106,9 @@ func (h *ProductHandler) Enable(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
|
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -128,8 +127,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -142,8 +142,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.Product{})
|
res := h.DB.Where("id", id).Delete(&model.Product{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,30 +1,35 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/handler"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/handler"
|
||||||
"chatplus/store/vo"
|
"geekai/store/model"
|
||||||
"chatplus/utils"
|
"geekai/store/vo"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RewardHandler struct {
|
type RewardHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
||||||
h := RewardHandler{db: db}
|
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RewardHandler) List(c *gin.Context) {
|
func (h *RewardHandler) List(c *gin.Context) {
|
||||||
var items []model.Reward
|
var items []model.Reward
|
||||||
res := h.db.Order("id DESC").Find(&items)
|
res := h.DB.Order("id DESC").Find(&items)
|
||||||
var rewards = make([]vo.Reward, 0)
|
var rewards = make([]vo.Reward, 0)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
userIds := make([]uint, 0)
|
userIds := make([]uint, 0)
|
||||||
@@ -32,7 +37,7 @@ func (h *RewardHandler) List(c *gin.Context) {
|
|||||||
userIds = append(userIds, v.UserId)
|
userIds = append(userIds, v.UserId)
|
||||||
}
|
}
|
||||||
var users []model.User
|
var users []model.User
|
||||||
h.db.Where("id IN ?", userIds).Find(&users)
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
var userMap = make(map[uint]model.User)
|
var userMap = make(map[uint]model.User)
|
||||||
for _, u := range users {
|
for _, u := range users {
|
||||||
userMap[u.Id] = u
|
userMap[u.Id] = u
|
||||||
@@ -57,11 +62,17 @@ func (h *RewardHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RewardHandler) Remove(c *gin.Context) {
|
func (h *RewardHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
var data struct {
|
||||||
|
Id uint
|
||||||
if id > 0 {
|
}
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.Reward{})
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data.Id > 0 {
|
||||||
|
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
12
api/handler/admin/types.go
Normal file
12
api/handler/admin/types.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import "geekai/core/types"
|
||||||
|
|
||||||
|
var Platforms = []types.Platform{
|
||||||
|
types.OpenAI,
|
||||||
|
types.QWen,
|
||||||
|
types.XunFei,
|
||||||
|
types.ChatGLM,
|
||||||
|
types.Baidu,
|
||||||
|
types.Azure,
|
||||||
|
}
|
||||||
52
api/handler/admin/upload_handler.go
Normal file
52
api/handler/admin/upload_handler.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UploadHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
uploaderManager *oss.UploaderManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
||||||
|
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||||
|
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId := 0
|
||||||
|
res := h.DB.Create(&model.File{
|
||||||
|
UserId: userId,
|
||||||
|
Name: file.Name,
|
||||||
|
ObjKey: file.ObjKey,
|
||||||
|
URL: file.URL,
|
||||||
|
Ext: file.Ext,
|
||||||
|
Size: file.Size,
|
||||||
|
CreatedAt: time.Time{},
|
||||||
|
})
|
||||||
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "error with update database: "+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, file)
|
||||||
|
}
|
||||||
@@ -1,26 +1,35 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"fmt"
|
||||||
"chatplus/core/types"
|
"geekai/core"
|
||||||
"chatplus/handler"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/handler"
|
||||||
"chatplus/store/vo"
|
"geekai/service"
|
||||||
"chatplus/utils"
|
"geekai/store/model"
|
||||||
"chatplus/utils/resp"
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
licenseService *service.LicenseService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
|
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
|
||||||
h := UserHandler{db: db}
|
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 用户列表
|
// List 用户列表
|
||||||
@@ -34,7 +43,7 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
var users = make([]vo.User, 0)
|
var users = make([]vo.User, 0)
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if username != "" {
|
if username != "" {
|
||||||
session = session.Where("username LIKE ?", "%"+username+"%")
|
session = session.Where("username LIKE ?", "%"+username+"%")
|
||||||
}
|
}
|
||||||
@@ -63,57 +72,84 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Mobile string `json:"mobile"`
|
Username string `json:"username"`
|
||||||
Calls int `json:"calls"`
|
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
ChatRoles []string `json:"chat_roles"`
|
ChatRoles []string `json:"chat_roles"`
|
||||||
ChatModels []string `json:"chat_models"`
|
ChatModels []int `json:"chat_models"`
|
||||||
ExpiredTime string `json:"expired_time"`
|
ExpiredTime string `json:"expired_time"`
|
||||||
Status bool `json:"status"`
|
Status bool `json:"status"`
|
||||||
Vip bool `json:"vip"`
|
Vip bool `json:"vip"`
|
||||||
|
Power int `json:"power"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 检测最大注册人数
|
||||||
|
var totalUser int64
|
||||||
|
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||||
|
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||||
|
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||||
|
return
|
||||||
|
}
|
||||||
var user = model.User{}
|
var user = model.User{}
|
||||||
var res *gorm.DB
|
var res *gorm.DB
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
if data.Id > 0 { // 更新
|
if data.Id > 0 { // 更新
|
||||||
user.Id = data.Id
|
res = h.DB.Where("id", data.Id).First(&user)
|
||||||
// 此处需要用 map 更新,用结构体无法更新 0 值
|
if res.Error != nil {
|
||||||
res = h.db.Model(&user).Updates(map[string]interface{}{
|
resp.ERROR(c, "user not found")
|
||||||
"mobile": data.Mobile,
|
return
|
||||||
"calls": data.Calls,
|
}
|
||||||
"img_calls": data.ImgCalls,
|
var oldPower = user.Power
|
||||||
"status": data.Status,
|
user.Username = data.Username
|
||||||
"vip": data.Vip,
|
user.Status = data.Status
|
||||||
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
|
user.Vip = data.Vip
|
||||||
"chat_models_json": utils.JsonEncode(data.ChatModels),
|
user.Power = data.Power
|
||||||
"expired_time": utils.Str2stamp(data.ExpiredTime),
|
user.ChatRoles = utils.JsonEncode(data.ChatRoles)
|
||||||
})
|
user.ChatModels = utils.JsonEncode(data.ChatModels)
|
||||||
|
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
|
||||||
|
|
||||||
|
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 记录算力日志
|
||||||
|
if oldPower != user.Power {
|
||||||
|
mark := types.PowerAdd
|
||||||
|
amount := user.Power - oldPower
|
||||||
|
if oldPower > user.Power {
|
||||||
|
mark = types.PowerSub
|
||||||
|
amount = oldPower - user.Power
|
||||||
|
}
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerGift,
|
||||||
|
Amount: amount,
|
||||||
|
Balance: user.Power,
|
||||||
|
Mark: mark,
|
||||||
|
Model: "管理员",
|
||||||
|
Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID:%d", oldPower, user.Power, h.GetLoginUserId(c)),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
salt := utils.RandString(8)
|
salt := utils.RandString(8)
|
||||||
u := model.User{
|
u := model.User{
|
||||||
Username: data.Mobile,
|
Username: data.Username,
|
||||||
|
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
|
||||||
Password: utils.GenPassword(data.Password, salt),
|
Password: utils.GenPassword(data.Password, salt),
|
||||||
Avatar: "/images/avatar/user.png",
|
Avatar: "/images/avatar/user.png",
|
||||||
Salt: salt,
|
Salt: salt,
|
||||||
|
Power: data.Power,
|
||||||
Status: true,
|
Status: true,
|
||||||
ChatRoles: utils.JsonEncode(data.ChatRoles),
|
ChatRoles: utils.JsonEncode(data.ChatRoles),
|
||||||
ChatModels: utils.JsonEncode(data.ChatModels),
|
ChatModels: utils.JsonEncode(data.ChatModels),
|
||||||
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
||||||
ChatConfig: utils.JsonEncode(types.UserChatConfig{
|
|
||||||
ApiKeys: map[types.Platform]string{
|
|
||||||
types.OpenAI: "",
|
|
||||||
types.Azure: "",
|
|
||||||
types.ChatGLM: "",
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
Calls: data.Calls,
|
|
||||||
ImgCalls: data.ImgCalls,
|
|
||||||
}
|
}
|
||||||
res = h.db.Create(&u)
|
res = h.DB.Create(&u)
|
||||||
_ = utils.CopyObject(u, &userVo)
|
_ = utils.CopyObject(u, &userVo)
|
||||||
userVo.Id = u.Id
|
userVo.Id = u.Id
|
||||||
userVo.CreatedAt = u.CreatedAt.Unix()
|
userVo.CreatedAt = u.CreatedAt.Unix()
|
||||||
@@ -121,6 +157,7 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -140,7 +177,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.First(&user, data.Id)
|
res := h.DB.First(&user, data.Id)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No user found")
|
resp.ERROR(c, "No user found")
|
||||||
return
|
return
|
||||||
@@ -148,7 +185,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
user.Password = password
|
user.Password = password
|
||||||
res = h.db.Updates(&user)
|
res = h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c)
|
resp.ERROR(c)
|
||||||
} else {
|
} else {
|
||||||
@@ -158,36 +195,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *UserHandler) Remove(c *gin.Context) {
|
func (h *UserHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
if id > 0 {
|
if id <= 0 {
|
||||||
tx := h.db.Begin()
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.User{})
|
return
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 删除聊天记录
|
|
||||||
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 删除聊天历史记录
|
|
||||||
res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{})
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 删除登录日志
|
|
||||||
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
// 删除用户
|
||||||
|
res := h.DB.Where("id = ?", id).Delete(&model.User{})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "删除失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除聊天记录
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
|
||||||
|
// 删除聊天历史记录
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
|
||||||
|
// 删除登录日志
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
|
||||||
|
// 删除算力日志
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
|
||||||
|
// 删除众筹日志
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
|
||||||
|
// 删除绘图任务
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
|
||||||
|
// 删除订单
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,10 +228,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
|
|||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
pageSize := h.GetInt(c, "page_size", 20)
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
var total int64
|
var total int64
|
||||||
h.db.Model(&model.UserLoginLog{}).Count(&total)
|
h.DB.Model(&model.UserLoginLog{}).Count(&total)
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
var items []model.UserLoginLog
|
var items []model.UserLoginLog
|
||||||
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
|
res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "获取数据失败")
|
resp.ERROR(c, "获取数据失败")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
logger2 "chatplus/logger"
|
logger2 "geekai/logger"
|
||||||
"chatplus/utils"
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"gorm.io/gorm"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -15,6 +25,7 @@ var logger = logger2.GetLogger()
|
|||||||
|
|
||||||
type BaseHandler struct {
|
type BaseHandler struct {
|
||||||
App *core.AppServer
|
App *core.AppServer
|
||||||
|
DB *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
|
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
|
||||||
@@ -57,3 +68,27 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
|
|||||||
}
|
}
|
||||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
|
||||||
|
return h.GetLoginUserId(c) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
|
||||||
|
value, exists := c.Get(types.LoginUserCache)
|
||||||
|
if exists {
|
||||||
|
return value.(model.User), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userId, ok := c.Get(types.LoginUserID)
|
||||||
|
if !ok {
|
||||||
|
return model.User{}, errors.New("user not login")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.User
|
||||||
|
res := h.DB.First(&user, userId)
|
||||||
|
// 更新缓存
|
||||||
|
if res.Error == nil {
|
||||||
|
c.Set(types.LoginUserCache, user)
|
||||||
|
}
|
||||||
|
return user, res.Error
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/service"
|
"geekai/service"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,3 +52,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SlideGet 获取滑动验证图片
|
||||||
|
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||||
|
data, err := h.service.SlideGet()
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlideCheck 滑动验证结果校验
|
||||||
|
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
X int `json:"x"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.service.SlideCheck(data) {
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
} else {
|
||||||
|
resp.ERROR(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,48 +1,53 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatModelHandler struct {
|
type ChatModelHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||||
h := ChatModelHandler{db: db}
|
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatModel
|
var items []model.ChatModel
|
||||||
var chatModels = make([]vo.ChatModel, 0)
|
var chatModels = make([]vo.ChatModel, 0)
|
||||||
// 只加载用户订阅的 AI 模型
|
var res *gorm.DB
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
// 如果用户没有登录,则加载所有开放模型
|
||||||
if err != nil {
|
if !h.IsLogin(c) {
|
||||||
resp.NotAuth(c)
|
res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
|
||||||
return
|
} else {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
var models []int
|
||||||
|
err := utils.JsonDecode(user.ChatModels, &models)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "当前用户没有订阅任何模型")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 查询用户有权限访问的模型以及所有开放的模型
|
||||||
|
res = h.DB.Where("enabled = ?", true).Where(
|
||||||
|
h.DB.Where("id IN ?", models).Or("open", true),
|
||||||
|
).Order("sort_num ASC").Find(&items)
|
||||||
}
|
}
|
||||||
|
|
||||||
var models []string
|
|
||||||
err = utils.JsonDecode(user.ChatModels, &models)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "当前用户没有订阅任何模型")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查询用户有权限访问的模型以及所有开放的模型
|
|
||||||
res := h.db.Where("enabled = ?", true).Where(
|
|
||||||
h.db.Where("value IN ?", models).Or("open =?", true),
|
|
||||||
).Order("sort_num ASC").Find(&items)
|
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var cm vo.ChatModel
|
var cm vo.ChatModel
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -14,27 +21,26 @@ import (
|
|||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatRoleHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||||
handler := &ChatRoleHandler{db: db}
|
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
handler.App = app
|
|
||||||
return handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List get user list
|
// List 获取用户聊天应用列表
|
||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||||
all := h.GetBool(c, "all")
|
all := h.GetBool(c, "all")
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatRole
|
||||||
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
var roleVos = make([]vo.ChatRole, 0)
|
||||||
|
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No roles found,"+res.Error.Error())
|
resp.SUCCESS(c, roleVos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取所有角色
|
// 获取所有角色
|
||||||
if all {
|
if userId == 0 || all {
|
||||||
// 转成 vo
|
// 转成 vo
|
||||||
var roleVos = make([]vo.ChatRole, 0)
|
var roleVos = make([]vo.ChatRole, 0)
|
||||||
for _, r := range roles {
|
for _, r := range roles {
|
||||||
@@ -49,21 +55,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
if userId == 0 {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var user model.User
|
var user model.User
|
||||||
h.db.First(&user, userId)
|
h.DB.First(&user, userId)
|
||||||
var roleKeys []string
|
var roleKeys []string
|
||||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "角色解析失败!")
|
resp.ERROR(c, "角色解析失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 转成 vo
|
|
||||||
var roleVos = make([]vo.ChatRole, 0)
|
|
||||||
for _, r := range roles {
|
for _, r := range roles {
|
||||||
if !utils.ContainsStr(roleKeys, r.Key) {
|
if !utils.ContainsStr(roleKeys, r.Key) {
|
||||||
continue
|
continue
|
||||||
@@ -80,7 +80,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
// UpdateRole 更新用户聊天角色
|
// UpdateRole 更新用户聊天角色
|
||||||
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
@@ -94,9 +94,9 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
|
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("添加应用失败:", err)
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,31 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 微软 Azure 模型消息发送实现
|
// 微软 Azure 模型消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendAzureMessage(
|
func (h *ChatHandler) sendAzureMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -30,21 +36,14 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = model.ApiKey{}
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
logger.Info("用户取消了请求:", prompt)
|
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||||
return nil
|
|
||||||
} else if strings.Contains(err.Error(), "no available key") {
|
} else if strings.Contains(err.Error(), "no available key") {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -66,10 +65,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
var responseBody = types.ApiResponse{}
|
var responseBody = types.ApiResponse{}
|
||||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||||
if err != nil { // 数据解析出错
|
if err != nil { // 数据解析出错
|
||||||
logger.Error(err, line)
|
return errors.New(line)
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(responseBody.Choices) == 0 {
|
if len(responseBody.Choices) == 0 {
|
||||||
@@ -103,106 +99,12 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
|
||||||
if h.App.ChatConfig.EnableContext {
|
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加聊天记录
|
|
||||||
if h.App.ChatConfig.EnableHistory {
|
|
||||||
// for prompt
|
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens += getTotalTokens(req)
|
|
||||||
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存当前会话
|
|
||||||
var chatItem model.ChatItem
|
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
|
||||||
if res.Error != nil {
|
|
||||||
chatItem.ChatId = session.ChatId
|
|
||||||
chatItem.UserId = session.UserId
|
|
||||||
chatItem.RoleId = role.Id
|
|
||||||
chatItem.ModelId = session.Model.Id
|
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
|
||||||
} else {
|
|
||||||
chatItem.Title = prompt
|
|
||||||
}
|
|
||||||
h.db.Create(&chatItem)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
body, err := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
return fmt.Errorf("请求大模型 API 失败:%s", body)
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
var res types.ApiError
|
|
||||||
err = json.Unmarshal(body, &res)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(res.Error.Message, "maximum context length") {
|
|
||||||
logger.Error(res.Error.Message)
|
|
||||||
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
|
||||||
h.App.ChatContexts.Delete(session.ChatId)
|
|
||||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
|
||||||
} else {
|
|
||||||
utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,20 +1,26 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type baiduResp struct {
|
type baiduResp struct {
|
||||||
@@ -36,7 +42,7 @@ type baiduResp struct {
|
|||||||
// 百度文心一言消息发送实现
|
// 百度文心一言消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendBaiduMessage(
|
func (h *ChatHandler) sendBaiduMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -47,21 +53,15 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = model.ApiKey{}
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
logger.Info("用户取消了请求:", prompt)
|
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||||
return nil
|
|
||||||
} else if strings.Contains(err.Error(), "no available key") {
|
} else if strings.Contains(err.Error(), "no available key") {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -128,101 +128,11 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
|
||||||
if h.App.ChatConfig.EnableContext {
|
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加聊天记录
|
|
||||||
if h.App.ChatConfig.EnableHistory {
|
|
||||||
// for prompt
|
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存当前会话
|
|
||||||
var chatItem model.ChatItem
|
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
|
||||||
if res.Error != nil {
|
|
||||||
chatItem.ChatId = session.ChatId
|
|
||||||
chatItem.UserId = session.UserId
|
|
||||||
chatItem.RoleId = role.Id
|
|
||||||
chatItem.ModelId = session.Model.Id
|
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
|
||||||
} else {
|
|
||||||
chatItem.Title = prompt
|
|
||||||
}
|
|
||||||
h.db.Create(&chatItem)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
body, err := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
return fmt.Errorf("请求大模型 API 失败:%s", body)
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var res struct {
|
|
||||||
Code int `json:"error_code"`
|
|
||||||
Msg string `json:"error_msg"`
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(body, &res)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,23 +1,35 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/handler"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
@@ -25,28 +37,24 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
|
||||||
const ErrImg = ""
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
redis *redis.Client
|
||||||
redis *redis.Client
|
uploadManager *oss.UploaderManager
|
||||||
|
licenseService *service.LicenseService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
|
||||||
h := ChatHandler{
|
return &ChatHandler{
|
||||||
db: db,
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
redis: redis,
|
redis: redis,
|
||||||
|
uploadManager: manager,
|
||||||
|
licenseService: licenseService,
|
||||||
}
|
}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var chatConfig types.ChatConfig
|
|
||||||
|
|
||||||
// ChatHandle 处理聊天 WebSocket 请求
|
// ChatHandle 处理聊天 WebSocket 请求
|
||||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
@@ -61,9 +69,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
modelId := h.GetInt(c, "model_id", 0)
|
modelId := h.GetInt(c, "model_id", 0)
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
client := types.NewWsClient(ws)
|
||||||
|
var chatRole model.ChatRole
|
||||||
|
res := h.DB.First(&chatRole, roleId)
|
||||||
|
if res.Error != nil || !chatRole.Enable {
|
||||||
|
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// if the role bind a model_id, use role's bind model_id
|
||||||
|
if chatRole.ModelId > 0 {
|
||||||
|
modelId = chatRole.ModelId
|
||||||
|
}
|
||||||
// get model info
|
// get model info
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
res := h.db.First(&chatModel, modelId)
|
res = h.DB.First(&chatModel, modelId)
|
||||||
if res.Error != nil || chatModel.Enabled == false {
|
if res.Error != nil || chatModel.Enabled == false {
|
||||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@@ -72,7 +91,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
|
|
||||||
session := h.App.ChatSession.Get(sessionId)
|
session := h.App.ChatSession.Get(sessionId)
|
||||||
if session == nil {
|
if session == nil {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("用户未登录")
|
logger.Info("用户未登录")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@@ -89,7 +108,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
|
|
||||||
// use old chat data override the chat model and role ID
|
// use old chat data override the chat model and role ID
|
||||||
var chat model.ChatItem
|
var chat model.ChatItem
|
||||||
res = h.db.Where("chat_id=?", chatId).First(&chat)
|
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
chatModel.Id = chat.ModelId
|
chatModel.Id = chat.ModelId
|
||||||
roleId = int(chat.RoleId)
|
roleId = int(chat.RoleId)
|
||||||
@@ -97,28 +116,16 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
|
|
||||||
session.ChatId = chatId
|
session.ChatId = chatId
|
||||||
session.Model = types.ChatModel{
|
session.Model = types.ChatModel{
|
||||||
Id: chatModel.Id,
|
Id: chatModel.Id,
|
||||||
Value: chatModel.Value,
|
Name: chatModel.Name,
|
||||||
Weight: chatModel.Weight,
|
Value: chatModel.Value,
|
||||||
Platform: types.Platform(chatModel.Platform)}
|
Power: chatModel.Power,
|
||||||
|
MaxTokens: chatModel.MaxTokens,
|
||||||
|
MaxContext: chatModel.MaxContext,
|
||||||
|
Temperature: chatModel.Temperature,
|
||||||
|
KeyId: chatModel.KeyId,
|
||||||
|
Platform: chatModel.Platform}
|
||||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||||
var chatRole model.ChatRole
|
|
||||||
res = h.db.First(&chatRole, roleId)
|
|
||||||
if res.Error != nil || !chatRole.Enable {
|
|
||||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化聊天配置
|
|
||||||
var config model.Config
|
|
||||||
h.db.Where("marker", "chat").First(&config)
|
|
||||||
err = utils.JsonDecode(config.Config, &chatConfig)
|
|
||||||
if err != nil {
|
|
||||||
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存会话连接
|
// 保存会话连接
|
||||||
h.App.ChatClients.Put(sessionId, client)
|
h.App.ChatClients.Put(sessionId, client)
|
||||||
@@ -126,9 +133,10 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
_, msg, err := client.Receive()
|
_, msg, err := client.Receive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||||
client.Close()
|
client.Close()
|
||||||
h.App.ChatClients.Delete(sessionId)
|
h.App.ChatClients.Delete(sessionId)
|
||||||
|
h.App.ChatSession.Delete(sessionId)
|
||||||
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
|
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
|
||||||
if cancelFunc != nil {
|
if cancelFunc != nil {
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
@@ -137,19 +145,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
message := string(msg)
|
var message types.WsMessage
|
||||||
logger.Info("Receive a message: ", message)
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
//utils.ReplyMessage(client, "这是一条测试消息!")
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 心跳消息
|
||||||
|
if message.Type == "heartbeat" {
|
||||||
|
logger.Debug("收到 Chat 心跳消息:", message.Content)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Receive a message: ", message.Content)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
||||||
// 回复消息
|
// 回复消息
|
||||||
err = h.sendMessage(ctx, session, chatRole, message, client)
|
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
utils.ReplyMessage(client, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||||
logger.Info("回答完毕: " + string(message))
|
logger.Infof("回答完毕: %v", message.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -166,10 +185,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Model(&model.User{}).First(&user, session.UserId)
|
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
|
return errors.New("未授权用户,您正在进行非法操作!")
|
||||||
return res.Error
|
|
||||||
}
|
}
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
err := utils.CopyObject(user, &userVo)
|
err := utils.CopyObject(user, &userVo)
|
||||||
@@ -179,124 +197,93 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userVo.Status == false {
|
if userVo.Status == false {
|
||||||
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userVo.Calls < session.Model.Weight {
|
if userVo.Power < session.Model.Power {
|
||||||
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d)已不足以支付当前模型的单次对话需要消耗的对话额度(%d)!", userVo.Calls, session.Model.Weight))
|
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power)
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
|
||||||
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||||
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
|
return errors.New("您的账号已经过期,请联系管理员!")
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||||
|
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||||
|
if promptTokens > session.Model.MaxContext {
|
||||||
|
|
||||||
|
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
|
||||||
|
}
|
||||||
|
|
||||||
var req = types.ApiRequest{
|
var req = types.ApiRequest{
|
||||||
Model: session.Model.Value,
|
Model: session.Model.Value,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
}
|
||||||
switch session.Model.Platform {
|
switch session.Model.Platform {
|
||||||
case types.Azure:
|
case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
|
||||||
req.Temperature = h.App.ChatConfig.Azure.Temperature
|
req.Temperature = session.Model.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
|
req.MaxTokens = session.Model.MaxTokens
|
||||||
break
|
break
|
||||||
case types.ChatGLM:
|
case types.OpenAI.Value:
|
||||||
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
|
req.Temperature = session.Model.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
|
req.MaxTokens = session.Model.MaxTokens
|
||||||
break
|
|
||||||
case types.Baidu:
|
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
|
||||||
// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
|
|
||||||
break
|
|
||||||
case types.OpenAI:
|
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
|
||||||
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
|
||||||
// OpenAI 支持函数功能
|
// OpenAI 支持函数功能
|
||||||
var items []model.Function
|
var items []model.Function
|
||||||
res := h.db.Where("enabled", true).Find(&items)
|
res := h.DB.Where("enabled", true).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools = make([]interface{}, 0)
|
var tools = make([]types.Tool, 0)
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var parameters map[string]interface{}
|
var parameters map[string]interface{}
|
||||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
required := parameters["required"]
|
tool := types.Tool{
|
||||||
delete(parameters, "required")
|
Type: "function",
|
||||||
tools = append(tools, gin.H{
|
Function: types.Function{
|
||||||
"type": "function",
|
Name: v.Name,
|
||||||
"function": gin.H{
|
Description: v.Description,
|
||||||
"name": v.Name,
|
Parameters: parameters,
|
||||||
"description": v.Description,
|
|
||||||
"parameters": parameters,
|
|
||||||
"required": required,
|
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
if v, ok := parameters["required"]; v == nil || !ok {
|
||||||
|
tool.Function.Parameters["required"] = []string{}
|
||||||
|
}
|
||||||
|
tools = append(tools, tool)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
req.Tools = tools
|
req.Tools = tools
|
||||||
req.ToolChoice = "auto"
|
req.ToolChoice = "auto"
|
||||||
}
|
}
|
||||||
case types.XunFei:
|
case types.QWen.Value:
|
||||||
req.Temperature = h.App.ChatConfig.XunFei.Temperature
|
req.Parameters = map[string]interface{}{
|
||||||
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
|
"max_tokens": session.Model.MaxTokens,
|
||||||
|
"temperature": session.Model.Temperature,
|
||||||
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
default:
|
default:
|
||||||
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载聊天上下文
|
// 加载聊天上下文
|
||||||
var chatCtx []interface{}
|
chatCtx := make([]types.Message, 0)
|
||||||
if h.App.ChatConfig.EnableContext {
|
messages := make([]types.Message, 0)
|
||||||
|
if h.App.SysConfig.EnableContext {
|
||||||
if h.App.ChatContexts.Has(session.ChatId) {
|
if h.App.ChatContexts.Has(session.ChatId) {
|
||||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
messages = h.App.ChatContexts.Get(session.ChatId)
|
||||||
} else {
|
} else {
|
||||||
// calculate the tokens of current request, to prevent to exceeding the max tokens num
|
_ = utils.JsonDecode(role.Context, &messages)
|
||||||
tokens := req.MaxTokens
|
if h.App.SysConfig.ContextDeep > 0 {
|
||||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
var historyMessages []model.ChatMessage
|
||||||
tokens += tks
|
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||||
// loading the role context
|
|
||||||
var messages []types.Message
|
|
||||||
err := utils.JsonDecode(role.Context, &messages)
|
|
||||||
if err == nil {
|
|
||||||
for _, v := range messages {
|
|
||||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
|
||||||
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens += tks
|
|
||||||
chatCtx = append(chatCtx, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loading recent chat history as chat context
|
|
||||||
if chatConfig.ContextDeep > 0 {
|
|
||||||
var historyMessages []model.HistoryMessage
|
|
||||||
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
|
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||||
msg := historyMessages[i]
|
msg := historyMessages[i]
|
||||||
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens += msg.Tokens
|
|
||||||
ms := types.Message{Role: "user", Content: msg.Content}
|
ms := types.Message{Role: "user", Content: msg.Content}
|
||||||
if msg.Type == types.ReplyMsg {
|
if msg.Type == types.ReplyMsg {
|
||||||
ms.Role = "assistant"
|
ms.Role = "assistant"
|
||||||
@@ -306,6 +293,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
|
||||||
|
// MaxContextLength = Response + Tool + Prompt + Context
|
||||||
|
tokens := req.MaxTokens // 最大响应长度
|
||||||
|
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||||
|
tokens += tks + promptTokens
|
||||||
|
|
||||||
|
for _, v := range messages {
|
||||||
|
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||||
|
// 上下文 token 超出了模型的最大上下文长度
|
||||||
|
if tokens+tks >= session.Model.MaxContext {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上下文的深度超出了模型的最大上下文深度
|
||||||
|
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens += tks
|
||||||
|
chatCtx = append(chatCtx, v)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||||
}
|
}
|
||||||
reqMgs := make([]interface{}, 0)
|
reqMgs := make([]interface{}, 0)
|
||||||
@@ -313,28 +323,65 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
reqMgs = append(reqMgs, m)
|
reqMgs = append(reqMgs, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Messages = append(reqMgs, map[string]interface{}{
|
if session.Model.Platform == types.QWen.Value {
|
||||||
"role": "user",
|
req.Input = make(map[string]interface{})
|
||||||
"content": prompt,
|
reqMgs = append(reqMgs, types.Message{
|
||||||
})
|
Role: "user",
|
||||||
|
Content: prompt,
|
||||||
|
})
|
||||||
|
req.Input["messages"] = reqMgs
|
||||||
|
} else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model
|
||||||
|
imgURLs := utils.ExtractImgURL(prompt)
|
||||||
|
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||||
|
var content interface{}
|
||||||
|
if len(imgURLs) > 0 {
|
||||||
|
data := make([]interface{}, 0)
|
||||||
|
text := prompt
|
||||||
|
for _, v := range imgURLs {
|
||||||
|
text = strings.Replace(text, v, "", 1)
|
||||||
|
data = append(data, gin.H{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": gin.H{
|
||||||
|
"url": v,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
data = append(data, gin.H{
|
||||||
|
"type": "text",
|
||||||
|
"text": text,
|
||||||
|
})
|
||||||
|
content = data
|
||||||
|
} else {
|
||||||
|
content = prompt
|
||||||
|
}
|
||||||
|
req.Messages = append(reqMgs, map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
req.Messages = append(reqMgs, map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("%+v", req.Messages)
|
||||||
|
|
||||||
switch session.Model.Platform {
|
switch session.Model.Platform {
|
||||||
case types.Azure:
|
case types.Azure.Value:
|
||||||
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
case types.OpenAI:
|
case types.OpenAI.Value:
|
||||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
case types.ChatGLM:
|
case types.ChatGLM.Value:
|
||||||
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
case types.Baidu:
|
case types.Baidu.Value:
|
||||||
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
case types.XunFei:
|
case types.XunFei.Value:
|
||||||
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
|
case types.QWen.Value:
|
||||||
|
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform),
|
|
||||||
})
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -352,9 +399,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
|||||||
|
|
||||||
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
|
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
|
||||||
if data.Text == "" && data.ChatId != "" {
|
if data.Text == "" && data.ChatId != "" {
|
||||||
var item model.HistoryMessage
|
var item model.ChatMessage
|
||||||
userId, _ := c.Get(types.LoginUserID)
|
userId, _ := c.Get(types.LoginUserID)
|
||||||
res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
|
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -404,32 +451,52 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
|||||||
|
|
||||||
// 发送请求到 OpenAI 服务器
|
// 发送请求到 OpenAI 服务器
|
||||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||||
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
// if the chat model bind a KEY, use it directly
|
||||||
if res.Error != nil {
|
if session.Model.KeyId > 0 {
|
||||||
|
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
|
||||||
|
}
|
||||||
|
// use the last unused key
|
||||||
|
if apiKey.Id == 0 {
|
||||||
|
h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||||
|
}
|
||||||
|
if apiKey.Id == 0 {
|
||||||
return nil, errors.New("no available key, please import key")
|
return nil, errors.New("no available key, please import key")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ONLY allow apiURL in blank list
|
||||||
|
if session.Model.Platform == types.OpenAI.Value {
|
||||||
|
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var apiURL string
|
var apiURL string
|
||||||
switch platform {
|
switch session.Model.Platform {
|
||||||
case types.Azure:
|
case types.Azure.Value:
|
||||||
md := strings.Replace(req.Model, ".", "", 1)
|
md := strings.Replace(req.Model, ".", "", 1)
|
||||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
||||||
break
|
break
|
||||||
case types.ChatGLM:
|
case types.ChatGLM.Value:
|
||||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||||
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
||||||
req.Messages = nil
|
req.Messages = nil
|
||||||
break
|
break
|
||||||
case types.Baidu:
|
case types.Baidu.Value:
|
||||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||||
break
|
break
|
||||||
|
case types.QWen.Value:
|
||||||
|
apiURL = apiKey.ApiURL
|
||||||
|
req.Messages = nil
|
||||||
|
break
|
||||||
default:
|
default:
|
||||||
apiURL = apiKey.ApiURL
|
apiURL = apiKey.ApiURL
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
// 百度文心,需要串接 access_token
|
// 百度文心,需要串接 access_token
|
||||||
if platform == types.Baidu {
|
if session.Model.Platform == types.Baidu.Value {
|
||||||
token, err := h.getBaiduToken(apiKey.Value)
|
token, err := h.getBaiduToken(apiKey.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -438,6 +505,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
|
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debugf(utils.JsonEncode(req))
|
||||||
|
|
||||||
// 创建 HttpClient 请求对象
|
// 创建 HttpClient 请求对象
|
||||||
var client *http.Client
|
var client *http.Client
|
||||||
requestBody, err := json.Marshal(req)
|
requestBody, err := json.Marshal(req)
|
||||||
@@ -451,10 +520,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
|
|
||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
var proxyURL string
|
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||||
if h.App.Config.ProxyURL != "" && apiKey.UseProxy { // 使用代理
|
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||||
proxyURL = h.App.Config.ProxyURL
|
|
||||||
proxy, _ := url.Parse(proxyURL)
|
|
||||||
client = &http.Client{
|
client = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Proxy: http.ProxyURL(proxy),
|
Proxy: http.ProxyURL(proxy),
|
||||||
@@ -463,41 +530,172 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
} else {
|
} else {
|
||||||
client = http.DefaultClient
|
client = http.DefaultClient
|
||||||
}
|
}
|
||||||
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
|
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
|
||||||
switch platform {
|
switch session.Model.Platform {
|
||||||
case types.Azure:
|
case types.Azure.Value:
|
||||||
request.Header.Set("api-key", apiKey.Value)
|
request.Header.Set("api-key", apiKey.Value)
|
||||||
break
|
break
|
||||||
case types.ChatGLM:
|
case types.ChatGLM.Value:
|
||||||
token, err := h.getChatGLMToken(apiKey.Value)
|
token, err := h.getChatGLMToken(apiKey.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
break
|
break
|
||||||
case types.Baidu:
|
case types.Baidu.Value:
|
||||||
request.RequestURI = ""
|
request.RequestURI = ""
|
||||||
case types.OpenAI:
|
case types.OpenAI.Value:
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||||
|
break
|
||||||
|
case types.QWen.Value:
|
||||||
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||||
|
request.Header.Set("X-DashScope-SSE", "enable")
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return client.Do(request)
|
return client.Do(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 扣减用户的对话次数
|
// 扣减用户算力
|
||||||
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
|
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
|
||||||
// 仅当用户没有导入自己的 API KEY 时才进行扣减
|
power := 1
|
||||||
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
if session.Model.Power > 0 {
|
||||||
num := 1
|
power = session.Model.Power
|
||||||
if session.Model.Weight > 0 {
|
}
|
||||||
num = session.Model.Weight
|
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
|
||||||
|
if res.Error == nil {
|
||||||
|
// 记录算力消费日志
|
||||||
|
var u model.User
|
||||||
|
h.DB.Where("id", userVo.Id).First(&u)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
Username: userVo.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Balance: u.Power,
|
||||||
|
Model: session.Model.Value,
|
||||||
|
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ChatHandler) saveChatHistory(
|
||||||
|
req types.ApiRequest,
|
||||||
|
prompt string,
|
||||||
|
contents []string,
|
||||||
|
message types.Message,
|
||||||
|
chatCtx []types.Message,
|
||||||
|
session *types.ChatSession,
|
||||||
|
role model.ChatRole,
|
||||||
|
userVo vo.User,
|
||||||
|
promptCreatedAt time.Time,
|
||||||
|
replyCreatedAt time.Time) {
|
||||||
|
if message.Role == "" {
|
||||||
|
message.Role = "assistant"
|
||||||
|
}
|
||||||
|
message.Content = strings.Join(contents, "")
|
||||||
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
|
if h.App.SysConfig.EnableContext {
|
||||||
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 追加聊天记录
|
||||||
|
// for prompt
|
||||||
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for reply
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens := replyTokens + getTotalTokens(req)
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: totalTokens,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.Model.Power > 0 {
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
|
// 保存当前会话
|
||||||
|
var chatItem model.ChatItem
|
||||||
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
|
if res.Error != nil {
|
||||||
|
chatItem.ChatId = session.ChatId
|
||||||
|
chatItem.UserId = session.UserId
|
||||||
|
chatItem.RoleId = role.Id
|
||||||
|
chatItem.ModelId = session.Model.Id
|
||||||
|
if utf8.RuneCountInString(prompt) > 30 {
|
||||||
|
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||||
|
} else {
|
||||||
|
chatItem.Title = prompt
|
||||||
|
}
|
||||||
|
chatItem.Model = req.Model
|
||||||
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
|
// 将AI回复消息中生成的图片链接下载到本地
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userId).
|
func (h *ChatHandler) extractImgUrl(text string) string {
|
||||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
|
pattern := `!\[([^\]]*)]\(([^)]+)\)`
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userId).
|
re := regexp.MustCompile(pattern)
|
||||||
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
|
matches := re.FindAllStringSubmatch(text, -1)
|
||||||
|
|
||||||
|
// 下载图片并替换链接地址
|
||||||
|
for _, match := range matches {
|
||||||
|
imageURL := match[2]
|
||||||
|
logger.Debug(imageURL)
|
||||||
|
// 对于相同地址的图片,已经被替换了,就不再重复下载了
|
||||||
|
if !strings.Contains(text, imageURL) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
text = strings.ReplaceAll(text, imageURL, newImgURL)
|
||||||
|
}
|
||||||
|
return text
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,41 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// List 获取会话列表
|
// List 获取会话列表
|
||||||
func (h *ChatHandler) List(c *gin.Context) {
|
func (h *ChatHandler) List(c *gin.Context) {
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
if !h.IsLogin(c) {
|
||||||
if userId == 0 {
|
resp.SUCCESS(c)
|
||||||
resp.ERROR(c, "The parameter 'user_id' is needed.")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
var items = make([]vo.ChatItem, 0)
|
var items = make([]vo.ChatItem, 0)
|
||||||
var chats []model.ChatItem
|
var chats []model.ChatItem
|
||||||
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
var roleIds = make([]uint, 0)
|
var roleIds = make([]uint, 0)
|
||||||
for _, chat := range chats {
|
for _, chat := range chats {
|
||||||
roleIds = append(roleIds, chat.RoleId)
|
roleIds = append(roleIds, chat.RoleId)
|
||||||
}
|
}
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatRole
|
||||||
res = h.db.Find(&roles, roleIds)
|
res = h.DB.Find(&roles, roleIds)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
roleMap := make(map[uint]model.ChatRole)
|
roleMap := make(map[uint]model.ChatRole)
|
||||||
for _, role := range roles {
|
for _, role := range roles {
|
||||||
@@ -58,7 +67,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
|
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Failed to update database")
|
resp.ERROR(c, "Failed to update database")
|
||||||
return
|
return
|
||||||
@@ -70,14 +79,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
|
|||||||
// Clear 清空所有聊天记录
|
// Clear 清空所有聊天记录
|
||||||
func (h *ChatHandler) Clear(c *gin.Context) {
|
func (h *ChatHandler) Clear(c *gin.Context) {
|
||||||
// 获取当前登录用户所有的聊天会话
|
// 获取当前登录用户所有的聊天会话
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var chats []model.ChatItem
|
var chats []model.ChatItem
|
||||||
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
|
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No chats found")
|
resp.ERROR(c, "No chats found")
|
||||||
return
|
return
|
||||||
@@ -89,13 +98,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
|||||||
// 清空会话上下文
|
// 清空会话上下文
|
||||||
h.App.ChatContexts.Delete(chat.ChatId)
|
h.App.ChatContexts.Delete(chat.ChatId)
|
||||||
}
|
}
|
||||||
err = h.db.Transaction(func(tx *gorm.DB) error {
|
err = h.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.HistoryMessage{})
|
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
@@ -116,9 +125,9 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
|||||||
// History 获取聊天历史记录
|
// History 获取聊天历史记录
|
||||||
func (h *ChatHandler) History(c *gin.Context) {
|
func (h *ChatHandler) History(c *gin.Context) {
|
||||||
chatId := c.Query("chat_id") // 会话 ID
|
chatId := c.Query("chat_id") // 会话 ID
|
||||||
var items []model.HistoryMessage
|
var items []model.ChatMessage
|
||||||
var messages = make([]vo.HistoryMessage, 0)
|
var messages = make([]vo.HistoryMessage, 0)
|
||||||
res := h.db.Where("chat_id = ?", chatId).Find(&items)
|
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No history message")
|
resp.ERROR(c, "No history message")
|
||||||
return
|
return
|
||||||
@@ -144,20 +153,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
|
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Failed to update database")
|
resp.ERROR(c, "Failed to update database")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 删除当前会话的聊天记录
|
// 删除当前会话的聊天记录
|
||||||
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
|
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Failed to remove chat from database.")
|
resp.ERROR(c, "Failed to remove chat from database.")
|
||||||
return
|
return
|
||||||
@@ -179,18 +188,26 @@ func (h *ChatHandler) Detail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
|
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No chat found")
|
resp.ERROR(c, "No chat found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 填充角色名称
|
||||||
|
var role model.ChatRole
|
||||||
|
res = h.DB.Where("id", chatItem.RoleId).First(&role)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "Role not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var chatItemVo vo.ChatItem
|
var chatItemVo vo.ChatItem
|
||||||
err := utils.CopyObject(chatItem, &chatItemVo)
|
err := utils.CopyObject(chatItem, &chatItemVo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
chatItemVo.RoleName = role.Name
|
||||||
resp.SUCCESS(c, chatItemVo)
|
resp.SUCCESS(c, chatItemVo)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,31 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"html/template"
|
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 清华大学 ChatGML 消息发送实现
|
// 清华大学 ChatGML 消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendChatGLMMessage(
|
func (h *ChatHandler) sendChatGLMMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -31,21 +36,14 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = model.ApiKey{}
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
logger.Info("用户取消了请求:", prompt)
|
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||||
return nil
|
|
||||||
} else if strings.Contains(err.Error(), "no available key") {
|
} else if strings.Contains(err.Error(), "no available key") {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -107,104 +105,11 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
|
||||||
if h.App.ChatConfig.EnableContext {
|
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加聊天记录
|
|
||||||
if h.App.ChatConfig.EnableHistory {
|
|
||||||
// for prompt
|
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存当前会话
|
|
||||||
var chatItem model.ChatItem
|
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
|
||||||
if res.Error != nil {
|
|
||||||
chatItem.ChatId = session.ChatId
|
|
||||||
chatItem.UserId = session.UserId
|
|
||||||
chatItem.RoleId = role.Id
|
|
||||||
chatItem.ModelId = session.Model.Id
|
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
|
||||||
} else {
|
|
||||||
chatItem.Title = prompt
|
|
||||||
}
|
|
||||||
h.db.Create(&chatItem)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
body, err := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
return fmt.Errorf("请求大模型 API 失败:%s", body)
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var res struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(body, &res)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
if !res.Success {
|
|
||||||
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,26 +1,31 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
req2 "github.com/imroc/req/v3"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
req2 "github.com/imroc/req/v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// OPenAI 消息发送实现
|
// OPenAI 消息发送实现
|
||||||
func (h *ChatHandler) sendOpenAiMessage(
|
func (h *ChatHandler) sendOpenAiMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -31,21 +36,14 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = model.ApiKey{}
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
logger.Info("用户取消了请求:", prompt)
|
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||||
return nil
|
|
||||||
} else if strings.Contains(err.Error(), "no available key") {
|
} else if strings.Contains(err.Error(), "no available key") {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -61,6 +59,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
var toolCall = false
|
var toolCall = false
|
||||||
var arguments = make([]string, 0)
|
var arguments = make([]string, 0)
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
|
var isNew = true
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||||
@@ -69,46 +68,64 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
var responseBody = types.ApiResponse{}
|
var responseBody = types.ApiResponse{}
|
||||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
if err != nil { // 数据解析出错
|
||||||
logger.Error(err, line)
|
return errors.New(line)
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
}
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||||
|
utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
var fun types.ToolCall
|
var tool types.ToolCall
|
||||||
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
fun = responseBody.Choices[0].Delta.ToolCalls[0]
|
tool = responseBody.Choices[0].Delta.ToolCalls[0]
|
||||||
if toolCall && fun.Function.Name == "" {
|
if toolCall && tool.Function.Name == "" {
|
||||||
arguments = append(arguments, fun.Function.Arguments)
|
arguments = append(arguments, tool.Function.Arguments)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !utils.IsEmptyValue(fun) {
|
// 兼容 Function Call
|
||||||
res := h.db.Where("name = ?", fun.Function.Name).First(&function)
|
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||||
|
if fun.Name != "" {
|
||||||
|
tool = *new(types.ToolCall)
|
||||||
|
tool.Function.Name = fun.Name
|
||||||
|
} else if toolCall {
|
||||||
|
arguments = append(arguments, fun.Arguments)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utils.IsEmptyValue(tool) {
|
||||||
|
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
toolCall = true
|
toolCall = true
|
||||||
|
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
|
||||||
|
contents = append(contents, callMsg)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "tool_calls" { // 函数调用完毕
|
if responseBody.Choices[0].FinishReason == "tool_calls" ||
|
||||||
|
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化 role
|
// output stopped
|
||||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
if responseBody.Choices[0].FinishReason != "" {
|
||||||
message.Role = responseBody.Choices[0].Delta.Role
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
|
||||||
continue
|
|
||||||
} else if responseBody.Choices[0].FinishReason != "" {
|
|
||||||
break // 输出完成或者输出中断了
|
break // 输出完成或者输出中断了
|
||||||
} else {
|
} else {
|
||||||
content := responseBody.Choices[0].Delta.Content
|
content := responseBody.Choices[0].Delta.Content
|
||||||
contents = append(contents, utils.InterfaceToString(content))
|
contents = append(contents, utils.InterfaceToString(content))
|
||||||
|
if isNew {
|
||||||
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
|
isNew = false
|
||||||
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
@@ -158,126 +175,11 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
|
||||||
if h.App.ChatConfig.EnableContext && toolCall == false {
|
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加聊天记录
|
|
||||||
if h.App.ChatConfig.EnableHistory {
|
|
||||||
useContext := true
|
|
||||||
if toolCall {
|
|
||||||
useContext = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// for prompt
|
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: useContext,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
var totalTokens = 0
|
|
||||||
if toolCall { // prompt + 函数名 + 参数 token
|
|
||||||
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
} else {
|
|
||||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
|
||||||
}
|
|
||||||
totalTokens += getTotalTokens(req)
|
|
||||||
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: useContext,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存当前会话
|
|
||||||
var chatItem model.ChatItem
|
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
|
||||||
if res.Error != nil {
|
|
||||||
chatItem.ChatId = session.ChatId
|
|
||||||
chatItem.UserId = session.UserId
|
|
||||||
chatItem.RoleId = role.Id
|
|
||||||
chatItem.ModelId = session.Model.Id
|
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
|
||||||
} else {
|
|
||||||
chatItem.Title = prompt
|
|
||||||
}
|
|
||||||
h.db.Create(&chatItem)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
body, err := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
return fmt.Errorf("请求 OpenAI API 失败:%s", body)
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
var res types.ApiError
|
|
||||||
err = json.Unmarshal(body, &res)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAI API 调用异常处理
|
|
||||||
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
|
||||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
|
||||||
// 移除当前 API key
|
|
||||||
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
|
||||||
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
|
||||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
|
||||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
|
||||||
logger.Error(res.Error.Message)
|
|
||||||
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
|
||||||
h.App.ChatContexts.Delete(session.ChatId)
|
|
||||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
|
||||||
} else {
|
|
||||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
150
api/handler/chatimpl/qwen_handler.go
Normal file
150
api/handler/chatimpl/qwen_handler.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/syndtr/goleveldb/leveldb/errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qWenResp struct {
|
||||||
|
Output struct {
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"output,omitempty"`
|
||||||
|
Usage struct {
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
} `json:"usage,omitempty"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通义千问消息发送实现
|
||||||
|
func (h *ChatHandler) sendQWenMessage(
|
||||||
|
chatCtx []types.Message,
|
||||||
|
req types.ApiRequest,
|
||||||
|
userVo vo.User,
|
||||||
|
ctx context.Context,
|
||||||
|
session *types.ChatSession,
|
||||||
|
role model.ChatRole,
|
||||||
|
prompt string,
|
||||||
|
ws *types.WsClient) error {
|
||||||
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
|
start := time.Now()
|
||||||
|
var apiKey = model.ApiKey{}
|
||||||
|
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||||
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
|
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||||
|
} else if strings.Contains(err.Error(), "no available key") {
|
||||||
|
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
defer response.Body.Close()
|
||||||
|
}
|
||||||
|
contentType := response.Header.Get("Content-Type")
|
||||||
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
|
replyCreatedAt := time.Now() // 记录回复时间
|
||||||
|
// 循环读取 Chunk 消息
|
||||||
|
var message = types.Message{}
|
||||||
|
var contents = make([]string, 0)
|
||||||
|
scanner := bufio.NewScanner(response.Body)
|
||||||
|
|
||||||
|
var content, lastText, newText string
|
||||||
|
var outPutStart = false
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
|
||||||
|
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(line, "data:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content = line[5:]
|
||||||
|
var resp qWenResp
|
||||||
|
if len(contents) == 0 { // 发送消息头
|
||||||
|
if !outPutStart {
|
||||||
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
|
outPutStart = true
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
// 处理代码换行
|
||||||
|
content = "\n"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := utils.JsonDecode(content, &resp)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with parse data line: ", content)
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if resp.Message != "" {
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//通过比较 lastText(上一次的文本)和 currentText(当前的文本),
|
||||||
|
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
|
||||||
|
//每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
|
||||||
|
currentText := resp.Output.Text
|
||||||
|
if currentText != lastText {
|
||||||
|
// 提取新增文本
|
||||||
|
newText = strings.Replace(currentText, lastText, "", 1)
|
||||||
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
|
Type: types.WsMiddle,
|
||||||
|
Content: utils.InterfaceToString(newText),
|
||||||
|
})
|
||||||
|
lastText = currentText // 更新 lastText
|
||||||
|
}
|
||||||
|
contents = append(contents, newText)
|
||||||
|
|
||||||
|
if resp.Output.FinishReason == "stop" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
} //end for
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
|
logger.Info("用户取消了请求:", prompt)
|
||||||
|
} else {
|
||||||
|
logger.Error("信息读取出错:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 消息发送成功
|
||||||
|
if len(contents) > 0 {
|
||||||
|
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
return fmt.Errorf("请求大模型 API 失败:%s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,24 +1,31 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"html/template"
|
"gorm.io/gorm"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type xunFeiResp struct {
|
type xunFeiResp struct {
|
||||||
@@ -50,15 +57,16 @@ type xunFeiResp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var Model2URL = map[string]string{
|
var Model2URL = map[string]string{
|
||||||
"general": "v1.1",
|
"general": "v1.1",
|
||||||
"generalv2": "v2.1",
|
"generalv2": "v2.1",
|
||||||
"generalv3": "v3.1",
|
"generalv3": "v3.1",
|
||||||
|
"generalv3.5": "v3.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
// 科大讯飞消息发送实现
|
// 科大讯飞消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendXunFeiMessage(
|
func (h *ChatHandler) sendXunFeiMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -68,13 +76,20 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
var res *gorm.DB
|
||||||
|
// use the bind key
|
||||||
|
if session.Model.KeyId > 0 {
|
||||||
|
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
|
||||||
|
}
|
||||||
|
// use the last unused key
|
||||||
|
if apiKey.Id == 0 {
|
||||||
|
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
|
||||||
|
}
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
@@ -86,6 +101,7 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
|
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
|
||||||
|
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
|
||||||
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
||||||
//握手并建立websocket 连接
|
//握手并建立websocket 连接
|
||||||
conn, resp, err := d.Dial(wsURL, nil)
|
conn, resp, err := d.Dial(wsURL, nil)
|
||||||
@@ -163,90 +179,10 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
|
||||||
if h.App.ChatConfig.EnableContext {
|
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加聊天记录
|
|
||||||
if h.App.ChatConfig.EnableHistory {
|
|
||||||
// for prompt
|
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存当前会话
|
|
||||||
var chatItem model.ChatItem
|
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
|
||||||
if res.Error != nil {
|
|
||||||
chatItem.ChatId = session.ChatId
|
|
||||||
chatItem.UserId = session.UserId
|
|
||||||
chatItem.RoleId = role.Id
|
|
||||||
chatItem.ModelId = session.Model.Id
|
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
|
||||||
} else {
|
|
||||||
chatItem.Title = prompt
|
|
||||||
}
|
|
||||||
h.db.Create(&chatItem)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,7 +195,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
|
|||||||
"parameter": map[string]interface{}{
|
"parameter": map[string]interface{}{
|
||||||
"chat": map[string]interface{}{
|
"chat": map[string]interface{}{
|
||||||
"domain": req.Model,
|
"domain": req.Model,
|
||||||
"temperature": float64(req.Temperature),
|
"temperature": req.Temperature,
|
||||||
"top_k": int64(6),
|
"top_k": int64(6),
|
||||||
"max_tokens": int64(req.MaxTokens),
|
"max_tokens": int64(req.MaxTokens),
|
||||||
"auditing": "default",
|
"auditing": "default",
|
||||||
|
|||||||
54
api/handler/config_handler.go
Normal file
54
api/handler/config_handler.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
licenseService *service.LicenseService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler {
|
||||||
|
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取指定的系统配置
|
||||||
|
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||||
|
key := c.Query("key")
|
||||||
|
var config model.Config
|
||||||
|
res := h.DB.Where("marker", key).First(&config)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var value map[string]interface{}
|
||||||
|
err := utils.JsonDecode(config.Config, &value)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// License 获取 License 配置
|
||||||
|
func (h *ConfigHandler) License(c *gin.Context) {
|
||||||
|
license := h.licenseService.GetLicense()
|
||||||
|
resp.SUCCESS(c, license.Configs)
|
||||||
|
}
|
||||||
262
api/handler/dalle_handler.go
Normal file
262
api/handler/dalle_handler.go
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service/dalle"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DallJobHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
redis *redis.Client
|
||||||
|
service *dalle.Service
|
||||||
|
uploader *oss.UploaderManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
|
||||||
|
return &DallJobHandler{
|
||||||
|
service: service,
|
||||||
|
uploader: manager,
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client WebSocket 客户端,用于通知任务状态变更
|
||||||
|
func (h *DallJobHandler) Client(c *gin.Context) {
|
||||||
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := h.GetInt(c, "user_id", 0)
|
||||||
|
if userId == 0 {
|
||||||
|
logger.Info("Invalid user ID")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := types.NewWsClient(ws)
|
||||||
|
h.service.Clients.Put(uint(userId), client)
|
||||||
|
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, msg, err := client.Receive()
|
||||||
|
if err != nil {
|
||||||
|
client.Close()
|
||||||
|
h.service.Clients.Delete(uint(userId))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var message types.WsMessage
|
||||||
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 心跳消息
|
||||||
|
if message.Type == "heartbeat" {
|
||||||
|
logger.Debug("收到 DallE 心跳消息:", message.Content)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
|
||||||
|
user, err := h.GetLoginUser(c)
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if user.Power < h.App.SysConfig.DallPower {
|
||||||
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image 创建一个绘画任务
|
||||||
|
func (h *DallJobHandler) Image(c *gin.Context) {
|
||||||
|
if !h.preCheck(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data types.DallTask
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
|
job := model.DallJob{
|
||||||
|
UserId: uint(userId),
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
Power: h.App.SysConfig.DallPower,
|
||||||
|
}
|
||||||
|
res := h.DB.Create(&job)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.service.PushTask(types.DallTask{
|
||||||
|
JobId: job.Id,
|
||||||
|
UserId: uint(userId),
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
Quality: data.Quality,
|
||||||
|
Size: data.Size,
|
||||||
|
Style: data.Style,
|
||||||
|
Power: job.Power,
|
||||||
|
})
|
||||||
|
|
||||||
|
client := h.service.Clients.Get(job.UserId)
|
||||||
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImgWall 照片墙
|
||||||
|
func (h *DallJobHandler) ImgWall(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 SD 任务列表
|
||||||
|
func (h *DallJobHandler) JobList(c *gin.Context) {
|
||||||
|
status := h.GetBool(c, "status")
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取任务列表
|
||||||
|
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if finish {
|
||||||
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
|
} else {
|
||||||
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
|
}
|
||||||
|
if userId > 0 {
|
||||||
|
session = session.Where("user_id = ?", userId)
|
||||||
|
}
|
||||||
|
if publish {
|
||||||
|
session = session.Where("publish", publish)
|
||||||
|
}
|
||||||
|
if page > 0 && pageSize > 0 {
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []model.DallJob
|
||||||
|
res := session.Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
return res.Error, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var jobs = make([]vo.DallJob, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
var job vo.DallJob
|
||||||
|
err := utils.CopyObject(item, &job)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jobs = append(jobs, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, jobs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove remove task image
|
||||||
|
func (h *DallJobHandler) Remove(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
ImgURL string `json:"img_url"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove job recode
|
||||||
|
res := h.DB.Delete(&model.DallJob{Id: data.Id})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove image
|
||||||
|
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("remove image failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish 发布/取消发布图片到画廊显示
|
||||||
|
func (h *DallJobHandler) Publish(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
@@ -1,39 +1,52 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/service/oss"
|
"geekai/service/dalle"
|
||||||
"chatplus/store/model"
|
"geekai/service/oss"
|
||||||
"chatplus/utils"
|
"geekai/store/model"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type FunctionHandler struct {
|
type FunctionHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
config types.ApiConfig
|
||||||
config types.ChatPlusApiConfig
|
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
proxyURL string
|
dallService *dalle.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
|
func NewFunctionHandler(
|
||||||
|
server *core.AppServer,
|
||||||
|
db *gorm.DB,
|
||||||
|
config *types.AppConfig,
|
||||||
|
manager *oss.UploaderManager,
|
||||||
|
dallService *dalle.Service) *FunctionHandler {
|
||||||
return &FunctionHandler{
|
return &FunctionHandler{
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: server,
|
App: server,
|
||||||
|
DB: db,
|
||||||
},
|
},
|
||||||
db: db,
|
|
||||||
config: config.ApiConfig,
|
config: config.ApiConfig,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
proxyURL: config.ProxyURL,
|
dallService: dallService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,30 +167,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
type imgReq struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Size string `json:"size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type imgRes struct {
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []struct {
|
|
||||||
RevisedPrompt string `json:"revised_prompt"`
|
|
||||||
Url string `json:"url"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrRes struct {
|
|
||||||
Error struct {
|
|
||||||
Code interface{} `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Param interface{} `json:"param"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dall3 DallE3 AI 绘图
|
// Dall3 DallE3 AI 绘图
|
||||||
func (h *FunctionHandler) Dall3(c *gin.Context) {
|
func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||||
if err := h.checkAuth(c); err != nil {
|
if err := h.checkAuth(c); err != nil {
|
||||||
@@ -192,88 +181,46 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Debugf("绘画参数:%+v", params)
|
logger.Debugf("绘画参数:%+v", params)
|
||||||
// check img calls
|
|
||||||
var user model.User
|
var user model.User
|
||||||
tx := h.db.Where("id = ?", params["user_id"]).First(&user)
|
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||||
if tx.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "当前用户不存在!")
|
resp.ERROR(c, "当前用户不存在!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ImgCalls <= 0 {
|
if user.Power < h.App.SysConfig.DallPower {
|
||||||
resp.ERROR(c, "当前用户的绘图次数额度不足!")
|
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create dall task
|
||||||
prompt := utils.InterfaceToString(params["prompt"])
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
// get image generation API KEY
|
job := model.DallJob{
|
||||||
var apiKey model.ApiKey
|
UserId: user.Id,
|
||||||
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
Prompt: prompt,
|
||||||
if tx.Error != nil {
|
Power: h.App.SysConfig.DallPower,
|
||||||
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
}
|
||||||
|
res = h.DB.Create(&job)
|
||||||
|
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// get image generation api URL
|
content, err := h.dallService.Image(types.DallTask{
|
||||||
var conf model.Config
|
JobId: job.Id,
|
||||||
var chatConfig types.ChatConfig
|
UserId: user.Id,
|
||||||
tx = h.db.Where("marker", "chat").First(&conf)
|
Prompt: job.Prompt,
|
||||||
if tx.Error != nil {
|
N: 1,
|
||||||
resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
|
Quality: "standard",
|
||||||
return
|
Size: "1024x1024",
|
||||||
}
|
Style: "vivid",
|
||||||
|
Power: job.Power,
|
||||||
err := utils.JsonDecode(conf.Config, &chatConfig)
|
}, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "error with decode chat config: "+err.Error())
|
resp.ERROR(c, "任务执行失败:"+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// translate prompt
|
|
||||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
|
||||||
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]), h.App.Config.ProxyURL)
|
|
||||||
if err == nil {
|
|
||||||
prompt = pt
|
|
||||||
}
|
|
||||||
imgNum := chatConfig.DallImgNum
|
|
||||||
if imgNum <= 0 {
|
|
||||||
imgNum = 1
|
|
||||||
}
|
|
||||||
var res imgRes
|
|
||||||
var errRes ErrRes
|
|
||||||
var request *req.Request
|
|
||||||
if apiKey.UseProxy && h.proxyURL != "" {
|
|
||||||
request = req.C().SetProxyURL(h.proxyURL).R()
|
|
||||||
} else {
|
|
||||||
request = req.C().R()
|
|
||||||
}
|
|
||||||
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, h.proxyURL)
|
|
||||||
r, err := request.SetHeader("Content-Type", "application/json").
|
|
||||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
|
||||||
SetBody(imgReq{
|
|
||||||
Model: "dall-e-3",
|
|
||||||
Prompt: prompt,
|
|
||||||
N: imgNum,
|
|
||||||
Size: "1024x1024",
|
|
||||||
}).
|
|
||||||
SetErrorResult(&errRes).
|
|
||||||
SetSuccessResult(&res).Post(apiKey.ApiURL)
|
|
||||||
if r.IsErrorState() {
|
|
||||||
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 更新 API KEY 的最后使用时间
|
|
||||||
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
|
||||||
// 存储图片
|
|
||||||
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "下载图片失败: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL)
|
|
||||||
// update user's img_calls
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
resp.SUCCESS(c, content)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -15,32 +22,29 @@ import (
|
|||||||
// InviteHandler 用户邀请
|
// InviteHandler 用户邀请
|
||||||
type InviteHandler struct {
|
type InviteHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
||||||
h := InviteHandler{db: db}
|
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Code 获取当前用户邀请码
|
// Code 获取当前用户邀请码
|
||||||
func (h *InviteHandler) Code(c *gin.Context) {
|
func (h *InviteHandler) Code(c *gin.Context) {
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
var inviteCode model.InviteCode
|
var inviteCode model.InviteCode
|
||||||
res := h.db.Where("user_id = ?", userId).First(&inviteCode)
|
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
|
||||||
// 如果邀请码不存在,则创建一个
|
// 如果邀请码不存在,则创建一个
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
code := strings.ToUpper(utils.RandString(8))
|
code := strings.ToUpper(utils.RandString(8))
|
||||||
for {
|
for {
|
||||||
res = h.db.Where("code = ?", code).First(&inviteCode)
|
res = h.DB.Where("code = ?", code).First(&inviteCode)
|
||||||
if res.Error != nil { // 不存在相同的邀请码则退出
|
if res.Error != nil { // 不存在相同的邀请码则退出
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inviteCode.UserId = userId
|
inviteCode.UserId = userId
|
||||||
inviteCode.Code = code
|
inviteCode.Code = code
|
||||||
h.db.Create(&inviteCode)
|
h.DB.Create(&inviteCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var codeVo vo.InviteCode
|
var codeVo vo.InviteCode
|
||||||
@@ -65,7 +69,7 @@ func (h *InviteHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
session := h.db.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
|
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.InviteLog{}).Count(&total)
|
session.Model(&model.InviteLog{}).Count(&total)
|
||||||
var items []model.InviteLog
|
var items []model.InviteLog
|
||||||
@@ -91,6 +95,6 @@ func (h *InviteHandler) List(c *gin.Context) {
|
|||||||
// Hits 访问邀请码
|
// Hits 访问邀请码
|
||||||
func (h *InviteHandler) Hits(c *gin.Context) {
|
func (h *InviteHandler) Hits(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
h.db.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
273
api/handler/markmap_handler.go
Normal file
273
api/handler/markmap_handler.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MarkMapHandler 生成思维导图
|
||||||
|
type MarkMapHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
clients *types.LMap[int, *types.WsClient]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
|
||||||
|
return &MarkMapHandler{
|
||||||
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
|
clients: types.NewLMap[int, *types.WsClient](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MarkMapHandler) Client(c *gin.Context) {
|
||||||
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelId := h.GetInt(c, "model_id", 0)
|
||||||
|
userId := h.GetInt(c, "user_id", 0)
|
||||||
|
|
||||||
|
client := types.NewWsClient(ws)
|
||||||
|
h.clients.Put(userId, client)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, msg, err := client.Receive()
|
||||||
|
if err != nil {
|
||||||
|
client.Close()
|
||||||
|
h.clients.Delete(userId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var message types.WsMessage
|
||||||
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 心跳消息
|
||||||
|
if message.Type == "heartbeat" {
|
||||||
|
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// change model
|
||||||
|
if message.Type == "model_id" {
|
||||||
|
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Receive a message: ", message.Content)
|
||||||
|
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
||||||
|
var user model.User
|
||||||
|
res := h.DB.Model(&model.User{}).First(&user, userId)
|
||||||
|
if res.Error != nil {
|
||||||
|
return fmt.Errorf("error with query user info: %v", res.Error)
|
||||||
|
}
|
||||||
|
var chatModel model.ChatModel
|
||||||
|
res = h.DB.Where("id", modelId).First(&chatModel)
|
||||||
|
if res.Error != nil {
|
||||||
|
return fmt.Errorf("error with query chat model: %v", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Status == false {
|
||||||
|
return errors.New("当前用户被禁用")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Power < chatModel.Power {
|
||||||
|
return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := make([]interface{}, 0)
|
||||||
|
messages = append(messages, types.Message{Role: "system", Content: `
|
||||||
|
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
|
||||||
|
# Geek-AI 助手
|
||||||
|
|
||||||
|
## 完整的开源系统
|
||||||
|
### 前端开源
|
||||||
|
### 后端开源
|
||||||
|
|
||||||
|
## 支持各种大模型
|
||||||
|
### OpenAI
|
||||||
|
### Azure
|
||||||
|
### 文心一言
|
||||||
|
### 通义千问
|
||||||
|
|
||||||
|
## 集成多种收费方式
|
||||||
|
### 支付宝
|
||||||
|
### 微信
|
||||||
|
|
||||||
|
另外,除此之外不要任何解释性语句。
|
||||||
|
`})
|
||||||
|
messages = append(messages, types.Message{Role: "user", Content: prompt})
|
||||||
|
var req = types.ApiRequest{
|
||||||
|
Model: chatModel.Value,
|
||||||
|
Stream: true,
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
response, err := h.doRequest(req, chatModel, &apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
contentType := response.Header.Get("Content-Type")
|
||||||
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
|
// 循环读取 Chunk 消息
|
||||||
|
scanner := bufio.NewScanner(response.Body)
|
||||||
|
var isNew = true
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseBody = types.ApiResponse{}
|
||||||
|
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||||
|
if err != nil { // 数据解析出错
|
||||||
|
return fmt.Errorf("error with decode data: %v", line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseBody.Choices[0].FinishReason == "stop" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNew {
|
||||||
|
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
|
||||||
|
isNew = false
|
||||||
|
}
|
||||||
|
utils.ReplyChunkMessage(client, types.WsMessage{
|
||||||
|
Type: types.WsMiddle,
|
||||||
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
|
})
|
||||||
|
} // end for
|
||||||
|
|
||||||
|
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||||
|
|
||||||
|
} else {
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("读取响应失败: %v", err)
|
||||||
|
}
|
||||||
|
var res types.ApiError
|
||||||
|
err = json.Unmarshal(body, &res)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("解析响应失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI API 调用异常处理
|
||||||
|
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
||||||
|
// remove key
|
||||||
|
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
||||||
|
return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
||||||
|
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
||||||
|
return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 扣减算力
|
||||||
|
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
|
||||||
|
if res.Error == nil {
|
||||||
|
// 记录算力消费日志
|
||||||
|
var u model.User
|
||||||
|
h.DB.Where("id", userId).First(&u)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: u.Id,
|
||||||
|
Username: u.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: chatModel.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Balance: u.Power,
|
||||||
|
Model: chatModel.Value,
|
||||||
|
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
|
||||||
|
// if the chat model bind a KEY, use it directly
|
||||||
|
var res *gorm.DB
|
||||||
|
if chatModel.KeyId > 0 {
|
||||||
|
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
|
||||||
|
}
|
||||||
|
// use the last unused key
|
||||||
|
if apiKey.Id == 0 {
|
||||||
|
res = h.DB.Where("platform", types.OpenAI).
|
||||||
|
Where("type", "chat").
|
||||||
|
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||||
|
}
|
||||||
|
if res.Error != nil {
|
||||||
|
return nil, errors.New("no available key, please import key")
|
||||||
|
}
|
||||||
|
apiURL := apiKey.ApiURL
|
||||||
|
// 更新 API KEY 的最后使用时间
|
||||||
|
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
|
||||||
|
// 创建 HttpClient 请求对象
|
||||||
|
var client *http.Client
|
||||||
|
requestBody, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header.Set("Content-Type", "application/json")
|
||||||
|
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||||
|
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||||
|
client = &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxy),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||||
|
return client.Do(request)
|
||||||
|
}
|
||||||
43
api/handler/menu_handler.go
Normal file
43
api/handler/menu_handler.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MenuHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||||
|
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 数据列表
|
||||||
|
func (h *MenuHandler) List(c *gin.Context) {
|
||||||
|
var items []model.Menu
|
||||||
|
var list = make([]vo.Menu, 0)
|
||||||
|
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range items {
|
||||||
|
var product vo.Menu
|
||||||
|
err := utils.CopyObject(item, &product)
|
||||||
|
if err == nil {
|
||||||
|
list = append(list, product)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, list)
|
||||||
|
}
|
||||||
@@ -1,53 +1,61 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/service"
|
|
||||||
"chatplus/service/mj"
|
|
||||||
"chatplus/service/oss"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"geekai/core"
|
||||||
"github.com/gorilla/websocket"
|
"geekai/core/types"
|
||||||
"gorm.io/gorm"
|
"geekai/service"
|
||||||
|
"geekai/service/mj"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
pool *mj.ServicePool
|
pool *mj.ServicePool
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
||||||
h := MidJourneyHandler{
|
return &MidJourneyHandler{
|
||||||
db: db,
|
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
pool: pool,
|
pool: pool,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ImgCalls <= 0 {
|
if user.Power < h.App.SysConfig.MjPower {
|
||||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,19 +92,23 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
|||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Prompt string `json:"prompt"`
|
TaskType string `json:"task_type"`
|
||||||
NegPrompt string `json:"neg_prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Rate string `json:"rate"`
|
NegPrompt string `json:"neg_prompt"`
|
||||||
Model string `json:"model"`
|
Rate string `json:"rate"`
|
||||||
Chaos int `json:"chaos"`
|
Model string `json:"model"`
|
||||||
Raw bool `json:"raw"`
|
Chaos int `json:"chaos"`
|
||||||
Seed int64 `json:"seed"`
|
Raw bool `json:"raw"`
|
||||||
Stylize int `json:"stylize"`
|
Seed int64 `json:"seed"`
|
||||||
Img string `json:"img"`
|
Stylize int `json:"stylize"`
|
||||||
Tile bool `json:"tile"`
|
ImgArr []string `json:"img_arr"`
|
||||||
Quality float32 `json:"quality"`
|
Tile bool `json:"tile"`
|
||||||
Weight float32 `json:"weight"`
|
Quality float32 `json:"quality"`
|
||||||
|
Iw float32 `json:"iw"`
|
||||||
|
CRef string `json:"cref"` //生成角色一致的图像
|
||||||
|
SRef string `json:"sref"` //生成风格一致的图像
|
||||||
|
Cw int `json:"cw"` // 参考程度
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -106,39 +118,57 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var prompt = data.Prompt
|
var params = ""
|
||||||
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
if data.Rate != "" && !strings.Contains(params, "--ar") {
|
||||||
prompt += " --ar " + data.Rate
|
params += " --ar " + data.Rate
|
||||||
}
|
}
|
||||||
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
|
if data.Seed > 0 && !strings.Contains(params, "--seed") {
|
||||||
prompt += fmt.Sprintf(" --seed %d", data.Seed)
|
params += fmt.Sprintf(" --seed %d", data.Seed)
|
||||||
}
|
}
|
||||||
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
|
if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
|
||||||
prompt += fmt.Sprintf(" --s %d", data.Stylize)
|
params += fmt.Sprintf(" --s %d", data.Stylize)
|
||||||
}
|
}
|
||||||
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
|
if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
|
||||||
prompt += fmt.Sprintf(" --c %d", data.Chaos)
|
params += fmt.Sprintf(" --c %d", data.Chaos)
|
||||||
}
|
}
|
||||||
if data.Img != "" {
|
if len(data.ImgArr) > 0 && data.Iw > 0 {
|
||||||
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
|
params += fmt.Sprintf(" --iw %.2f", data.Iw)
|
||||||
if data.Weight > 0 {
|
|
||||||
prompt += fmt.Sprintf(" --iw %f", data.Weight)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if data.Raw {
|
if data.Raw {
|
||||||
prompt += " --style raw"
|
params += " --style raw"
|
||||||
}
|
}
|
||||||
if data.Quality > 0 {
|
if data.Quality > 0 {
|
||||||
prompt += fmt.Sprintf(" --q %.2f", data.Quality)
|
params += fmt.Sprintf(" --q %.2f", data.Quality)
|
||||||
}
|
|
||||||
if data.NegPrompt != "" {
|
|
||||||
prompt += fmt.Sprintf(" --no %s", data.NegPrompt)
|
|
||||||
}
|
}
|
||||||
if data.Tile {
|
if data.Tile {
|
||||||
prompt += " --tile "
|
params += " --tile "
|
||||||
}
|
}
|
||||||
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
|
if data.CRef != "" {
|
||||||
prompt += fmt.Sprintf(" %s", data.Model)
|
params += fmt.Sprintf(" --cref %s", data.CRef)
|
||||||
|
if data.Cw > 0 {
|
||||||
|
params += fmt.Sprintf(" --cw %d", data.Cw)
|
||||||
|
} else {
|
||||||
|
params += " --cw 100"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.SRef != "" {
|
||||||
|
params += fmt.Sprintf(" --sref %s", data.SRef)
|
||||||
|
}
|
||||||
|
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
|
||||||
|
params += fmt.Sprintf(" %s", data.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理融图和换脸的提示词
|
||||||
|
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
|
||||||
|
params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果本地图片上传的是相对地址,处理成绝对地址
|
||||||
|
for k, v := range data.ImgArr {
|
||||||
|
if !strings.HasPrefix(v, "http") {
|
||||||
|
data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
@@ -150,31 +180,62 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
Type: types.TaskImage.String(),
|
Type: data.TaskType,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: prompt,
|
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
||||||
|
Power: h.App.SysConfig.MjPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.db.Create(&job); res.Error != nil {
|
opt := "绘图"
|
||||||
|
if data.TaskType == types.TaskBlend.String() {
|
||||||
|
job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
|
||||||
|
opt = "融图"
|
||||||
|
} else if data.TaskType == types.TaskSwapFace.String() {
|
||||||
|
job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
|
||||||
|
opt = "换脸"
|
||||||
|
}
|
||||||
|
|
||||||
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
h.pool.PushTask(types.MjTask{
|
||||||
Id: int(job.Id),
|
Id: job.Id,
|
||||||
|
TaskId: taskId,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Type: types.TaskImage,
|
Type: types.TaskType(data.TaskType),
|
||||||
Prompt: fmt.Sprintf("%s %s", taskId, prompt),
|
Prompt: data.Prompt,
|
||||||
|
NegPrompt: data.NegPrompt,
|
||||||
|
Params: params,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
|
ImgArr: data.ImgArr,
|
||||||
})
|
})
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
client := h.pool.Clients.Get(uint(job.UserId))
|
||||||
_ = client.Send([]byte("Task Updated"))
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
|
||||||
// update user's img calls
|
// update user's power
|
||||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +264,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
taskId, _ := h.snowflake.Next(true)
|
taskId, _ := h.snowflake.Next(true)
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
@@ -213,15 +273,16 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
|
Power: h.App.SysConfig.MjActionPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.db.Create(&job); res.Error != nil {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
h.pool.PushTask(types.MjTask{
|
||||||
Id: jobId,
|
Id: job.Id,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Type: types.TaskUpscale,
|
Type: types.TaskUpscale,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
@@ -233,8 +294,26 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
client := h.pool.Clients.Get(uint(job.UserId))
|
||||||
_ = client.Send([]byte("Task Updated"))
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
// update user's power
|
||||||
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +330,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
taskId, _ := h.snowflake.Next(true)
|
taskId, _ := h.snowflake.Next(true)
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
@@ -262,15 +340,16 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
|
Power: h.App.SysConfig.MjActionPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.db.Create(&job); res.Error != nil {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
h.pool.PushTask(types.MjTask{
|
||||||
Id: jobId,
|
Id: job.Id,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Type: types.TaskVariation,
|
Type: types.TaskVariation,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
@@ -282,22 +361,64 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
client := h.pool.Clients.Get(uint(job.UserId))
|
||||||
_ = client.Send([]byte("Task Updated"))
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
|
||||||
// update user's img calls
|
// update user's power
|
||||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImgWall 照片墙
|
||||||
|
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
// JobList 获取 MJ 任务列表
|
// JobList 获取 MJ 任务列表
|
||||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||||
status := h.GetInt(c, "status", 0)
|
status := h.GetBool(c, "status")
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 0)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
if status == 1 {
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 MJ 任务列表
|
||||||
|
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
@@ -305,6 +426,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
if userId > 0 {
|
if userId > 0 {
|
||||||
session = session.Where("user_id = ?", userId)
|
session = session.Where("user_id = ?", userId)
|
||||||
}
|
}
|
||||||
|
if publish {
|
||||||
|
session = session.Where("publish = ?", publish)
|
||||||
|
}
|
||||||
if page > 0 && pageSize > 0 {
|
if page > 0 && pageSize > 0 {
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
session = session.Offset(offset).Limit(pageSize)
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
@@ -313,8 +437,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
var items []model.MidJourneyJob
|
var items []model.MidJourneyJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, types.NoData)
|
return res.Error, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.MidJourneyJob, 0)
|
var jobs = make([]vo.MidJourneyJob, 0)
|
||||||
@@ -325,31 +448,21 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if job.Progress == -1 {
|
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
||||||
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
|
// discord 服务器图片需要使用代理转发图片数据流
|
||||||
}
|
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
|
||||||
|
|
||||||
if item.Progress < 100 {
|
|
||||||
// 10 分钟还没完成的任务直接删除
|
|
||||||
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
|
|
||||||
h.db.Delete(&item)
|
|
||||||
// 退回绘图次数
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 正在运行中任务使用代理访问图片
|
|
||||||
if item.ImgURL == "" && item.OrgURL != "" {
|
|
||||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
job.ImgURL = job.OrgURL
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
return nil, jobs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
@@ -365,7 +478,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// remove job recode
|
||||||
res := h.db.Delete(&model.MidJourneyJob{Id: data.Id})
|
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -378,7 +491,30 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := h.pool.Clients.Get(data.UserId)
|
client := h.pool.Clients.Get(data.UserId)
|
||||||
_ = client.Send([]byte("Task Updated"))
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish 发布图片到画廊显示
|
||||||
|
func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,30 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OrderHandler struct {
|
type OrderHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||||
h := OrderHandler{db: db}
|
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OrderHandler) List(c *gin.Context) {
|
func (h *OrderHandler) List(c *gin.Context) {
|
||||||
@@ -31,8 +36,8 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, _ := utils.GetLoginUser(c, h.db)
|
userId := h.GetLoginUserId(c)
|
||||||
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
|
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.Order{}).Count(&total)
|
session.Model(&model.Order{}).Count(&total)
|
||||||
var items []model.Order
|
var items []model.Order
|
||||||
|
|||||||
@@ -1,27 +1,38 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/service"
|
|
||||||
"chatplus/service/payment"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"geekai/core"
|
||||||
"gorm.io/gorm"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/payment"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PayWayAlipay = "支付宝"
|
PayWayAlipay = "支付宝"
|
||||||
PayWayXunHu = "虎皮椒"
|
PayWayXunHu = "虎皮椒"
|
||||||
|
PayWayJs = "PayJS"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PaymentHandler 支付服务回调 handler
|
// PaymentHandler 支付服务回调 handler
|
||||||
@@ -29,28 +40,53 @@ type PaymentHandler struct {
|
|||||||
BaseHandler
|
BaseHandler
|
||||||
alipayService *payment.AlipayService
|
alipayService *payment.AlipayService
|
||||||
huPiPayService *payment.HuPiPayService
|
huPiPayService *payment.HuPiPayService
|
||||||
|
js *payment.PayJS
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
db *gorm.DB
|
|
||||||
fs embed.FS
|
fs embed.FS
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
|
signKey string // 用来签名的随机秘钥
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPaymentHandler(server *core.AppServer, alipayService *payment.AlipayService, huPiPayService *payment.HuPiPayService, snowflake *service.Snowflake, db *gorm.DB, fs embed.FS) *PaymentHandler {
|
func NewPaymentHandler(
|
||||||
h := PaymentHandler{
|
server *core.AppServer,
|
||||||
|
alipayService *payment.AlipayService,
|
||||||
|
huPiPayService *payment.HuPiPayService,
|
||||||
|
js *payment.PayJS,
|
||||||
|
db *gorm.DB,
|
||||||
|
snowflake *service.Snowflake,
|
||||||
|
fs embed.FS) *PaymentHandler {
|
||||||
|
return &PaymentHandler{
|
||||||
alipayService: alipayService,
|
alipayService: alipayService,
|
||||||
huPiPayService: huPiPayService,
|
huPiPayService: huPiPayService,
|
||||||
|
js: js,
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
fs: fs,
|
fs: fs,
|
||||||
db: db,
|
|
||||||
lock: sync.Mutex{},
|
lock: sync.Mutex{},
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: server,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
|
signKey: utils.RandString(32),
|
||||||
}
|
}
|
||||||
h.App = server
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
||||||
orderNo := h.GetTrim(c, "order_no")
|
orderNo := h.GetTrim(c, "order_no")
|
||||||
payWay := h.GetTrim(c, "pay_way")
|
payWay := h.GetTrim(c, "pay_way")
|
||||||
|
t := h.GetInt(c, "t", 0)
|
||||||
|
sign := h.GetTrim(c, "sign")
|
||||||
|
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey)
|
||||||
|
newSign := utils.Sha256(signStr)
|
||||||
|
if newSign != sign {
|
||||||
|
resp.ERROR(c, "订单签名错误!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查二维码是否过期
|
||||||
|
if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) {
|
||||||
|
resp.ERROR(c, "支付二维码已过期,请重新生成!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if orderNo == "" {
|
if orderNo == "" {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -58,14 +94,20 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.db.Where("order_no = ?", orderNo).First(&order)
|
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Order not found")
|
resp.ERROR(c, "Order not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
|
||||||
|
if order.Status == types.OrderPaidSuccess {
|
||||||
|
resp.ERROR(c, "This order had been paid, please do not pay twice")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 更新扫码状态
|
// 更新扫码状态
|
||||||
h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
|
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
|
||||||
if payWay == "alipay" { // 支付宝
|
if payWay == "alipay" { // 支付宝
|
||||||
// 生成支付链接
|
// 生成支付链接
|
||||||
notifyURL := h.App.Config.AlipayConfig.NotifyURL
|
notifyURL := h.App.Config.AlipayConfig.NotifyURL
|
||||||
@@ -81,41 +123,20 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
|
|||||||
c.Redirect(302, uri)
|
c.Redirect(302, uri)
|
||||||
return
|
return
|
||||||
} else if payWay == "hupi" { // 虎皮椒支付
|
} else if payWay == "hupi" { // 虎皮椒支付
|
||||||
params := map[string]string{
|
params := payment.HuPiPayReq{
|
||||||
"version": "1.1",
|
Version: "1.1",
|
||||||
"trade_order_id": orderNo,
|
TradeOrderId: orderNo,
|
||||||
"total_fee": fmt.Sprintf("%f", order.Amount),
|
TotalFee: fmt.Sprintf("%f", order.Amount),
|
||||||
"title": order.Subject,
|
Title: order.Subject,
|
||||||
"notify_url": h.App.Config.HuPiPayConfig.NotifyURL,
|
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
|
||||||
"return_url": "",
|
WapName: "极客学长",
|
||||||
"wap_name": "极客学长",
|
|
||||||
"callback_url": "",
|
|
||||||
}
|
}
|
||||||
|
r, err := h.huPiPayService.Pay(params)
|
||||||
res, err := h.huPiPayService.Pay(params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var r struct {
|
|
||||||
Openid interface{} `json:"openid"`
|
|
||||||
UrlQrcode string `json:"url_qrcode"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
ErrCode int `json:"errcode"`
|
|
||||||
ErrMsg string `json:"errmsg,omitempty"`
|
|
||||||
}
|
|
||||||
err = utils.JsonDecode(res, &r)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debugf(res)
|
|
||||||
resp.ERROR(c, "error with decode payment result: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.ErrCode != 0 {
|
|
||||||
resp.ERROR(c, "error with generate pay url: "+r.ErrMsg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Redirect(302, r.URL)
|
c.Redirect(302, r.URL)
|
||||||
}
|
}
|
||||||
resp.ERROR(c, "Invalid operations")
|
resp.ERROR(c, "Invalid operations")
|
||||||
@@ -132,7 +153,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
|
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Order not found")
|
resp.ERROR(c, "Order not found")
|
||||||
return
|
return
|
||||||
@@ -147,7 +168,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
var item model.Order
|
var item model.Order
|
||||||
h.db.Where("order_no = ?", data.OrderNo).First(&item)
|
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
|
||||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
||||||
order.Status = item.Status
|
order.Status = item.Status
|
||||||
break
|
break
|
||||||
@@ -171,7 +192,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var product model.Product
|
var product model.Product
|
||||||
res := h.db.First(&product, data.ProductId)
|
res := h.DB.First(&product, data.ProductId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Product not found")
|
resp.ERROR(c, "Product not found")
|
||||||
return
|
return
|
||||||
@@ -183,42 +204,69 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var user model.User
|
var user model.User
|
||||||
res = h.db.First(&user, data.UserId)
|
res = h.DB.First(&user, data.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Invalid user ID")
|
resp.ERROR(c, "Invalid user ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
payWay := PayWayAlipay
|
var payWay string
|
||||||
if data.PayWay == "hupi" {
|
var notifyURL string
|
||||||
|
switch data.PayWay {
|
||||||
|
case "hupi":
|
||||||
payWay = PayWayXunHu
|
payWay = PayWayXunHu
|
||||||
|
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||||
|
case "payjs":
|
||||||
|
payWay = PayWayJs
|
||||||
|
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||||
|
default:
|
||||||
|
payWay = PayWayAlipay
|
||||||
|
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||||
}
|
}
|
||||||
// 创建订单
|
// 创建订单
|
||||||
remark := types.OrderRemark{
|
remark := types.OrderRemark{
|
||||||
Days: product.Days,
|
Days: product.Days,
|
||||||
Calls: product.Calls,
|
Power: product.Power,
|
||||||
ImgCalls: product.ImgCalls,
|
|
||||||
Name: product.Name,
|
Name: product.Name,
|
||||||
Price: product.Price,
|
Price: product.Price,
|
||||||
Discount: product.Discount,
|
Discount: product.Discount,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||||
order := model.Order{
|
order := model.Order{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Mobile: user.Username,
|
Username: user.Username,
|
||||||
ProductId: product.Id,
|
ProductId: product.Id,
|
||||||
OrderNo: orderNo,
|
OrderNo: orderNo,
|
||||||
Subject: product.Name,
|
Subject: product.Name,
|
||||||
Amount: product.Price - product.Discount,
|
Amount: amount,
|
||||||
Status: types.OrderNotPaid,
|
Status: types.OrderNotPaid,
|
||||||
PayWay: payWay,
|
PayWay: payWay,
|
||||||
Remark: utils.JsonEncode(remark),
|
Remark: utils.JsonEncode(remark),
|
||||||
}
|
}
|
||||||
res = h.db.Create(&order)
|
res = h.DB.Create(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PayJs 单独处理,只能用官方生成的二维码
|
||||||
|
if data.PayWay == "payjs" {
|
||||||
|
params := payment.JPayReq{
|
||||||
|
TotalFee: int(math.Ceil(order.Amount * 100)),
|
||||||
|
OutTradeNo: order.OrderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
}
|
||||||
|
r := h.js.Pay(params)
|
||||||
|
if r.IsOK() {
|
||||||
|
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var logo string
|
var logo string
|
||||||
if data.PayWay == "alipay" {
|
if data.PayWay == "alipay" {
|
||||||
logo = "res/img/alipay.jpg"
|
logo = "res/img/alipay.jpg"
|
||||||
@@ -236,13 +284,15 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
parse, err := url.Parse(h.App.Config.AlipayConfig.NotifyURL)
|
parse, err := url.Parse(notifyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
|
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
|
||||||
|
sign := utils.Sha256(signStr)
|
||||||
|
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
|
||||||
imgData, err := utils.GenQrcode(imageURL, 400, file)
|
imgData, err := utils.GenQrcode(imageURL, 400, file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
@@ -252,52 +302,141 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
|
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AlipayNotify 支付宝支付回调
|
// Mobile 移动端支付
|
||||||
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||||
err := c.Request.ParseForm()
|
var data struct {
|
||||||
|
PayWay string `json:"pay_way"` // 支付方式
|
||||||
|
ProductId uint `json:"product_id"`
|
||||||
|
UserId int `json:"user_id"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var product model.Product
|
||||||
|
res := h.DB.First(&product, data.ProductId)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "Product not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
orderNo, err := h.snowflake.Next(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusOK, "fail")
|
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
res = h.DB.First(&user, data.UserId)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "Invalid user ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:这里最好用支付宝的公钥签名签证一下交易真假
|
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||||
//res := h.alipayService.TradeVerify(c.Request.Form)
|
var payWay string
|
||||||
r := h.alipayService.TradeQuery(c.Request.Form.Get("out_trade_no"))
|
var notifyURL, returnURL string
|
||||||
logger.Infof("验证支付结果:%+v", r)
|
var payURL string
|
||||||
if !r.Success() {
|
switch data.PayWay {
|
||||||
c.String(http.StatusOK, "fail")
|
case "hupi":
|
||||||
|
payWay = PayWayXunHu
|
||||||
|
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||||
|
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||||
|
parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
|
||||||
|
baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
|
||||||
|
params := payment.HuPiPayReq{
|
||||||
|
Version: "1.1",
|
||||||
|
TradeOrderId: orderNo,
|
||||||
|
TotalFee: fmt.Sprintf("%f", amount),
|
||||||
|
Title: product.Name,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
ReturnURL: returnURL,
|
||||||
|
CallbackURL: returnURL,
|
||||||
|
WapName: "极客学长",
|
||||||
|
WapUrl: baseURL,
|
||||||
|
Type: "WAP",
|
||||||
|
}
|
||||||
|
r, err := h.huPiPayService.Pay(params)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with generating Pay URL: ", err.Error())
|
||||||
|
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payURL = r.URL
|
||||||
|
case "payjs":
|
||||||
|
payWay = PayWayJs
|
||||||
|
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||||
|
returnURL = h.App.Config.JPayConfig.ReturnURL
|
||||||
|
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
|
||||||
|
params.Add("out_trade_no", orderNo)
|
||||||
|
params.Add("body", product.Name)
|
||||||
|
params.Add("notify_url", notifyURL)
|
||||||
|
params.Add("auto", "0")
|
||||||
|
payURL = h.js.PayH5(params)
|
||||||
|
case "alipay":
|
||||||
|
payWay = PayWayAlipay
|
||||||
|
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||||
|
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||||
|
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建订单
|
||||||
|
remark := types.OrderRemark{
|
||||||
|
Days: product.Days,
|
||||||
|
Power: product.Power,
|
||||||
|
Name: product.Name,
|
||||||
|
Price: product.Price,
|
||||||
|
Discount: product.Discount,
|
||||||
|
}
|
||||||
|
|
||||||
|
order := model.Order{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
ProductId: product.Id,
|
||||||
|
OrderNo: orderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
Amount: amount,
|
||||||
|
Status: types.OrderNotPaid,
|
||||||
|
PayWay: payWay,
|
||||||
|
Remark: utils.JsonEncode(remark),
|
||||||
|
}
|
||||||
|
res = h.DB.Create(&order)
|
||||||
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.lock.Lock()
|
resp.SUCCESS(c, payURL)
|
||||||
defer h.lock.Unlock()
|
|
||||||
|
|
||||||
err = h.notify(r.OutTradeNo)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusOK, "fail")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.String(http.StatusOK, "success")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步通知回调公共逻辑
|
// 异步通知回调公共逻辑
|
||||||
func (h *PaymentHandler) notify(orderNo string) error {
|
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.db.Where("order_no = ?", orderNo).First(&order)
|
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with fetch order: %v", res.Error)
|
err := fmt.Errorf("error with fetch order: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
// 已支付订单,直接返回
|
// 已支付订单,直接返回
|
||||||
if order.Status == types.OrderPaidSuccess {
|
if order.Status == types.OrderPaidSuccess {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res = h.db.First(&user, order.UserId)
|
res = h.DB.First(&user, order.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with fetch user info: %v", res.Error)
|
err := fmt.Errorf("error with fetch user info: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -312,34 +451,27 @@ func (h *PaymentHandler) notify(orderNo string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 点卡:days == 0, calls > 0
|
var opt string
|
||||||
// 2. vip 套餐:days > 0, calls == 0
|
var power int
|
||||||
if remark.Days > 0 {
|
if remark.Days > 0 { // VIP 充值
|
||||||
if user.ExpiredTime > time.Now().Unix() {
|
if user.ExpiredTime >= time.Now().Unix() {
|
||||||
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
||||||
|
opt = "VIP充值,VIP 没到期,只延期不增加算力"
|
||||||
} else {
|
} else {
|
||||||
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
||||||
|
user.Power += h.App.SysConfig.VipMonthPower
|
||||||
|
power = h.App.SysConfig.VipMonthPower
|
||||||
|
opt = "VIP充值"
|
||||||
}
|
}
|
||||||
user.Vip = true
|
user.Vip = true
|
||||||
|
} else { // 充值点卡,直接增加次数即可
|
||||||
} else if !user.Vip { // 充值点卡的非 VIP 用户
|
user.Power += remark.Power
|
||||||
user.ExpiredTime = time.Now().AddDate(0, 0, 30).Unix()
|
opt = "点卡充值"
|
||||||
}
|
power = remark.Power
|
||||||
|
|
||||||
if remark.Calls > 0 { // 充值点卡
|
|
||||||
user.Calls += remark.Calls
|
|
||||||
} else {
|
|
||||||
user.Calls += h.App.SysConfig.VipMonthCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
if remark.ImgCalls > 0 {
|
|
||||||
user.ImgCalls += remark.ImgCalls
|
|
||||||
} else {
|
|
||||||
user.ImgCalls += h.App.SysConfig.VipMonthImgCalls
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新用户信息
|
// 更新用户信息
|
||||||
res = h.db.Updates(&user)
|
res = h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with update user info: %v", res.Error)
|
err := fmt.Errorf("error with update user info: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -349,7 +481,8 @@ func (h *PaymentHandler) notify(orderNo string) error {
|
|||||||
// 更新订单状态
|
// 更新订单状态
|
||||||
order.PayTime = time.Now().Unix()
|
order.PayTime = time.Now().Unix()
|
||||||
order.Status = types.OrderPaidSuccess
|
order.Status = types.OrderPaidSuccess
|
||||||
res = h.db.Updates(&order)
|
order.TradeNo = tradeNo
|
||||||
|
res = h.DB.Updates(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with update order info: %v", res.Error)
|
err := fmt.Errorf("error with update order info: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -357,7 +490,23 @@ func (h *PaymentHandler) notify(orderNo string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 更新产品销量
|
// 更新产品销量
|
||||||
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
||||||
|
|
||||||
|
// 记录算力充值日志
|
||||||
|
if opt != "" {
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerRecharge,
|
||||||
|
Amount: power,
|
||||||
|
Balance: user.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: order.PayWay,
|
||||||
|
Remark: fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,6 +519,9 @@ func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
|||||||
if h.App.Config.HuPiPayConfig.Enabled {
|
if h.App.Config.HuPiPayConfig.Enabled {
|
||||||
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
|
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
|
||||||
}
|
}
|
||||||
|
if h.App.Config.JPayConfig.Enabled {
|
||||||
|
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
|
||||||
|
}
|
||||||
resp.SUCCESS(c, data)
|
resp.SUCCESS(c, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,12 +534,76 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
orderNo := c.Request.Form.Get("trade_order_id")
|
orderNo := c.Request.Form.Get("trade_order_id")
|
||||||
logger.Infof("收到订单支付回调,订单 NO:%s", orderNo)
|
tradeNo := c.Request.Form.Get("open_order_id")
|
||||||
// TODO 是否要保存订单交易流水号
|
logger.Infof("收到虎皮椒订单支付回调,订单 NO:%s,交易流水号:%s", orderNo, tradeNo)
|
||||||
h.lock.Lock()
|
|
||||||
defer h.lock.Unlock()
|
|
||||||
|
|
||||||
err = h.notify(orderNo)
|
if err = h.huPiPayService.Check(tradeNo); err != nil {
|
||||||
|
logger.Error("订单校验失败:", err)
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = h.notify(orderNo, tradeNo)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.String(http.StatusOK, "success")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlipayNotify 支付宝支付回调
|
||||||
|
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||||
|
err := c.Request.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO:验证交易签名
|
||||||
|
res := h.alipayService.TradeVerify(c.Request.Form)
|
||||||
|
logger.Infof("验证支付结果:%+v", res)
|
||||||
|
if !res.Success() {
|
||||||
|
logger.Error("订单校验失败:", res.Message)
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tradeNo := c.Request.Form.Get("trade_no")
|
||||||
|
err = h.notify(res.OutTradeNo, tradeNo)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.String(http.StatusOK, "success")
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayJsNotify PayJs 支付异步回调
|
||||||
|
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
|
||||||
|
err := c.Request.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
orderNo := c.Request.Form.Get("out_trade_no")
|
||||||
|
returnCode := c.Request.Form.Get("return_code")
|
||||||
|
logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
|
||||||
|
// 支付失败
|
||||||
|
if returnCode != "1" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 校验订单支付状态
|
||||||
|
tradeNo := c.Request.Form.Get("payjs_order_id")
|
||||||
|
err = h.js.Check(tradeNo)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("订单校验失败:", err)
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.notify(orderNo, tradeNo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
|
|||||||
74
api/handler/power_log_handler.go
Normal file
74
api/handler/power_log_handler.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PowerLogHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||||
|
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Date []string `json:"date"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
session = session.Where("user_id", userId)
|
||||||
|
if data.Model != "" {
|
||||||
|
session = session.Where("model", data.Model)
|
||||||
|
}
|
||||||
|
if len(data.Date) == 2 {
|
||||||
|
start := data.Date[0] + " 00:00:00"
|
||||||
|
end := data.Date[1] + " 00:00:00"
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.PowerLog{}).Count(&total)
|
||||||
|
var items []model.PowerLog
|
||||||
|
var list = make([]vo.PowerLog, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range items {
|
||||||
|
var log vo.PowerLog
|
||||||
|
err := utils.CopyObject(item, &log)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Id = item.Id
|
||||||
|
log.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
log.TypeStr = item.Type.String()
|
||||||
|
list = append(list, log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||||
|
}
|
||||||
@@ -1,31 +1,35 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/store/model"
|
"geekai/store/model"
|
||||||
"chatplus/store/vo"
|
"geekai/store/vo"
|
||||||
"chatplus/utils"
|
"geekai/utils"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProductHandler struct {
|
type ProductHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||||
h := ProductHandler{db: db}
|
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ProductHandler) List(c *gin.Context) {
|
func (h *ProductHandler) List(c *gin.Context) {
|
||||||
var items []model.Product
|
var items []model.Product
|
||||||
var list = make([]vo.Product, 0)
|
var list = make([]vo.Product, 0)
|
||||||
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
|
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var product vo.Product
|
var product vo.Product
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
|
||||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
|
||||||
|
|
||||||
type PromptHandler struct {
|
|
||||||
BaseHandler
|
|
||||||
db *gorm.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
|
|
||||||
h := &PromptHandler{db: db}
|
|
||||||
h.App = app
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewrite translate and rewrite prompt with ChatGPT
|
|
||||||
func (h *PromptHandler) Rewrite(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PromptHandler) Translate(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
|
||||||
}
|
|
||||||
@@ -1,43 +1,48 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"fmt"
|
||||||
"chatplus/core/types"
|
"geekai/core"
|
||||||
"chatplus/store/model"
|
"geekai/core/types"
|
||||||
"chatplus/store/vo"
|
"geekai/store/model"
|
||||||
"chatplus/utils"
|
"geekai/store/vo"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RewardHandler struct {
|
type RewardHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
|
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
||||||
h := RewardHandler{db: db, lock: sync.Mutex{}}
|
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = server
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify 打赏码核销
|
// Verify 打赏码核销
|
||||||
func (h *RewardHandler) Verify(c *gin.Context) {
|
func (h *RewardHandler) Verify(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
TxId string `json:"tx_id"`
|
TxId string `json:"tx_id"`
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.HACKER(c)
|
resp.HACKER(c)
|
||||||
return
|
return
|
||||||
@@ -50,7 +55,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
|||||||
defer h.lock.Unlock()
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
var item model.Reward
|
var item model.Reward
|
||||||
res := h.db.Where("tx_id = ?", data.TxId).First(&item)
|
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "无效的众筹交易流水号!")
|
resp.ERROR(c, "无效的众筹交易流水号!")
|
||||||
return
|
return
|
||||||
@@ -61,18 +66,14 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := h.db.Begin()
|
tx := h.DB.Begin()
|
||||||
exchange := vo.RewardExchange{}
|
exchange := vo.RewardExchange{}
|
||||||
if data.Type == "chat" {
|
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
|
||||||
calls := math.Ceil(item.Amount / h.App.SysConfig.ChatCallPrice)
|
exchange.Power = int(power)
|
||||||
exchange.Calls = int(calls)
|
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
|
||||||
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
|
|
||||||
} else if data.Type == "img" {
|
|
||||||
calls := math.Ceil(item.Amount / h.App.SysConfig.ImgCallPrice)
|
|
||||||
exchange.ImgCalls = int(calls)
|
|
||||||
res = h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", calls))
|
|
||||||
}
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
logger.Error("添加应用失败:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -81,13 +82,26 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
|||||||
item.Status = true
|
item.Status = true
|
||||||
item.UserId = user.Id
|
item.UserId = user.Id
|
||||||
item.Exchange = utils.JsonEncode(exchange)
|
item.Exchange = utils.JsonEncode(exchange)
|
||||||
res = h.db.Updates(&item)
|
res = tx.Updates(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
|
logger.Error("添加应用失败:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录算力充值日志
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerReward,
|
||||||
|
Amount: exchange.Power,
|
||||||
|
Balance: user.Power + exchange.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "众筹支付",
|
||||||
|
Remark: fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,29 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/service/oss"
|
|
||||||
"chatplus/service/sd"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/service/sd"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -20,24 +31,49 @@ import (
|
|||||||
|
|
||||||
type SdJobHandler struct {
|
type SdJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
db *gorm.DB
|
pool *sd.ServicePool
|
||||||
pool *sd.ServicePool
|
uploader *oss.UploaderManager
|
||||||
uploader *oss.UploaderManager
|
snowflake *service.Snowflake
|
||||||
|
leveldb *store.LevelDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
|
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
||||||
h := SdJobHandler{
|
return &SdJobHandler{
|
||||||
db: db,
|
pool: pool,
|
||||||
pool: pool,
|
uploader: manager,
|
||||||
uploader: manager,
|
snowflake: snowflake,
|
||||||
|
leveldb: levelDB,
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
// Client WebSocket 客户端,用于通知任务状态变更
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
func (h *SdJobHandler) Client(c *gin.Context) {
|
||||||
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := h.GetInt(c, "user_id", 0)
|
||||||
|
if userId == 0 {
|
||||||
|
logger.Info("Invalid user ID")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := types.NewWsClient(ws)
|
||||||
|
h.pool.Clients.Put(uint(userId), client)
|
||||||
|
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||||
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return false
|
return false
|
||||||
@@ -48,8 +84,8 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ImgCalls <= 0 {
|
if user.Power < h.App.SysConfig.SdPower {
|
||||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +95,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
|||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *SdJobHandler) Image(c *gin.Context) {
|
func (h *SdJobHandler) Image(c *gin.Context) {
|
||||||
if !h.checkLimits(c) {
|
if !h.preCheck(c) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,23 +128,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
params := types.SdTaskParams{
|
taskId, err := h.snowflake.Next(true)
|
||||||
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
|
if err != nil {
|
||||||
Prompt: data.Prompt,
|
resp.ERROR(c, "error with generate task id: "+err.Error())
|
||||||
NegativePrompt: data.NegativePrompt,
|
return
|
||||||
Steps: data.Steps,
|
|
||||||
Sampler: data.Sampler,
|
|
||||||
FaceFix: data.FaceFix,
|
|
||||||
CfgScale: data.CfgScale,
|
|
||||||
Seed: data.Seed,
|
|
||||||
Height: data.Height,
|
|
||||||
Width: data.Width,
|
|
||||||
HdFix: data.HdFix,
|
|
||||||
HdRedrawRate: data.HdRedrawRate,
|
|
||||||
HdScale: data.HdScale,
|
|
||||||
HdScaleAlg: data.HdScaleAlg,
|
|
||||||
HdSteps: data.HdSteps,
|
|
||||||
}
|
}
|
||||||
|
params := types.SdTaskParams{
|
||||||
|
TaskId: taskId,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
NegPrompt: data.NegPrompt,
|
||||||
|
Steps: data.Steps,
|
||||||
|
Sampler: data.Sampler,
|
||||||
|
FaceFix: data.FaceFix,
|
||||||
|
CfgScale: data.CfgScale,
|
||||||
|
Seed: data.Seed,
|
||||||
|
Height: data.Height,
|
||||||
|
Width: data.Width,
|
||||||
|
HdFix: data.HdFix,
|
||||||
|
HdRedrawRate: data.HdRedrawRate,
|
||||||
|
HdScale: data.HdScale,
|
||||||
|
HdScaleAlg: data.HdScaleAlg,
|
||||||
|
HdSteps: data.HdSteps,
|
||||||
|
}
|
||||||
|
|
||||||
job := model.SdJob{
|
job := model.SdJob{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Type: types.TaskImage.String(),
|
Type: types.TaskImage.String(),
|
||||||
@@ -116,9 +158,10 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
Params: utils.JsonEncode(params),
|
Params: utils.JsonEncode(params),
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
|
Power: h.App.SysConfig.SdPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
res := h.db.Create(&job)
|
res := h.DB.Create(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -128,26 +171,71 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
Id: int(job.Id),
|
Id: int(job.Id),
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Type: types.TaskImage,
|
Type: types.TaskImage,
|
||||||
Prompt: data.Prompt,
|
|
||||||
Params: params,
|
Params: params,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
})
|
})
|
||||||
|
|
||||||
// update user's img calls
|
client := h.pool.Clients.Get(uint(job.UserId))
|
||||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// update user's power
|
||||||
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "stable-diffusion",
|
||||||
|
Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobList 获取 stable diffusion 任务列表
|
// ImgWall 照片墙
|
||||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
func (h *SdJobHandler) ImgWall(c *gin.Context) {
|
||||||
status := h.GetInt(c, "status", 0)
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 0)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
resp.SUCCESS(c, jobs)
|
||||||
if status == 1 {
|
}
|
||||||
|
|
||||||
|
// JobList 获取 SD 任务列表
|
||||||
|
func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||||
|
status := h.GetBool(c, "status")
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 MJ 任务列表
|
||||||
|
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
@@ -155,6 +243,9 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
if userId > 0 {
|
if userId > 0 {
|
||||||
session = session.Where("user_id = ?", userId)
|
session = session.Where("user_id = ?", userId)
|
||||||
}
|
}
|
||||||
|
if publish {
|
||||||
|
session = session.Where("publish", publish)
|
||||||
|
}
|
||||||
if page > 0 && pageSize > 0 {
|
if page > 0 && pageSize > 0 {
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
session = session.Offset(offset).Limit(pageSize)
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
@@ -163,8 +254,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
var items []model.SdJob
|
var items []model.SdJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, types.NoData)
|
return res.Error, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.SdJob, 0)
|
var jobs = make([]vo.SdJob, 0)
|
||||||
@@ -175,33 +265,25 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if job.Progress == -1 {
|
|
||||||
h.db.Delete(&model.SdJob{Id: job.Id})
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.Progress < 100 {
|
if item.Progress < 100 {
|
||||||
// 5 分钟还没完成的任务直接删除
|
// 从 leveldb 中获取图片预览数据
|
||||||
if time.Now().Sub(item.CreatedAt) > time.Minute*5 {
|
var imageData string
|
||||||
h.db.Delete(&item)
|
err = h.leveldb.Get(item.TaskId, &imageData)
|
||||||
// 退回绘图次数
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 正在运行中任务使用代理访问图片
|
|
||||||
image, err := utils.DownloadImage(item.ImgURL, "")
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
job.ImgURL = "data:image/png;base64," + imageData
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
|
||||||
|
return nil, jobs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
func (h *SdJobHandler) Remove(c *gin.Context) {
|
func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
ImgURL string `json:"img_url"`
|
ImgURL string `json:"img_url"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
@@ -210,7 +292,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// remove job recode
|
||||||
res := h.db.Delete(&model.SdJob{Id: data.Id})
|
res := h.DB.Delete(&model.SdJob{Id: data.Id})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -222,5 +304,31 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
|||||||
logger.Error("remove image failed: ", err)
|
logger.Error("remove image failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := h.pool.Clients.Get(data.UserId)
|
||||||
|
if client != nil {
|
||||||
|
_ = client.Send([]byte(sd.Finished))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish 发布/取消发布图片到画廊显示
|
||||||
|
func (h *SdJobHandler) Publish(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database:", res.Error)
|
||||||
|
resp.ERROR(c, "更新数据库失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,23 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"chatplus/service"
|
"geekai/service"
|
||||||
"chatplus/utils"
|
"geekai/service/sms"
|
||||||
"chatplus/utils/resp"
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const CodeStorePrefix = "/verify/codes/"
|
const CodeStorePrefix = "/verify/codes/"
|
||||||
@@ -16,7 +25,7 @@ const CodeStorePrefix = "/verify/codes/"
|
|||||||
type SmsHandler struct {
|
type SmsHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
sms *service.AliYunSmsService
|
sms *sms.ServiceManager
|
||||||
smtp *service.SmtpService
|
smtp *service.SmtpService
|
||||||
captcha *service.CaptchaService
|
captcha *service.CaptchaService
|
||||||
}
|
}
|
||||||
@@ -24,12 +33,15 @@ type SmsHandler struct {
|
|||||||
func NewSmsHandler(
|
func NewSmsHandler(
|
||||||
app *core.AppServer,
|
app *core.AppServer,
|
||||||
client *redis.Client,
|
client *redis.Client,
|
||||||
sms *service.AliYunSmsService,
|
sms *sms.ServiceManager,
|
||||||
smtp *service.SmtpService,
|
smtp *service.SmtpService,
|
||||||
captcha *service.CaptchaService) *SmsHandler {
|
captcha *service.CaptchaService) *SmsHandler {
|
||||||
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha, smtp: smtp}
|
return &SmsHandler{
|
||||||
handler.App = app
|
redis: client,
|
||||||
return handler
|
sms: sms,
|
||||||
|
captcha: captcha,
|
||||||
|
smtp: smtp,
|
||||||
|
BaseHandler: BaseHandler{App: app}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendCode 发送验证码
|
// SendCode 发送验证码
|
||||||
@@ -52,9 +64,18 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
code := utils.RandomNumber(6)
|
code := utils.RandomNumber(6)
|
||||||
var err error
|
var err error
|
||||||
if strings.Contains(data.Receiver, "@") { // email
|
if strings.Contains(data.Receiver, "@") { // email
|
||||||
|
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
|
||||||
|
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||||
|
return
|
||||||
|
}
|
||||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||||
} else {
|
} else {
|
||||||
err = h.sms.SendVerifyCode(data.Receiver, code)
|
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
|
||||||
|
resp.ERROR(c, "系统已禁用手机号注册!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = h.sms.GetService().SendVerifyCode(data.Receiver, code)
|
||||||
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
|
|||||||
@@ -1,59 +1,17 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/service"
|
"geekai/service"
|
||||||
"chatplus/store/model"
|
"geekai/service/payment"
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestHandler struct {
|
type TestHandler struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
|
js *payment.PayJS
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake) *TestHandler {
|
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
|
||||||
return &TestHandler{db: db, snowflake: snowflake}
|
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) Test(c *gin.Context) {
|
|
||||||
h.initUserNickname(c)
|
|
||||||
h.initMjTaskId(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) initUserNickname(c *gin.Context) {
|
|
||||||
var users []model.User
|
|
||||||
tx := h.db.Find(&users)
|
|
||||||
if tx.Error != nil {
|
|
||||||
resp.ERROR(c, tx.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, u := range users {
|
|
||||||
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
|
||||||
h.db.Updates(&u)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) initMjTaskId(c *gin.Context) {
|
|
||||||
var jobs []model.MidJourneyJob
|
|
||||||
tx := h.db.Find(&jobs)
|
|
||||||
if tx.Error != nil {
|
|
||||||
resp.ERROR(c, tx.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, job := range jobs {
|
|
||||||
id, _ := h.snowflake.Next(true)
|
|
||||||
job.TaskId = id
|
|
||||||
h.db.Updates(&job)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,31 +1,101 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"geekai/core"
|
||||||
"chatplus/service/oss"
|
"geekai/service/oss"
|
||||||
"chatplus/utils/resp"
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UploadHandler struct {
|
type UploadHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
uploaderManager *oss.UploaderManager
|
uploaderManager *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
||||||
handler := &UploadHandler{db: db, uploaderManager: manager}
|
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
||||||
handler.App = app
|
|
||||||
return handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||||
fileURL, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, fileURL)
|
userId := h.GetLoginUserId(c)
|
||||||
|
res := h.DB.Create(&model.File{
|
||||||
|
UserId: int(userId),
|
||||||
|
Name: file.Name,
|
||||||
|
ObjKey: file.ObjKey,
|
||||||
|
URL: file.URL,
|
||||||
|
Ext: file.Ext,
|
||||||
|
Size: file.Size,
|
||||||
|
CreatedAt: time.Time{},
|
||||||
|
})
|
||||||
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "error with update database: "+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UploadHandler) List(c *gin.Context) {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
var items []model.File
|
||||||
|
var files = make([]vo.File, 0)
|
||||||
|
h.DB.Where("user_id = ?", userId).Find(&items)
|
||||||
|
if len(items) > 0 {
|
||||||
|
for _, v := range items {
|
||||||
|
var file vo.File
|
||||||
|
err := utils.CopyObject(v, &file)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
file.CreatedAt = v.CreatedAt.Unix()
|
||||||
|
files = append(files, file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, files)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove remove files
|
||||||
|
func (h *UploadHandler) Remove(c *gin.Context) {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
var file model.File
|
||||||
|
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
|
||||||
|
if tx.Error != nil || file.Id == 0 {
|
||||||
|
resp.ERROR(c, "file not existed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove database
|
||||||
|
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
|
||||||
|
if tx.Error != nil || tx.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "failed to update database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// remove files
|
||||||
|
objectKey := file.ObjKey
|
||||||
|
if objectKey == "" {
|
||||||
|
objectKey = file.URL
|
||||||
|
}
|
||||||
|
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
|
||||||
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"geekai/core"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/store/vo"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -20,25 +29,30 @@ import (
|
|||||||
|
|
||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
searcher *xdb.Searcher
|
||||||
searcher *xdb.Searcher
|
redis *redis.Client
|
||||||
redis *redis.Client
|
licenseService *service.LicenseService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(
|
func NewUserHandler(
|
||||||
app *core.AppServer,
|
app *core.AppServer,
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
searcher *xdb.Searcher,
|
searcher *xdb.Searcher,
|
||||||
client *redis.Client) *UserHandler {
|
client *redis.Client,
|
||||||
handler := &UserHandler{db: db, searcher: searcher, redis: client}
|
licenseService *service.LicenseService) *UserHandler {
|
||||||
handler.App = app
|
return &UserHandler{
|
||||||
return handler
|
BaseHandler: BaseHandler{DB: db, App: app},
|
||||||
|
searcher: searcher,
|
||||||
|
redis: client,
|
||||||
|
licenseService: licenseService,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register user register
|
// Register user register
|
||||||
func (h *UserHandler) Register(c *gin.Context) {
|
func (h *UserHandler) Register(c *gin.Context) {
|
||||||
// parameters process
|
// parameters process
|
||||||
var data struct {
|
var data struct {
|
||||||
|
RegWay string `json:"reg_way"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
@@ -54,18 +68,29 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查验证码
|
// 检测最大注册人数
|
||||||
key := CodeStorePrefix + data.Username
|
var totalUser int64
|
||||||
code, err := h.redis.Get(c, key).Result()
|
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||||
if err != nil || code != data.Code {
|
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||||
resp.ERROR(c, "验证码错误")
|
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查验证码
|
||||||
|
var key string
|
||||||
|
if data.RegWay == "email" || data.RegWay == "mobile" {
|
||||||
|
key = CodeStorePrefix + data.Username
|
||||||
|
code, err := h.redis.Get(c, key).Result()
|
||||||
|
if err != nil || code != data.Code {
|
||||||
|
resp.ERROR(c, "验证码错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 验证邀请码
|
// 验证邀请码
|
||||||
inviteCode := model.InviteCode{}
|
inviteCode := model.InviteCode{}
|
||||||
if data.InviteCode != "" {
|
if data.InviteCode != "" {
|
||||||
res := h.db.Where("code = ?", data.InviteCode).First(&inviteCode)
|
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "无效的邀请码")
|
resp.ERROR(c, "无效的邀请码")
|
||||||
return
|
return
|
||||||
@@ -74,8 +99,8 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
|
|
||||||
// check if the username is exists
|
// check if the username is exists
|
||||||
var item model.User
|
var item model.User
|
||||||
res := h.db.Where("username = ?", data.Username).First(&item)
|
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||||||
if res.RowsAffected > 0 {
|
if item.Id > 0 {
|
||||||
resp.ERROR(c, "该用户名已经被注册")
|
resp.ERROR(c, "该用户名已经被注册")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -90,18 +115,10 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
Status: true,
|
Status: true,
|
||||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||||
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
||||||
ChatConfig: utils.JsonEncode(types.UserChatConfig{
|
Power: h.App.SysConfig.InitPower,
|
||||||
ApiKeys: map[types.Platform]string{
|
|
||||||
types.OpenAI: "",
|
|
||||||
types.Azure: "",
|
|
||||||
types.ChatGLM: "",
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
Calls: h.App.SysConfig.InitChatCalls,
|
|
||||||
ImgCalls: h.App.SysConfig.InitImgCalls,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Create(&user)
|
res = h.DB.Create(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "保存数据失败")
|
resp.ERROR(c, "保存数据失败")
|
||||||
logger.Error(res.Error)
|
logger.Error(res.Error)
|
||||||
@@ -111,21 +128,32 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
// 记录邀请关系
|
// 记录邀请关系
|
||||||
if data.InviteCode != "" {
|
if data.InviteCode != "" {
|
||||||
// 增加邀请数量
|
// 增加邀请数量
|
||||||
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||||
if h.App.SysConfig.InviteChatCalls > 0 {
|
if h.App.SysConfig.InvitePower > 0 {
|
||||||
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("calls", gorm.Expr("calls + ?", h.App.SysConfig.InviteChatCalls))
|
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
|
||||||
}
|
// 记录邀请算力充值日志
|
||||||
if h.App.SysConfig.InviteImgCalls > 0 {
|
var inviter model.User
|
||||||
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", h.App.SysConfig.InviteImgCalls))
|
h.DB.Where("id", inviteCode.UserId).First(&inviter)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: inviter.Id,
|
||||||
|
Username: inviter.Username,
|
||||||
|
Type: types.PowerInvite,
|
||||||
|
Amount: h.App.SysConfig.InvitePower,
|
||||||
|
Balance: inviter.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "",
|
||||||
|
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加邀请记录
|
// 添加邀请记录
|
||||||
h.db.Create(&model.InviteLog{
|
h.DB.Create(&model.InviteLog{
|
||||||
InviterId: inviteCode.UserId,
|
InviterId: inviteCode.UserId,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
InviteCode: inviteCode.Code,
|
InviteCode: inviteCode.Code,
|
||||||
Reward: utils.JsonEncode(types.InviteReward{ChatCalls: h.App.SysConfig.InviteChatCalls, ImgCalls: h.App.SysConfig.InviteImgCalls}),
|
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,7 +189,7 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Where("username = ?", data.Username).First(&user)
|
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "用户名不存在")
|
resp.ERROR(c, "用户名不存在")
|
||||||
return
|
return
|
||||||
@@ -181,9 +209,9 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
// 更新最后登录时间和IP
|
// 更新最后登录时间和IP
|
||||||
user.LastLoginIp = c.ClientIP()
|
user.LastLoginIp = c.ClientIP()
|
||||||
user.LastLoginAt = time.Now().Unix()
|
user.LastLoginAt = time.Now().Unix()
|
||||||
h.db.Model(&user).Updates(user)
|
h.DB.Model(&user).Updates(user)
|
||||||
|
|
||||||
h.db.Create(&model.UserLoginLog{
|
h.DB.Create(&model.UserLoginLog{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
LoginIp: c.ClientIP(),
|
LoginIp: c.ClientIP(),
|
||||||
@@ -211,24 +239,16 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
|
|
||||||
// Logout 注 销
|
// Logout 注 销
|
||||||
func (h *UserHandler) Logout(c *gin.Context) {
|
func (h *UserHandler) Logout(c *gin.Context) {
|
||||||
sessionId := c.GetHeader(types.ChatTokenHeader)
|
|
||||||
key := h.GetUserKey(c)
|
key := h.GetUserKey(c)
|
||||||
if _, err := h.redis.Del(c, key).Result(); err != nil {
|
if _, err := h.redis.Del(c, key).Result(); err != nil {
|
||||||
logger.Error("error with delete session: ", err)
|
logger.Error("error with delete session: ", err)
|
||||||
}
|
}
|
||||||
// 删除 websocket 会话列表
|
|
||||||
h.App.ChatSession.Delete(sessionId)
|
|
||||||
// 关闭 socket 连接
|
|
||||||
client := h.App.ChatClients.Get(sessionId)
|
|
||||||
if client != nil {
|
|
||||||
client.Close()
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session 获取/验证会话
|
// Session 获取/验证会话
|
||||||
func (h *UserHandler) Session(c *gin.Context) {
|
func (h *UserHandler) Session(c *gin.Context) {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
err := utils.CopyObject(user, &userVo)
|
err := utils.CopyObject(user, &userVo)
|
||||||
@@ -244,27 +264,23 @@ func (h *UserHandler) Session(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type userProfile struct {
|
type userProfile struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Nickname string `json:"nickname"`
|
Nickname string `json:"nickname"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Avatar string `json:"avatar"`
|
Avatar string `json:"avatar"`
|
||||||
ChatConfig types.UserChatConfig `json:"chat_config"`
|
Power int `json:"power"`
|
||||||
Calls int `json:"calls"`
|
ExpiredTime int64 `json:"expired_time"`
|
||||||
ImgCalls int `json:"img_calls"`
|
Vip bool `json:"vip"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
Tokens int64 `json:"tokens"`
|
|
||||||
ExpiredTime int64 `json:"expired_time"`
|
|
||||||
Vip bool `json:"vip"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserHandler) Profile(c *gin.Context) {
|
func (h *UserHandler) Profile(c *gin.Context) {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.db.First(&user, user.Id)
|
h.DB.First(&user, user.Id)
|
||||||
var profile userProfile
|
var profile userProfile
|
||||||
err = utils.CopyObject(user, &profile)
|
err = utils.CopyObject(user, &profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -284,15 +300,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.db.First(&user, user.Id)
|
h.DB.First(&user, user.Id)
|
||||||
user.Avatar = data.Avatar
|
user.Avatar = data.Avatar
|
||||||
user.Nickname = data.Nickname
|
user.Nickname = data.Nickname
|
||||||
res := h.db.Updates(&user)
|
res := h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新用户信息失败")
|
resp.ERROR(c, "更新用户信息失败")
|
||||||
return
|
return
|
||||||
@@ -317,23 +333,23 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
password := utils.GenPassword(data.OldPass, user.Salt)
|
password := utils.GenPassword(data.OldPass, user.Salt)
|
||||||
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
|
logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
|
||||||
if password != user.Password {
|
if password != user.Password {
|
||||||
resp.ERROR(c, "原密码错误")
|
resp.ERROR(c, "原密码错误")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newPass := utils.GenPassword(data.Password, user.Salt)
|
newPass := utils.GenPassword(data.Password, user.Salt)
|
||||||
res := h.db.Model(&user).UpdateColumn("password", newPass)
|
res := h.DB.Model(&user).UpdateColumn("password", newPass)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("更新数据库失败: ", res.Error)
|
logger.Error("error with update database:", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -354,7 +370,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Where("username", data.Username).First(&user)
|
res := h.DB.Where("username", data.Username).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "用户不存在!")
|
resp.ERROR(c, "用户不存在!")
|
||||||
return
|
return
|
||||||
@@ -370,7 +386,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
user.Password = password
|
user.Password = password
|
||||||
res = h.db.Updates(&user)
|
res = h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c)
|
resp.ERROR(c)
|
||||||
} else {
|
} else {
|
||||||
@@ -400,20 +416,21 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查手机号是否被其他账号绑定
|
// 检查手机号是否被其他账号绑定
|
||||||
var item model.User
|
var item model.User
|
||||||
res := h.db.Where("username = ?", data.Username).First(&item)
|
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
resp.ERROR(c, "该账号已经被其他账号绑定")
|
resp.ERROR(c, "该账号已经被其他账号绑定")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Model(&user).UpdateColumn("username", data.Username)
|
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
logger.Error(res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package logger
|
package logger
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
|
|||||||
198
api/main.go
198
api/main.go
@@ -1,21 +1,30 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/handler"
|
|
||||||
"chatplus/handler/admin"
|
|
||||||
"chatplus/handler/chatimpl"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/service"
|
|
||||||
"chatplus/service/mj"
|
|
||||||
"chatplus/service/oss"
|
|
||||||
"chatplus/service/payment"
|
|
||||||
"chatplus/service/sd"
|
|
||||||
"chatplus/service/wx"
|
|
||||||
"chatplus/store"
|
|
||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/handler/admin"
|
||||||
|
"geekai/handler/chatimpl"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/dalle"
|
||||||
|
"geekai/service/mj"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/service/payment"
|
||||||
|
"geekai/service/sd"
|
||||||
|
"geekai/service/sms"
|
||||||
|
"geekai/service/wx"
|
||||||
|
"geekai/store"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@@ -42,16 +51,20 @@ type AppLifecycle struct {
|
|||||||
|
|
||||||
// OnStart 应用程序启动时执行
|
// OnStart 应用程序启动时执行
|
||||||
func (l *AppLifecycle) OnStart(context.Context) error {
|
func (l *AppLifecycle) OnStart(context.Context) error {
|
||||||
log.Println("AppLifecycle OnStart")
|
logger.Info("AppLifecycle OnStart")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnStop 应用程序停止时执行
|
// OnStop 应用程序停止时执行
|
||||||
func (l *AppLifecycle) OnStop(context.Context) error {
|
func (l *AppLifecycle) OnStop(context.Context) error {
|
||||||
log.Println("AppLifecycle OnStop")
|
logger.Info("AppLifecycle OnStop")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewAppLifeCycle() *AppLifecycle {
|
||||||
|
return &AppLifecycle{}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
configFile := os.Getenv("CONFIG_FILE")
|
configFile := os.Getenv("CONFIG_FILE")
|
||||||
if configFile == "" {
|
if configFile == "" {
|
||||||
@@ -59,11 +72,13 @@ func main() {
|
|||||||
}
|
}
|
||||||
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
|
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
|
||||||
logger.Info("Loading config file: ", configFile)
|
logger.Info("Loading config file: ", configFile)
|
||||||
defer func() {
|
if !debug {
|
||||||
if err := recover(); err != nil {
|
defer func() {
|
||||||
logger.Error("Panic Error:", err)
|
if err := recover(); err != nil {
|
||||||
}
|
logger.Error("Panic Error:", err)
|
||||||
}()
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
app := fx.New(
|
app := fx.New(
|
||||||
// 初始化配置应用配置
|
// 初始化配置应用配置
|
||||||
@@ -89,6 +104,7 @@ func main() {
|
|||||||
fx.Provide(store.NewGormConfig),
|
fx.Provide(store.NewGormConfig),
|
||||||
fx.Provide(store.NewMysql),
|
fx.Provide(store.NewMysql),
|
||||||
fx.Provide(store.NewRedisClient),
|
fx.Provide(store.NewRedisClient),
|
||||||
|
fx.Provide(store.NewLevelDB),
|
||||||
|
|
||||||
fx.Provide(func() embed.FS {
|
fx.Provide(func() embed.FS {
|
||||||
return xdbFS
|
return xdbFS
|
||||||
@@ -122,6 +138,8 @@ func main() {
|
|||||||
fx.Provide(handler.NewPaymentHandler),
|
fx.Provide(handler.NewPaymentHandler),
|
||||||
fx.Provide(handler.NewOrderHandler),
|
fx.Provide(handler.NewOrderHandler),
|
||||||
fx.Provide(handler.NewProductHandler),
|
fx.Provide(handler.NewProductHandler),
|
||||||
|
fx.Provide(handler.NewConfigHandler),
|
||||||
|
fx.Provide(handler.NewPowerLogHandler),
|
||||||
|
|
||||||
fx.Provide(admin.NewConfigHandler),
|
fx.Provide(admin.NewConfigHandler),
|
||||||
fx.Provide(admin.NewAdminHandler),
|
fx.Provide(admin.NewAdminHandler),
|
||||||
@@ -133,17 +151,31 @@ func main() {
|
|||||||
fx.Provide(admin.NewChatModelHandler),
|
fx.Provide(admin.NewChatModelHandler),
|
||||||
fx.Provide(admin.NewProductHandler),
|
fx.Provide(admin.NewProductHandler),
|
||||||
fx.Provide(admin.NewOrderHandler),
|
fx.Provide(admin.NewOrderHandler),
|
||||||
|
fx.Provide(admin.NewChatHandler),
|
||||||
|
fx.Provide(admin.NewPowerLogHandler),
|
||||||
|
|
||||||
// 创建服务
|
// 创建服务
|
||||||
fx.Provide(service.NewAliYunSmsService),
|
fx.Provide(sms.NewSendServiceManager),
|
||||||
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
|
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
|
||||||
return service.NewCaptchaService(config.ApiConfig)
|
return service.NewCaptchaService(config.ApiConfig)
|
||||||
}),
|
}),
|
||||||
fx.Provide(oss.NewUploaderManager),
|
fx.Provide(oss.NewUploaderManager),
|
||||||
fx.Provide(mj.NewService),
|
fx.Provide(mj.NewService),
|
||||||
|
fx.Provide(dalle.NewService),
|
||||||
|
fx.Invoke(func(service *dalle.Service) {
|
||||||
|
service.Run()
|
||||||
|
service.CheckTaskNotify()
|
||||||
|
service.DownloadImages()
|
||||||
|
service.CheckTaskStatus()
|
||||||
|
}),
|
||||||
|
|
||||||
// 邮件服务
|
// 邮件服务
|
||||||
fx.Provide(service.NewSmtpService),
|
fx.Provide(service.NewSmtpService),
|
||||||
|
// License 服务
|
||||||
|
fx.Provide(service.NewLicenseService),
|
||||||
|
fx.Invoke(func(licenseService *service.LicenseService) {
|
||||||
|
licenseService.SyncLicense()
|
||||||
|
}),
|
||||||
|
|
||||||
// 微信机器人服务
|
// 微信机器人服务
|
||||||
fx.Provide(wx.NewWeChatBot),
|
fx.Provide(wx.NewWeChatBot),
|
||||||
@@ -158,18 +190,28 @@ func main() {
|
|||||||
|
|
||||||
// MidJourney service pool
|
// MidJourney service pool
|
||||||
fx.Provide(mj.NewServicePool),
|
fx.Provide(mj.NewServicePool),
|
||||||
fx.Invoke(func(pool *mj.ServicePool) {
|
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
|
||||||
|
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
|
||||||
if pool.HasAvailableService() {
|
if pool.HasAvailableService() {
|
||||||
pool.DownloadImages()
|
pool.DownloadImages()
|
||||||
pool.CheckTaskNotify()
|
pool.CheckTaskNotify()
|
||||||
|
pool.SyncTaskProgress()
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewServicePool),
|
fx.Provide(sd.NewServicePool),
|
||||||
|
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
|
||||||
|
pool.InitServices(config.SdConfigs)
|
||||||
|
if pool.HasAvailableService() {
|
||||||
|
pool.CheckTaskNotify()
|
||||||
|
pool.CheckTaskStatus()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
fx.Provide(payment.NewAlipayService),
|
fx.Provide(payment.NewAlipayService),
|
||||||
fx.Provide(payment.NewHuPiPay),
|
fx.Provide(payment.NewHuPiPay),
|
||||||
|
fx.Provide(payment.NewPayJS),
|
||||||
fx.Provide(service.NewSnowflake),
|
fx.Provide(service.NewSnowflake),
|
||||||
fx.Provide(service.NewXXLJobExecutor),
|
fx.Provide(service.NewXXLJobExecutor),
|
||||||
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
|
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
|
||||||
@@ -212,6 +254,8 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
|
||||||
s.Engine.POST("/api/upload", h.Upload)
|
s.Engine.POST("/api/upload", h.Upload)
|
||||||
|
s.Engine.GET("/api/upload/list", h.List)
|
||||||
|
s.Engine.GET("/api/upload/remove", h.Remove)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
||||||
group := s.Engine.Group("/api/sms/")
|
group := s.Engine.Group("/api/sms/")
|
||||||
@@ -221,6 +265,8 @@ func main() {
|
|||||||
group := s.Engine.Group("/api/captcha/")
|
group := s.Engine.Group("/api/captcha/")
|
||||||
group.GET("get", h.Get)
|
group.GET("get", h.Get)
|
||||||
group.POST("check", h.Check)
|
group.POST("check", h.Check)
|
||||||
|
group.GET("slide/get", h.SlideGet)
|
||||||
|
group.POST("slide/check", h.SlideCheck)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
||||||
group := s.Engine.Group("/api/reward/")
|
group := s.Engine.Group("/api/reward/")
|
||||||
@@ -233,26 +279,45 @@ func main() {
|
|||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
|
group.POST("publish", h.Publish)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||||
group := s.Engine.Group("/api/sd")
|
group := s.Engine.Group("/api/sd")
|
||||||
|
group.Any("client", h.Client)
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
|
group.POST("publish", h.Publish)
|
||||||
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||||
|
group := s.Engine.Group("/api/config/")
|
||||||
|
group.GET("get", h.Get)
|
||||||
|
group.GET("license", h.License)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||||
group := s.Engine.Group("/api/admin/config/")
|
group := s.Engine.Group("/api/admin/")
|
||||||
group.POST("update", h.Update)
|
group.POST("config/update", h.Update)
|
||||||
group.GET("get", h.Get)
|
group.GET("config/get", h.Get)
|
||||||
|
group.POST("active", h.Active)
|
||||||
|
group.GET("config/get/license", h.GetLicense)
|
||||||
|
group.GET("config/get/app", h.GetAppConfig)
|
||||||
|
group.POST("config/update/draw", h.SaveDrawingConfig)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||||
group := s.Engine.Group("/api/admin/")
|
group := s.Engine.Group("/api/admin/")
|
||||||
group.POST("login", h.Login)
|
group.POST("login", h.Login)
|
||||||
group.GET("logout", h.Logout)
|
group.GET("logout", h.Logout)
|
||||||
group.GET("session", h.Session)
|
group.GET("session", h.Session)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("resetPass", h.ResetPass)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
||||||
group := s.Engine.Group("/api/admin/apikey/")
|
group := s.Engine.Group("/api/admin/apikey/")
|
||||||
@@ -280,7 +345,7 @@ func main() {
|
|||||||
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
|
||||||
group := s.Engine.Group("/api/admin/reward/")
|
group := s.Engine.Group("/api/admin/reward/")
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.GET("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||||
group := s.Engine.Group("/api/admin/dashboard/")
|
group := s.Engine.Group("/api/admin/dashboard/")
|
||||||
@@ -304,8 +369,10 @@ func main() {
|
|||||||
group.GET("payWays", h.GetPayWays)
|
group.GET("payWays", h.GetPayWays)
|
||||||
group.POST("query", h.OrderQuery)
|
group.POST("query", h.OrderQuery)
|
||||||
group.POST("qrcode", h.PayQrcode)
|
group.POST("qrcode", h.PayQrcode)
|
||||||
|
group.POST("mobile", h.Mobile)
|
||||||
group.POST("alipay/notify", h.AlipayNotify)
|
group.POST("alipay/notify", h.AlipayNotify)
|
||||||
group.POST("hupipay/notify", h.HuPiPayNotify)
|
group.POST("hupipay/notify", h.HuPiPayNotify)
|
||||||
|
group.POST("payjs/notify", h.PayJsNotify)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||||
group := s.Engine.Group("/api/admin/product/")
|
group := s.Engine.Group("/api/admin/product/")
|
||||||
@@ -337,13 +404,6 @@ func main() {
|
|||||||
group.GET("hits", h.Hits)
|
group.GET("hits", h.Hits)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewPromptHandler),
|
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
|
||||||
group := s.Engine.Group("/api/prompt/")
|
|
||||||
group.POST("rewrite", h.Rewrite)
|
|
||||||
group.POST("translate", h.Translate)
|
|
||||||
}),
|
|
||||||
|
|
||||||
fx.Provide(admin.NewFunctionHandler),
|
fx.Provide(admin.NewFunctionHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
||||||
group := s.Engine.Group("/api/admin/function/")
|
group := s.Engine.Group("/api/admin/function/")
|
||||||
@@ -354,6 +414,18 @@ func main() {
|
|||||||
group.GET("token", h.GenToken)
|
group.GET("token", h.GenToken)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
// 验证码
|
||||||
|
fx.Provide(admin.NewCaptchaHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/login/")
|
||||||
|
group.GET("captcha", h.GetCaptcha)
|
||||||
|
}),
|
||||||
|
|
||||||
|
fx.Provide(admin.NewUploadHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
||||||
|
s.Engine.POST("/api/admin/upload", h.Upload)
|
||||||
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewFunctionHandler),
|
fx.Provide(handler.NewFunctionHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
||||||
group := s.Engine.Group("/api/function/")
|
group := s.Engine.Group("/api/function/")
|
||||||
@@ -361,18 +433,60 @@ func main() {
|
|||||||
group.POST("zaobao", h.ZaoBao)
|
group.POST("zaobao", h.ZaoBao)
|
||||||
group.POST("dalle3", h.Dall3)
|
group.POST("dalle3", h.Dall3)
|
||||||
}),
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
||||||
fx.Provide(handler.NewTestHandler),
|
group := s.Engine.Group("/api/admin/chat/")
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
group.POST("list", h.List)
|
||||||
s.Engine.GET("/api/test", h.Test)
|
group.POST("message", h.Messages)
|
||||||
|
group.GET("history", h.History)
|
||||||
|
group.GET("remove", h.RemoveChat)
|
||||||
|
group.GET("message/remove", h.RemoveMessage)
|
||||||
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
|
||||||
|
group := s.Engine.Group("/api/powerLog/")
|
||||||
|
group.POST("list", h.List)
|
||||||
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/powerLog/")
|
||||||
|
group.POST("list", h.List)
|
||||||
|
}),
|
||||||
|
fx.Provide(admin.NewMenuHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/menu/")
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewMenuHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
|
||||||
|
group := s.Engine.Group("/api/menu/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewMarkMapHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||||
|
group := s.Engine.Group("/api/markMap/")
|
||||||
|
group.Any("client", h.Client)
|
||||||
|
}),
|
||||||
|
fx.Provide(handler.NewDallJobHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||||
|
group := s.Engine.Group("/api/dall")
|
||||||
|
group.Any("client", h.Client)
|
||||||
|
group.POST("image", h.Image)
|
||||||
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
|
group.POST("remove", h.Remove)
|
||||||
|
group.POST("publish", h.Publish)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||||
err := s.Run(db)
|
go func() {
|
||||||
if err != nil {
|
err := s.Run(db)
|
||||||
log.Fatal(err)
|
if err != nil {
|
||||||
}
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}),
|
}),
|
||||||
|
fx.Provide(NewAppLifeCycle),
|
||||||
// 注册生命周期回调函数
|
// 注册生命周期回调函数
|
||||||
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
|
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
|
||||||
lifecycle.Append(fx.Hook{
|
lifecycle.Append(fx.Hook{
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
{
|
|
||||||
"data": [
|
|
||||||
"task(cxvkpawy8onnfti)",
|
|
||||||
"a cute girl",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
20,
|
|
||||||
"DPM++ 2M Karras",
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
7,
|
|
||||||
512,
|
|
||||||
512,
|
|
||||||
false,
|
|
||||||
0.7,
|
|
||||||
2,
|
|
||||||
"Latent",
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
"Use same checkpoint",
|
|
||||||
"Use same sampler",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"None",
|
|
||||||
false,
|
|
||||||
"",
|
|
||||||
0.8,
|
|
||||||
-1,
|
|
||||||
false,
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"positive",
|
|
||||||
"comma",
|
|
||||||
0,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"",
|
|
||||||
"Seed",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"Nothing",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"Nothing",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
50,
|
|
||||||
[],
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
""
|
|
||||||
],
|
|
||||||
"event_data": null,
|
|
||||||
"fn_index": 446,
|
|
||||||
"session_hash": "nk5noh1rz1o"
|
|
||||||
}
|
|
||||||
@@ -1,19 +1,26 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CaptchaService struct {
|
type CaptchaService struct {
|
||||||
config types.ChatPlusApiConfig
|
config types.ApiConfig
|
||||||
client *req.Client
|
client *req.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService {
|
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
|
||||||
return &CaptchaService{
|
return &CaptchaService{
|
||||||
config: config,
|
config: config,
|
||||||
client: req.C().SetTimeout(10 * time.Second),
|
client: req.C().SetTimeout(10 * time.Second),
|
||||||
@@ -60,3 +67,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||||
|
if s.config.Token == "" {
|
||||||
|
return nil, errors.New("无效的 API Token")
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
|
||||||
|
var res types.BizVo
|
||||||
|
r, err := s.client.R().
|
||||||
|
SetHeader("AppId", s.config.AppId).
|
||||||
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||||
|
SetSuccessResult(&res).Get(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaptchaService) SlideCheck(data interface{}) bool {
|
||||||
|
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
|
||||||
|
var res types.BizVo
|
||||||
|
r, err := s.client.R().
|
||||||
|
SetHeader("AppId", s.config.AppId).
|
||||||
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||||
|
SetBodyJsonMarshal(data).
|
||||||
|
SetSuccessResult(&res).Post(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
313
api/service/dalle/service.go
Normal file
313
api/service/dalle/service.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package dalle
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/service/sd"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
// DALL-E 绘画服务
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
httpClient *req.Client
|
||||||
|
db *gorm.DB
|
||||||
|
uploadManager *oss.UploaderManager
|
||||||
|
taskQueue *store.RedisQueue
|
||||||
|
notifyQueue *store.RedisQueue
|
||||||
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||||
|
return &Service{
|
||||||
|
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||||
|
db: db,
|
||||||
|
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||||
|
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||||
|
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||||
|
uploadManager: manager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushTask push a new mj task in to task queue
|
||||||
|
func (s *Service) PushTask(task types.DallTask) {
|
||||||
|
logger.Infof("add a new DALL-E task to the task list: %+v", task)
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Run() {
|
||||||
|
logger.Info("Starting DALL-E job consumer...")
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var task types.DallTask
|
||||||
|
err := s.taskQueue.LPop(&task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("taking task with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||||
|
_, err = s.Image(task, false)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with image task: %v", err)
|
||||||
|
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": -1,
|
||||||
|
"err_msg": err.Error(),
|
||||||
|
})
|
||||||
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
type imgReq struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
Quality string `json:"quality"`
|
||||||
|
Style string `json:"style"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type imgRes struct {
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
RevisedPrompt string `json:"revised_prompt"`
|
||||||
|
Url string `json:"url"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrRes struct {
|
||||||
|
Error struct {
|
||||||
|
Code interface{} `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Param interface{} `json:"param"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||||
|
logger.Debugf("绘画参数:%+v", task)
|
||||||
|
prompt := task.Prompt
|
||||||
|
// translate prompt
|
||||||
|
if utils.HasChinese(task.Prompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
prompt = content
|
||||||
|
logger.Debugf("重写后提示词:%s", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.User
|
||||||
|
s.db.Where("id", task.UserId).First(&user)
|
||||||
|
if user.Power < task.Power {
|
||||||
|
return "", errors.New("insufficient of power")
|
||||||
|
}
|
||||||
|
|
||||||
|
// get image generation API KEY
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
tx := s.db.Where("platform", types.OpenAI.Value).
|
||||||
|
Where("type", "img").
|
||||||
|
Where("enabled", true).
|
||||||
|
Order("last_used_at ASC").First(&apiKey)
|
||||||
|
if tx.Error != nil {
|
||||||
|
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var res imgRes
|
||||||
|
var errRes ErrRes
|
||||||
|
if len(apiKey.ProxyURL) > 5 {
|
||||||
|
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
|
||||||
|
}
|
||||||
|
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
|
||||||
|
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(imgReq{
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Prompt: prompt,
|
||||||
|
N: 1,
|
||||||
|
Size: task.Size,
|
||||||
|
Style: task.Style,
|
||||||
|
Quality: task.Quality,
|
||||||
|
}).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
SetSuccessResult(&res).Post(apiKey.ApiURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with send request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return "", fmt.Errorf("error with send request: %v", errRes.Error)
|
||||||
|
}
|
||||||
|
// update the api key last use time
|
||||||
|
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
// update task progress
|
||||||
|
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": 100,
|
||||||
|
"org_url": res.Data[0].Url,
|
||||||
|
"prompt": prompt,
|
||||||
|
})
|
||||||
|
if tx.Error != nil {
|
||||||
|
return "", fmt.Errorf("err with update database: %v", tx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
|
||||||
|
var content string
|
||||||
|
if sync {
|
||||||
|
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with download image: %v", err)
|
||||||
|
}
|
||||||
|
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户算力
|
||||||
|
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
var u model.User
|
||||||
|
s.db.Where("id", user.Id).First(&u)
|
||||||
|
s.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: task.Power,
|
||||||
|
Balance: u.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running DALL-E task notify checking ...")
|
||||||
|
for {
|
||||||
|
var message sd.NotifyMessage
|
||||||
|
err := s.notifyQueue.LPop(&message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client := s.Clients.Get(uint(message.UserId))
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = client.Send([]byte(message.Message))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) DownloadImages() {
|
||||||
|
go func() {
|
||||||
|
var items []model.DallJob
|
||||||
|
for {
|
||||||
|
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// download images
|
||||||
|
for _, v := range items {
|
||||||
|
if v.OrgURL == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("try to download image: %s", v.OrgURL)
|
||||||
|
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: %s, error: %v", imgURL, err)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
logger.Infof("download image %s successfully.", v.OrgURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
|
||||||
|
// sava image
|
||||||
|
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update img_url
|
||||||
|
res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
|
||||||
|
if res.Error != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
|
||||||
|
return imgURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
|
func (s *Service) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
|
for {
|
||||||
|
var jobs []model.DallJob
|
||||||
|
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 5 分钟还没完成的任务直接删除
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
||||||
|
s.db.Delete(&job)
|
||||||
|
var user model.User
|
||||||
|
s.db.Where("id = ?", job.UserId).First(&user)
|
||||||
|
// 退回绘图次数
|
||||||
|
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
|
if res.Error == nil && res.RowsAffected > 0 {
|
||||||
|
s.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power + job.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
197
api/service/license_service.go
Normal file
197
api/service/license_service.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LicenseService struct {
|
||||||
|
config types.ApiConfig
|
||||||
|
levelDB *store.LevelDB
|
||||||
|
license *types.License
|
||||||
|
urlWhiteList []string
|
||||||
|
machineId string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||||
|
var license types.License
|
||||||
|
return &LicenseService{
|
||||||
|
config: server.Config.ApiConfig,
|
||||||
|
levelDB: levelDB,
|
||||||
|
license: &license,
|
||||||
|
machineId: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type License struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
License string `json:"license"`
|
||||||
|
MachineId string `json:"mid"`
|
||||||
|
ActiveAt int64 `json:"active_at"`
|
||||||
|
ExpiredAt int64 `json:"expired_at"`
|
||||||
|
UserNum int `json:"user_num"`
|
||||||
|
Configs types.LicenseConfig `json:"configs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActiveLicense 激活 License
|
||||||
|
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||||
|
var res struct {
|
||||||
|
Code types.BizCode `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data License `json:"data"`
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
|
||||||
|
response, err := req.C().R().
|
||||||
|
SetBody(map[string]string{"license": license, "machine_id": machineId}).
|
||||||
|
SetSuccessResult(&res).Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("发送激活请求失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.IsErrorState() {
|
||||||
|
return fmt.Errorf("发送激活请求失败:%v", response.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return fmt.Errorf("激活失败:%v", res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.license = &types.License{
|
||||||
|
Key: license,
|
||||||
|
MachineId: machineId,
|
||||||
|
Configs: res.Data.Configs,
|
||||||
|
ExpiredAt: res.Data.ExpiredAt,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
err = s.levelDB.Put(types.LicenseKey, s.license)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("保存许可证书失败:%v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncLicense 定期同步 License
|
||||||
|
func (s *LicenseService) SyncLicense() {
|
||||||
|
go func() {
|
||||||
|
retryCounter := 0
|
||||||
|
for {
|
||||||
|
license, err := s.fetchLicense()
|
||||||
|
if err != nil {
|
||||||
|
retryCounter++
|
||||||
|
if retryCounter < 5 {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
s.license.IsActive = false
|
||||||
|
} else {
|
||||||
|
s.license = license
|
||||||
|
}
|
||||||
|
|
||||||
|
urls, err := s.fetchUrlWhiteList()
|
||||||
|
if err == nil {
|
||||||
|
s.urlWhiteList = urls
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||||
|
//var res struct {
|
||||||
|
// Code types.BizCode `json:"code"`
|
||||||
|
// Message string `json:"message"`
|
||||||
|
// Data License `json:"data"`
|
||||||
|
//}
|
||||||
|
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
||||||
|
//response, err := req.C().R().
|
||||||
|
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||||
|
// SetSuccessResult(&res).Post(apiURL)
|
||||||
|
//if err != nil {
|
||||||
|
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
||||||
|
//}
|
||||||
|
//if response.IsErrorState() {
|
||||||
|
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
||||||
|
//}
|
||||||
|
//if res.Code != types.Success {
|
||||||
|
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
||||||
|
//}
|
||||||
|
|
||||||
|
return &types.License{
|
||||||
|
Key: "abc",
|
||||||
|
MachineId: "abc",
|
||||||
|
Configs: types.LicenseConfig{
|
||||||
|
UserNum: 10000,
|
||||||
|
DeCopy: false,
|
||||||
|
},
|
||||||
|
ExpiredAt: 0,
|
||||||
|
IsActive: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||||
|
var res struct {
|
||||||
|
Code types.BizCode `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data []string `json:"data"`
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
|
||||||
|
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("发送请求失败: %v", err)
|
||||||
|
}
|
||||||
|
if response.IsErrorState() {
|
||||||
|
return nil, fmt.Errorf("发送请求失败:%v", response.Status)
|
||||||
|
}
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return nil, fmt.Errorf("获取白名单失败:%v", res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLicense 获取许可信息
|
||||||
|
func (s *LicenseService) GetLicense() *types.License {
|
||||||
|
return s.license
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidApiURL 判断是否合法的中转 URL
|
||||||
|
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||||
|
// 获得许可授权的直接放行
|
||||||
|
return nil
|
||||||
|
//if s.license.IsActive {
|
||||||
|
// if s.license.MachineId != s.machineId {
|
||||||
|
// return errors.New("系统使用了盗版的许可证书")
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if time.Now().Unix() > s.license.ExpiredAt {
|
||||||
|
// return errors.New("系统许可证书已经过期")
|
||||||
|
// }
|
||||||
|
// return nil
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//if len(s.urlWhiteList) == 0 {
|
||||||
|
// urls, err := s.fetchUrlWhiteList()
|
||||||
|
// if err == nil {
|
||||||
|
// s.urlWhiteList = urls
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//for _, v := range s.urlWhiteList {
|
||||||
|
// if strings.HasPrefix(uri, v) {
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||||
|
}
|
||||||
@@ -1,233 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/utils"
|
|
||||||
discordgo "github.com/bg5t/mydiscordgo"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MidJourney 机器人
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
type Bot struct {
|
|
||||||
config types.MidJourneyConfig
|
|
||||||
bot *discordgo.Session
|
|
||||||
name string
|
|
||||||
service *Service
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
|
|
||||||
bot, err := discordgo.New("Bot " + config.BotToken)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// use CDN reverse proxy
|
|
||||||
if config.UseCDN {
|
|
||||||
discordgo.SetEndpointDiscord(config.DiscordAPI)
|
|
||||||
discordgo.SetEndpointCDN(config.DiscordCDN)
|
|
||||||
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
|
|
||||||
bot.MjGateway = config.DiscordGateway + "/"
|
|
||||||
} else { // use proxy
|
|
||||||
discordgo.SetEndpointDiscord("https://discord.com")
|
|
||||||
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
|
|
||||||
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
|
|
||||||
bot.MjGateway = "wss://gateway.discord.gg"
|
|
||||||
|
|
||||||
if proxy != "" {
|
|
||||||
proxy, _ := url.Parse(proxy)
|
|
||||||
bot.Client = &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
bot.Dialer = &websocket.Dialer{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Bot{
|
|
||||||
config: config,
|
|
||||||
bot: bot,
|
|
||||||
name: name,
|
|
||||||
service: service,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) Run() error {
|
|
||||||
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
|
|
||||||
b.bot.AddHandler(b.messageCreate)
|
|
||||||
b.bot.AddHandler(b.messageUpdate)
|
|
||||||
|
|
||||||
logger.Infof("Starting MidJourney %s", b.name)
|
|
||||||
err := b.bot.Open()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Infof("Starting MidJourney %s successfully!", b.name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TaskStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Start = TaskStatus("Started")
|
|
||||||
Running = TaskStatus("Running")
|
|
||||||
Stopped = TaskStatus("Stopped")
|
|
||||||
Finished = TaskStatus("Finished")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Image struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
ProxyURL string `json:"proxy_url"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Width int `json:"width"`
|
|
||||||
Height int `json:"height"`
|
|
||||||
Size int `json:"size"`
|
|
||||||
Hash string `json:"hash"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|
||||||
// ignore messages for other channels
|
|
||||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ignore messages for self
|
|
||||||
if m.Author == nil || m.Author.ID == s.State.User.ID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
|
|
||||||
var referenceId = ""
|
|
||||||
if m.ReferencedMessage != nil {
|
|
||||||
referenceId = m.ReferencedMessage.ID
|
|
||||||
}
|
|
||||||
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
|
||||||
// parse content
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: m.ChannelID,
|
|
||||||
MessageId: m.ID,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Prompt: extractPrompt(m.Content),
|
|
||||||
Content: m.Content,
|
|
||||||
Progress: 0,
|
|
||||||
Status: Start}
|
|
||||||
b.service.Notify(req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
|
||||||
// ignore messages for other channels
|
|
||||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ignore messages for self
|
|
||||||
if m.Author == nil || m.Author.ID == s.State.User.ID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
|
|
||||||
|
|
||||||
var referenceId = ""
|
|
||||||
if m.ReferencedMessage != nil {
|
|
||||||
referenceId = m.ReferencedMessage.ID
|
|
||||||
}
|
|
||||||
if strings.Contains(m.Content, "(Stopped)") {
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: m.ChannelID,
|
|
||||||
MessageId: m.ID,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Prompt: extractPrompt(m.Content),
|
|
||||||
Content: m.Content,
|
|
||||||
Progress: extractProgress(m.Content),
|
|
||||||
Status: Stopped}
|
|
||||||
b.service.Notify(req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
|
||||||
progress := extractProgress(content)
|
|
||||||
var status TaskStatus
|
|
||||||
if progress == 100 {
|
|
||||||
status = Finished
|
|
||||||
} else {
|
|
||||||
status = Running
|
|
||||||
}
|
|
||||||
for _, attachment := range attachments {
|
|
||||||
if attachment.Width == 0 || attachment.Height == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
image := Image{
|
|
||||||
URL: attachment.URL,
|
|
||||||
Height: attachment.Height,
|
|
||||||
ProxyURL: attachment.ProxyURL,
|
|
||||||
Width: attachment.Width,
|
|
||||||
Size: attachment.Size,
|
|
||||||
Filename: attachment.Filename,
|
|
||||||
Hash: extractHashFromFilename(attachment.Filename),
|
|
||||||
}
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: channelId,
|
|
||||||
MessageId: messageId,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Image: image,
|
|
||||||
Prompt: extractPrompt(content),
|
|
||||||
Content: content,
|
|
||||||
Progress: progress,
|
|
||||||
Status: status,
|
|
||||||
}
|
|
||||||
b.service.Notify(req)
|
|
||||||
break // only get one image
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract prompt from string
|
|
||||||
func extractPrompt(input string) string {
|
|
||||||
pattern := `\*\*(.*?)\*\*`
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(input)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return strings.TrimSpace(matches[1])
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractProgress(input string) int {
|
|
||||||
pattern := `\((\d+)\%\)`
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(input)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return utils.IntValue(matches[1], 0)
|
|
||||||
}
|
|
||||||
return 100
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractHashFromFilename(filename string) string {
|
|
||||||
if !strings.HasSuffix(filename, ".png") {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
index := strings.LastIndex(filename, "_")
|
|
||||||
if index != -1 {
|
|
||||||
return filename[index+1 : len(filename)-4]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,150 +1,68 @@
|
|||||||
package mj
|
package mj
|
||||||
|
|
||||||
import (
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
"chatplus/core/types"
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
"fmt"
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
"time"
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
import "geekai/core/types"
|
||||||
)
|
|
||||||
|
|
||||||
// MidJourney client
|
type Client interface {
|
||||||
|
Imagine(task types.MjTask) (ImageRes, error)
|
||||||
type Client struct {
|
Blend(task types.MjTask) (ImageRes, error)
|
||||||
client *req.Client
|
SwapFace(task types.MjTask) (ImageRes, error)
|
||||||
Config types.MidJourneyConfig
|
Upscale(task types.MjTask) (ImageRes, error)
|
||||||
apiURL string
|
Variation(task types.MjTask) (ImageRes, error)
|
||||||
|
QueryTask(taskId string) (QueryRes, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
type ImageReq struct {
|
||||||
client := req.C().SetTimeout(10 * time.Second)
|
BotType string `json:"botType,omitempty"`
|
||||||
var apiURL string
|
Prompt string `json:"prompt,omitempty"`
|
||||||
// set proxy URL
|
Dimensions string `json:"dimensions,omitempty"`
|
||||||
if config.UseCDN {
|
Base64Array []string `json:"base64Array,omitempty"`
|
||||||
apiURL = config.DiscordAPI + "/api/v9/interactions"
|
AccountFilter interface{} `json:"accountFilter,omitempty"`
|
||||||
} else {
|
NotifyHook string `json:"notifyHook,omitempty"`
|
||||||
apiURL = "https://discord.com/api/v9/interactions"
|
State string `json:"state,omitempty"`
|
||||||
if proxy != "" {
|
|
||||||
client.SetProxyURL(proxy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{client: client, Config: config, apiURL: apiURL}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Imagine(prompt string) error {
|
type ImageRes struct {
|
||||||
interactionsReq := &InteractionsRequest{
|
Code int `json:"code"`
|
||||||
Type: 2,
|
Description string `json:"description"`
|
||||||
ApplicationID: ApplicationID,
|
Properties struct {
|
||||||
GuildID: c.Config.GuildId,
|
} `json:"properties"`
|
||||||
ChannelID: c.Config.ChanelId,
|
Result string `json:"result"`
|
||||||
SessionID: SessionID,
|
|
||||||
Data: map[string]any{
|
|
||||||
"version": "1166847114203123795",
|
|
||||||
"id": "938956540159881230",
|
|
||||||
"name": "imagine",
|
|
||||||
"type": "1",
|
|
||||||
"options": []map[string]any{
|
|
||||||
{
|
|
||||||
"type": 3,
|
|
||||||
"name": "prompt",
|
|
||||||
"value": prompt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"application_command": map[string]any{
|
|
||||||
"id": "938956540159881230",
|
|
||||||
"application_id": ApplicationID,
|
|
||||||
"version": "1118961510123847772",
|
|
||||||
"default_permission": true,
|
|
||||||
"default_member_permissions": nil,
|
|
||||||
"type": 1,
|
|
||||||
"nsfw": false,
|
|
||||||
"name": "imagine",
|
|
||||||
"description": "Create images with Midjourney",
|
|
||||||
"dm_permission": true,
|
|
||||||
"options": []map[string]any{
|
|
||||||
{
|
|
||||||
"type": 3,
|
|
||||||
"name": "prompt",
|
|
||||||
"description": "The prompt to imagine",
|
|
||||||
"required": true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"attachments": []any{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
Post(c.apiURL)
|
|
||||||
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %w%v", err, r.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upscale 放大指定的图片
|
type ErrRes struct {
|
||||||
func (c *Client) Upscale(index int, messageId string, hash string) error {
|
Error struct {
|
||||||
flags := 0
|
Message string `json:"message"`
|
||||||
interactionsReq := &InteractionsRequest{
|
} `json:"error"`
|
||||||
Type: 3,
|
|
||||||
ApplicationID: ApplicationID,
|
|
||||||
GuildID: c.Config.GuildId,
|
|
||||||
ChannelID: c.Config.ChanelId,
|
|
||||||
MessageFlags: &flags,
|
|
||||||
MessageID: &messageId,
|
|
||||||
SessionID: SessionID,
|
|
||||||
Data: map[string]any{
|
|
||||||
"component_type": 2,
|
|
||||||
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
|
|
||||||
},
|
|
||||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
||||||
}
|
|
||||||
|
|
||||||
var res InteractionsResult
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
SetErrorResult(&res).
|
|
||||||
Post(c.apiURL)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
type QueryRes struct {
|
||||||
func (c *Client) Variation(index int, messageId string, hash string) error {
|
Action string `json:"action"`
|
||||||
flags := 0
|
Buttons []struct {
|
||||||
interactionsReq := &InteractionsRequest{
|
CustomId string `json:"customId"`
|
||||||
Type: 3,
|
Emoji string `json:"emoji"`
|
||||||
ApplicationID: ApplicationID,
|
Label string `json:"label"`
|
||||||
GuildID: c.Config.GuildId,
|
Style int `json:"style"`
|
||||||
ChannelID: c.Config.ChanelId,
|
Type int `json:"type"`
|
||||||
MessageFlags: &flags,
|
} `json:"buttons"`
|
||||||
MessageID: &messageId,
|
Description string `json:"description"`
|
||||||
SessionID: SessionID,
|
FailReason string `json:"failReason"`
|
||||||
Data: map[string]any{
|
FinishTime int `json:"finishTime"`
|
||||||
"component_type": 2,
|
Id string `json:"id"`
|
||||||
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
|
ImageUrl string `json:"imageUrl"`
|
||||||
},
|
Progress string `json:"progress"`
|
||||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
Prompt string `json:"prompt"`
|
||||||
}
|
PromptEn string `json:"promptEn"`
|
||||||
|
Properties struct {
|
||||||
var res InteractionsResult
|
} `json:"properties"`
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
StartTime int `json:"startTime"`
|
||||||
SetHeader("Content-Type", "application/json").
|
State string `json:"state"`
|
||||||
SetBody(interactionsReq).
|
Status string `json:"status"`
|
||||||
SetErrorResult(&res).
|
SubmitTime int `json:"submitTime"`
|
||||||
Post(c.apiURL)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
267
api/service/mj/plus_client.go
Normal file
267
api/service/mj/plus_client.go
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PlusClient MidJourney Plus ProxyClient
|
||||||
|
type PlusClient struct {
|
||||||
|
Config types.MjPlusConfig
|
||||||
|
apiURL string
|
||||||
|
client *req.Client
|
||||||
|
licenseService *service.LicenseService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
|
||||||
|
return &PlusClient{
|
||||||
|
Config: config,
|
||||||
|
apiURL: config.ApiURL,
|
||||||
|
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
||||||
|
licenseService: licenseService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PlusClient) preCheck() error {
|
||||||
|
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
||||||
|
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||||
|
if task.NegPrompt != "" {
|
||||||
|
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
||||||
|
}
|
||||||
|
body := ImageReq{
|
||||||
|
BotType: "MID_JOURNEY",
|
||||||
|
Prompt: prompt,
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := c.client.R().
|
||||||
|
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
errStr, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blend 融图
|
||||||
|
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
body := ImageReq{
|
||||||
|
BotType: "MID_JOURNEY",
|
||||||
|
Dimensions: "SQUARE",
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
for _, imgURL := range task.ImgArr {
|
||||||
|
imageData, err := utils.DownloadImage(imgURL, "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := c.client.R().
|
||||||
|
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwapFace 换脸
|
||||||
|
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) != 2 {
|
||||||
|
return ImageRes{}, errors.New("参数错误,必须上传2张图片")
|
||||||
|
}
|
||||||
|
var sourceBase64 string
|
||||||
|
var targetBase64 string
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
}
|
||||||
|
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := gin.H{
|
||||||
|
"sourceBase64": sourceBase64,
|
||||||
|
"targetBase64": targetBase64,
|
||||||
|
"accountFilter": gin.H{
|
||||||
|
"instanceId": "",
|
||||||
|
},
|
||||||
|
"state": "",
|
||||||
|
}
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := c.client.SetTimeout(time.Minute).R().
|
||||||
|
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale 放大指定的图片
|
||||||
|
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body := map[string]string{
|
||||||
|
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := c.client.R().
|
||||||
|
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
|
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body := map[string]string{
|
||||||
|
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||||
|
var res QueryRes
|
||||||
|
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Get(apiURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return QueryRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Client = &PlusClient{}
|
||||||
@@ -1,12 +1,23 @@
|
|||||||
package mj
|
package mj
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/service/oss"
|
|
||||||
"chatplus/store"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/service/sd"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -20,42 +31,15 @@ type ServicePool struct {
|
|||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploaderManager *oss.UploaderManager
|
uploaderManager *oss.UploaderManager
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
|
licenseService *service.LicenseService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
|
||||||
services := make([]*Service, 0)
|
services := make([]*Service, 0)
|
||||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
||||||
// create mj client and service
|
|
||||||
for k, config := range appConfig.MjConfigs {
|
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// create mj client
|
|
||||||
client := NewClient(config, appConfig.ProxyURL)
|
|
||||||
|
|
||||||
name := fmt.Sprintf("MjService-%d", k)
|
|
||||||
// create mj service
|
|
||||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
|
|
||||||
botName := fmt.Sprintf("MjBot-%d", k)
|
|
||||||
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = bot.Run()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// run mj service
|
|
||||||
go func() {
|
|
||||||
service.Run()
|
|
||||||
}()
|
|
||||||
|
|
||||||
services = append(services, service)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ServicePool{
|
return &ServicePool{
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
@@ -63,19 +47,59 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
uploaderManager: manager,
|
uploaderManager: manager,
|
||||||
db: db,
|
db: db,
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||||
|
licenseService: licenseService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
|
||||||
|
// stop old service
|
||||||
|
for _, s := range p.services {
|
||||||
|
s.Stop()
|
||||||
|
}
|
||||||
|
p.services = make([]*Service, 0)
|
||||||
|
|
||||||
|
for k, config := range plusConfigs {
|
||||||
|
if config.Enabled == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cli := NewPlusClient(config, p.licenseService)
|
||||||
|
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||||
|
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
||||||
|
go func() {
|
||||||
|
plusService.Run()
|
||||||
|
}()
|
||||||
|
p.services = append(p.services, plusService)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for mid-journey proxy
|
||||||
|
for k, config := range proxyConfigs {
|
||||||
|
if config.Enabled == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cli := NewProxyClient(config)
|
||||||
|
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||||
|
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
||||||
|
go func() {
|
||||||
|
proxyService.Run()
|
||||||
|
}()
|
||||||
|
p.services = append(p.services, proxyService)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ServicePool) CheckTaskNotify() {
|
func (p *ServicePool) CheckTaskNotify() {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
var userId uint
|
var message sd.NotifyMessage
|
||||||
err := p.notifyQueue.LPop(&userId)
|
err := p.notifyQueue.LPop(&message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := p.Clients.Get(userId)
|
cli := p.Clients.Get(uint(message.UserId))
|
||||||
err = client.Send([]byte("Task Updated"))
|
if cli == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = cli.Send([]byte(message.Message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -94,17 +118,43 @@ func (p *ServicePool) DownloadImages() {
|
|||||||
|
|
||||||
// download images
|
// download images
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
|
if v.OrgURL == "" {
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Infof("try to download image: %s", v.OrgURL)
|
||||||
|
mjService := p.getService(v.ChannelId)
|
||||||
|
if mjService == nil {
|
||||||
|
logger.Errorf("Invalid task: %+v", v)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task, _ := mjService.Client.QueryTask(v.TaskId)
|
||||||
|
if len(task.Buttons) > 0 {
|
||||||
|
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
|
}
|
||||||
|
// 如果是返回的是 discord 图片地址,则使用代理下载
|
||||||
|
proxy := false
|
||||||
|
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||||
|
proxy = true
|
||||||
|
}
|
||||||
|
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
logger.Infof("download image %s successfully.", v.OrgURL)
|
||||||
|
}
|
||||||
|
|
||||||
v.ImgURL = imgURL
|
v.ImgURL = imgURL
|
||||||
p.db.Updates(&v)
|
p.db.Updates(&v)
|
||||||
|
|
||||||
client := p.Clients.Get(uint(v.UserId))
|
cli := p.Clients.Get(uint(v.UserId))
|
||||||
err = client.Send([]byte("Task Updated"))
|
if cli == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = cli.Send([]byte(sd.Finished))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -125,3 +175,56 @@ func (p *ServicePool) PushTask(task types.MjTask) {
|
|||||||
func (p *ServicePool) HasAvailableService() bool {
|
func (p *ServicePool) HasAvailableService() bool {
|
||||||
return len(p.services) > 0
|
return len(p.services) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncTaskProgress 异步拉取任务
|
||||||
|
func (p *ServicePool) SyncTaskProgress() {
|
||||||
|
go func() {
|
||||||
|
var items []model.MidJourneyJob
|
||||||
|
for {
|
||||||
|
res := p.db.Where("progress < ?", 100).Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range items {
|
||||||
|
// 失败或者 30 分钟还没完成的任务删除并退回算力
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
|
||||||
|
p.db.Delete(&job)
|
||||||
|
// 退回算力
|
||||||
|
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
var user model.User
|
||||||
|
p.db.Where("id = ?", job.UserId).First(&user)
|
||||||
|
p.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power + job.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
|
||||||
|
_ = servicePlus.Notify(job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ServicePool) getService(name string) *Service {
|
||||||
|
for _, s := range p.services {
|
||||||
|
if s.Name == name {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
185
api/service/mj/proxy_client.go
Normal file
185
api/service/mj/proxy_client.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyClient MidJourney Proxy Client
|
||||||
|
type ProxyClient struct {
|
||||||
|
Config types.MjProxyConfig
|
||||||
|
apiURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
|
||||||
|
return &ProxyClient{Config: config, apiURL: config.ApiURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
|
||||||
|
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||||
|
if task.NegPrompt != "" {
|
||||||
|
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
||||||
|
}
|
||||||
|
body := ImageReq{
|
||||||
|
Prompt: prompt,
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
errStr, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blend 融图
|
||||||
|
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
|
||||||
|
body := ImageReq{
|
||||||
|
Dimensions: "SQUARE",
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
for _, imgURL := range task.ImgArr {
|
||||||
|
imageData, err := utils.DownloadImage(imgURL, "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwapFace 换脸
|
||||||
|
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
|
||||||
|
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale 放大指定的图片
|
||||||
|
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"action": "UPSCALE",
|
||||||
|
"index": task.Index,
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
|
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"action": "VARIATION",
|
||||||
|
"index": task.Index,
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||||
|
var res QueryRes
|
||||||
|
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Get(apiURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return QueryRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Client = &ProxyClient{}
|
||||||
@@ -1,52 +1,50 @@
|
|||||||
package mj
|
package mj
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"fmt"
|
||||||
"chatplus/store"
|
"geekai/core/types"
|
||||||
"chatplus/store/model"
|
"geekai/service"
|
||||||
"gorm.io/gorm"
|
"geekai/service/sd"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Service MJ 绘画服务
|
// Service MJ 绘画服务
|
||||||
type Service struct {
|
type Service struct {
|
||||||
name string // service name
|
Name string // service Name
|
||||||
client *Client // MJ client
|
Client Client // MJ Client
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
running bool
|
||||||
handledTaskNum int32 // already handled task number
|
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
||||||
taskTimeout int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
Name: name,
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
client: client,
|
Client: cli,
|
||||||
taskTimeout: timeout,
|
running: true,
|
||||||
maxHandleTaskNum: maxTaskNum,
|
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||||||
for {
|
for s.running {
|
||||||
s.checkTasks()
|
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.MjTask
|
var task types.MjTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -54,108 +52,153 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it's reference message, check if it's this channel's message
|
// 如果配置了多个中转平台的 API KEY
|
||||||
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
|
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
|
||||||
|
if task.ChannelId != "" && task.ChannelId != s.Name {
|
||||||
|
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
||||||
s.taskQueue.RPush(task)
|
s.taskQueue.RPush(task)
|
||||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
|
||||||
s.notifyQueue.RPush(task.UserId)
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
// translate prompt
|
||||||
switch task.Type {
|
if utils.HasChinese(task.Prompt) {
|
||||||
case types.TaskImage:
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
|
||||||
err = s.client.Imagine(task.Prompt)
|
if err == nil {
|
||||||
break
|
task.Prompt = content
|
||||||
case types.TaskUpscale:
|
} else {
|
||||||
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
break
|
}
|
||||||
case types.TaskVariation:
|
// translate negative prompt
|
||||||
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
|
||||||
|
if err == nil {
|
||||||
|
task.NegPrompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
var job model.MidJourneyJob
|
||||||
logger.Error("绘画任务执行失败:", err)
|
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||||
// update the task progress
|
if tx.Error != nil {
|
||||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
logger.Error("任务不存在,任务ID:", task.TaskId)
|
||||||
s.notifyQueue.RPush(task.UserId)
|
|
||||||
// restore img_call quota
|
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock the task until the execute timeout
|
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
||||||
s.taskStartTimes[task.Id] = time.Now()
|
var res ImageRes
|
||||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
switch task.Type {
|
||||||
|
case types.TaskImage:
|
||||||
}
|
res, err = s.Client.Imagine(task)
|
||||||
}
|
break
|
||||||
|
case types.TaskUpscale:
|
||||||
// check if current service instance can handle more task
|
res, err = s.Client.Upscale(task)
|
||||||
func (s *Service) canHandleTask() bool {
|
break
|
||||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
case types.TaskVariation:
|
||||||
return handledNum < s.maxHandleTaskNum
|
res, err = s.Client.Variation(task)
|
||||||
}
|
break
|
||||||
|
case types.TaskBlend:
|
||||||
// remove the expired tasks
|
res, err = s.Client.Blend(task)
|
||||||
func (s *Service) checkTasks() {
|
break
|
||||||
for k, t := range s.taskStartTimes {
|
case types.TaskSwapFace:
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
res, err = s.Client.SwapFace(task)
|
||||||
delete(s.taskStartTimes, k)
|
break
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
|
||||||
// delete task from database
|
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||||
|
var errMsg string
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
} else {
|
||||||
|
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error("绘画任务执行失败:", errMsg)
|
||||||
|
job.Progress = -1
|
||||||
|
job.ErrMsg = errMsg
|
||||||
|
// update the task progress
|
||||||
|
s.db.Updates(&job)
|
||||||
|
// 任务失败,通知前端
|
||||||
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Infof("任务提交成功:%+v", res)
|
||||||
|
// 更新任务 ID/频道
|
||||||
|
job.TaskId = res.Result
|
||||||
|
job.MessageId = res.Result
|
||||||
|
job.ChannelId = s.Name
|
||||||
|
s.db.Updates(&job)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Notify(data CBReq) {
|
func (s *Service) Stop() {
|
||||||
// extract the task ID
|
s.running = false
|
||||||
split := strings.Split(data.Prompt, " ")
|
}
|
||||||
var job model.MidJourneyJob
|
|
||||||
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
type CBReq struct {
|
||||||
if res.Error == nil && data.Status == Finished {
|
Id string `json:"id"`
|
||||||
logger.Warn("重复消息:", data.MessageId)
|
Action string `json:"action"`
|
||||||
return
|
Status string `json:"status"`
|
||||||
}
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"promptEn"`
|
||||||
tx := s.db.Session(&gorm.Session{}).Order("id ASC")
|
Description string `json:"description"`
|
||||||
if data.ReferenceId != "" {
|
SubmitTime int64 `json:"submitTime"`
|
||||||
tx = tx.Where("reference_id = ?", data.ReferenceId)
|
StartTime int64 `json:"startTime"`
|
||||||
} else {
|
FinishTime int64 `json:"finishTime"`
|
||||||
tx = tx.Where("task_id = ?", split[0])
|
Progress string `json:"progress"`
|
||||||
}
|
ImageUrl string `json:"imageUrl"`
|
||||||
res = tx.First(&job)
|
FailReason interface{} `json:"failReason"`
|
||||||
if res.Error != nil {
|
Properties struct {
|
||||||
logger.Warn("非法任务:", res.Error)
|
FinalPrompt string `json:"finalPrompt"`
|
||||||
return
|
} `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
job.ChannelId = data.ChannelId
|
func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||||
job.MessageId = data.MessageId
|
task, err := s.Client.QueryTask(job.TaskId)
|
||||||
job.ReferenceId = data.ReferenceId
|
if err != nil {
|
||||||
job.Progress = data.Progress
|
return err
|
||||||
job.Prompt = data.Prompt
|
}
|
||||||
job.Hash = data.Image.Hash
|
|
||||||
job.OrgURL = data.Image.URL
|
// 任务执行失败了
|
||||||
if s.client.Config.UseCDN {
|
if task.FailReason != "" {
|
||||||
job.UseProxy = true
|
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
|
"progress": -1,
|
||||||
}
|
"err_msg": task.FailReason,
|
||||||
|
})
|
||||||
res = s.db.Updates(&job)
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||||||
if res.Error != nil {
|
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||||
logger.Error("error with update job: ", res.Error)
|
}
|
||||||
return
|
|
||||||
}
|
if len(task.Buttons) > 0 {
|
||||||
|
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
if data.Status == Finished {
|
}
|
||||||
// release lock task
|
oldProgress := job.Progress
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||||
}
|
job.Prompt = task.PromptEn
|
||||||
|
if task.ImageUrl != "" {
|
||||||
s.notifyQueue.RPush(job.UserId)
|
job.OrgURL = task.ImageUrl
|
||||||
|
}
|
||||||
|
tx := s.db.Updates(&job)
|
||||||
|
if tx.Error != nil {
|
||||||
|
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||||
|
}
|
||||||
|
// 通知前端更新任务进度
|
||||||
|
if oldProgress != job.Progress {
|
||||||
|
message := sd.Running
|
||||||
|
if job.Progress == 100 {
|
||||||
|
message = sd.Finished
|
||||||
|
}
|
||||||
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageHash(action string) string {
|
||||||
|
split := strings.Split(action, "::")
|
||||||
|
if len(split) > 5 {
|
||||||
|
return split[4]
|
||||||
|
}
|
||||||
|
return split[len(split)-1]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
const (
|
|
||||||
ApplicationID string = "936929561302675456"
|
|
||||||
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InteractionsRequest struct {
|
|
||||||
Type int `json:"type"`
|
|
||||||
ApplicationID string `json:"application_id"`
|
|
||||||
MessageFlags *int `json:"message_flags,omitempty"`
|
|
||||||
MessageID *string `json:"message_id,omitempty"`
|
|
||||||
GuildID string `json:"guild_id"`
|
|
||||||
ChannelID string `json:"channel_id"`
|
|
||||||
SessionID string `json:"session_id"`
|
|
||||||
Data map[string]any `json:"data"`
|
|
||||||
Nonce string `json:"nonce,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type InteractionsResult struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string
|
|
||||||
Error map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
type CBReq struct {
|
|
||||||
ChannelId string `json:"channel_id"`
|
|
||||||
MessageId string `json:"message_id"`
|
|
||||||
ReferenceId string `json:"reference_id"`
|
|
||||||
Image Image `json:"image"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Status TaskStatus `json:"status"`
|
|
||||||
Progress int `json:"progress"`
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,25 @@
|
|||||||
package oss
|
package oss
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core/types"
|
"encoding/base64"
|
||||||
"chatplus/utils"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
"geekai/core/types"
|
||||||
"github.com/gin-gonic/gin"
|
"geekai/utils"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AliYunOss struct {
|
type AliYunOss struct {
|
||||||
@@ -44,16 +54,16 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
// 解析表单
|
// 解析表单
|
||||||
file, err := ctx.FormFile(name)
|
file, err := ctx.FormFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return File{}, err
|
||||||
}
|
}
|
||||||
// 打开上传文件
|
// 打开上传文件
|
||||||
src, err := file.Open()
|
src, err := file.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return File{}, err
|
||||||
}
|
}
|
||||||
defer src.Close()
|
defer src.Close()
|
||||||
|
|
||||||
@@ -62,10 +72,16 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
|||||||
// 上传文件
|
// 上传文件
|
||||||
err = s.bucket.PutObject(objectKey, src)
|
err = s.bucket.PutObject(objectKey, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return File{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
return File{
|
||||||
|
Name: file.Filename,
|
||||||
|
ObjKey: objectKey,
|
||||||
|
URL: fmt.Sprintf("%s/%s", s.config.Domain, objectKey),
|
||||||
|
Ext: fileExt,
|
||||||
|
Size: file.Size,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||||
@@ -83,7 +99,7 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||||
}
|
}
|
||||||
fileExt := filepath.Ext(parse.Path)
|
fileExt := utils.GetImgExt(parse.Path)
|
||||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||||
// 上传文件字节数据
|
// 上传文件字节数据
|
||||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||||
@@ -93,10 +109,29 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
// 上传文件字节数据
|
||||||
|
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s AliYunOss) Delete(fileURL string) error {
|
func (s AliYunOss) Delete(fileURL string) error {
|
||||||
objectName := filepath.Base(fileURL)
|
var objectKey string
|
||||||
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
return s.bucket.DeleteObject(key)
|
filename := filepath.Base(fileURL)
|
||||||
|
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||||
|
} else {
|
||||||
|
objectKey = fileURL
|
||||||
|
}
|
||||||
|
return s.bucket.DeleteObject(objectKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Uploader = AliYunOss{}
|
var _ Uploader = AliYunOss{}
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
package oss
|
package oss
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"encoding/base64"
|
||||||
"chatplus/utils"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -23,23 +31,30 @@ func NewLocalStorage(config *types.AppConfig) LocalStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) {
|
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
file, err := ctx.FormFile(name)
|
file, err := ctx.FormFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with get form: %v", err)
|
return File{}, fmt.Errorf("error with get form: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
filePath, err := utils.GenUploadPath(s.config.BasePath, file.Filename)
|
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with generate filename: %s", err.Error())
|
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
|
||||||
}
|
}
|
||||||
// 将文件保存到指定路径
|
// 将文件保存到指定路径
|
||||||
err = ctx.SaveUploadedFile(file, filePath)
|
err = ctx.SaveUploadedFile(file, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with save upload file: %s", err.Error())
|
return File{}, fmt.Errorf("error with save upload file: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
ext := filepath.Ext(file.Filename)
|
||||||
|
return File{
|
||||||
|
Name: file.Filename,
|
||||||
|
ObjKey: path,
|
||||||
|
URL: utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path),
|
||||||
|
Ext: ext,
|
||||||
|
Size: file.Size,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||||
@@ -48,7 +63,7 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||||
}
|
}
|
||||||
filename := filepath.Base(parse.Path)
|
filename := filepath.Base(parse.Path)
|
||||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
|
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with generate image dir: %v", err)
|
return "", fmt.Errorf("error with generate image dir: %v", err)
|
||||||
}
|
}
|
||||||
@@ -65,7 +80,24 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||||
|
err = os.WriteFile(filePath, imageData, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error writing to file:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s LocalStorage) Delete(fileURL string) error {
|
func (s LocalStorage) Delete(fileURL string) error {
|
||||||
|
if _, err := os.Stat(fileURL); err == nil {
|
||||||
|
return os.Remove(fileURL)
|
||||||
|
}
|
||||||
filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1)
|
filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1)
|
||||||
return os.Remove(filePath)
|
return os.Remove(filePath)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,26 @@
|
|||||||
package oss
|
package oss
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"geekai/core/types"
|
||||||
"github.com/minio/minio-go/v7"
|
"geekai/utils"
|
||||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/minio/minio-go/v7"
|
||||||
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MiniOss struct {
|
type MiniOss struct {
|
||||||
@@ -65,34 +74,64 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
file, err := ctx.FormFile(name)
|
file, err := ctx.FormFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with get form: %v", err)
|
return File{}, fmt.Errorf("error with get form: %v", err)
|
||||||
}
|
}
|
||||||
// Open the uploaded file
|
// Open the uploaded file
|
||||||
fileReader, err := file.Open()
|
fileReader, err := file.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error opening file: %v", err)
|
return File{}, fmt.Errorf("error opening file: %v", err)
|
||||||
}
|
}
|
||||||
defer fileReader.Close()
|
defer fileReader.Close()
|
||||||
|
|
||||||
fileExt := filepath.Ext(file.Filename)
|
fileExt := utils.GetImgExt(file.Filename)
|
||||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||||
ContentType: file.Header.Get("Content-Type"),
|
ContentType: file.Header.Get("Content-Type"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error uploading to MinIO: %v", err)
|
return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return File{
|
||||||
|
Name: file.Filename,
|
||||||
|
ObjKey: info.Key,
|
||||||
|
URL: fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key),
|
||||||
|
Ext: fileExt,
|
||||||
|
Size: file.Size,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
info, err := s.client.PutObject(
|
||||||
|
context.Background(),
|
||||||
|
s.config.Bucket,
|
||||||
|
objectKey,
|
||||||
|
strings.NewReader(string(imageData)),
|
||||||
|
int64(len(imageData)),
|
||||||
|
minio.PutObjectOptions{ContentType: "image/png"})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s MiniOss) Delete(fileURL string) error {
|
func (s MiniOss) Delete(fileURL string) error {
|
||||||
objectName := filepath.Base(fileURL)
|
var objectKey string
|
||||||
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
return s.client.RemoveObject(context.Background(), s.config.Bucket, key, minio.RemoveObjectOptions{})
|
filename := filepath.Base(fileURL)
|
||||||
|
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||||
|
} else {
|
||||||
|
objectKey = fileURL
|
||||||
|
}
|
||||||
|
return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{})
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Uploader = MiniOss{}
|
var _ Uploader = MiniOss{}
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
package oss
|
package oss
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"net/url"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||||
"github.com/qiniu/go-sdk/v7/storage"
|
"github.com/qiniu/go-sdk/v7/storage"
|
||||||
"net/url"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type QinNiuOss struct {
|
type QinNiuOss struct {
|
||||||
@@ -50,16 +60,16 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
// 解析表单
|
// 解析表单
|
||||||
file, err := ctx.FormFile(name)
|
file, err := ctx.FormFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return File{}, err
|
||||||
}
|
}
|
||||||
// 打开上传文件
|
// 打开上传文件
|
||||||
src, err := file.Open()
|
src, err := file.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return File{}, err
|
||||||
}
|
}
|
||||||
defer src.Close()
|
defer src.Close()
|
||||||
|
|
||||||
@@ -70,10 +80,17 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
|||||||
extra := storage.PutExtra{}
|
extra := storage.PutExtra{}
|
||||||
err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra)
|
err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return File{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
return File{
|
||||||
|
Name: file.Filename,
|
||||||
|
ObjKey: key,
|
||||||
|
URL: fmt.Sprintf("%s/%s", s.config.Domain, ret.Key),
|
||||||
|
Ext: fileExt,
|
||||||
|
Size: file.Size,
|
||||||
|
}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||||
@@ -91,7 +108,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||||
}
|
}
|
||||||
fileExt := filepath.Ext(parse.Path)
|
fileExt := utils.GetImgExt(parse.Path)
|
||||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||||
ret := storage.PutRet{}
|
ret := storage.PutRet{}
|
||||||
extra := storage.PutExtra{}
|
extra := storage.PutExtra{}
|
||||||
@@ -103,10 +120,32 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
ret := storage.PutRet{}
|
||||||
|
extra := storage.PutExtra{}
|
||||||
|
// 上传文件字节数据
|
||||||
|
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) Delete(fileURL string) error {
|
func (s QinNiuOss) Delete(fileURL string) error {
|
||||||
objectName := filepath.Base(fileURL)
|
var objectKey string
|
||||||
key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
return s.manager.Delete(s.config.Bucket, key)
|
filename := filepath.Base(fileURL)
|
||||||
|
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||||
|
} else {
|
||||||
|
objectKey = fileURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.manager.Delete(s.config.Bucket, objectKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Uploader = QinNiuOss{}
|
var _ Uploader = QinNiuOss{}
|
||||||
|
|||||||
@@ -1,9 +1,29 @@
|
|||||||
package oss
|
package oss
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
const Local = "LOCAL"
|
||||||
|
const Minio = "MINIO"
|
||||||
|
const QiNiu = "QINIU"
|
||||||
|
const AliYun = "ALIYUN"
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ObjKey string `json:"obj_key"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Ext string `json:"ext"`
|
||||||
|
}
|
||||||
type Uploader interface {
|
type Uploader interface {
|
||||||
PutFile(ctx *gin.Context, name string) (string, error)
|
PutFile(ctx *gin.Context, name string) (File, error)
|
||||||
PutImg(imageURL string, useProxy bool) (string, error)
|
PutImg(imageURL string, useProxy bool) (string, error)
|
||||||
|
PutBase64(imageData string) (string, error)
|
||||||
Delete(fileURL string) error
|
Delete(fileURL string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
package oss
|
package oss
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"geekai/core/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -9,11 +16,6 @@ type UploaderManager struct {
|
|||||||
handler Uploader
|
handler Uploader
|
||||||
}
|
}
|
||||||
|
|
||||||
const Local = "LOCAL"
|
|
||||||
const Minio = "MINIO"
|
|
||||||
const QiNiu = "QINIU"
|
|
||||||
const AliYun = "ALIYUN"
|
|
||||||
|
|
||||||
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
|
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
|
||||||
active := Local
|
active := Local
|
||||||
if config.OSS.Active != "" {
|
if config.OSS.Active != "" {
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
package payment
|
package payment
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
"github.com/smartwalle/alipay/v3"
|
"github.com/smartwalle/alipay/v3"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
package payment
|
package payment
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -16,57 +26,146 @@ import (
|
|||||||
type HuPiPayService struct {
|
type HuPiPayService struct {
|
||||||
appId string
|
appId string
|
||||||
appSecret string
|
appSecret string
|
||||||
host string
|
apiURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
||||||
return &HuPiPayService{
|
return &HuPiPayService{
|
||||||
appId: config.HuPiPayConfig.AppId,
|
appId: config.HuPiPayConfig.AppId,
|
||||||
appSecret: config.HuPiPayConfig.AppSecret,
|
appSecret: config.HuPiPayConfig.AppSecret,
|
||||||
host: config.HuPiPayConfig.PayURL,
|
apiURL: config.HuPiPayConfig.ApiURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HuPiPayReq struct {
|
||||||
|
AppId string `json:"appid"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
TradeOrderId string `json:"trade_order_id"`
|
||||||
|
TotalFee string `json:"total_fee"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
NotifyURL string `json:"notify_url"`
|
||||||
|
ReturnURL string `json:"return_url"`
|
||||||
|
WapName string `json:"wap_name"`
|
||||||
|
CallbackURL string `json:"callback_url"`
|
||||||
|
Time string `json:"time"`
|
||||||
|
NonceStr string `json:"nonce_str"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
WapUrl string `json:"wap_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HuPiResp struct {
|
||||||
|
Openid interface{} `json:"openid"`
|
||||||
|
UrlQrcode string `json:"url_qrcode"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
ErrCode int `json:"errcode"`
|
||||||
|
ErrMsg string `json:"errmsg,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// Pay 执行支付请求操作
|
// Pay 执行支付请求操作
|
||||||
func (s *HuPiPayService) Pay(params map[string]string) (string, error) {
|
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
params["appid"] = s.appId
|
params.AppId = s.appId
|
||||||
params["time"] = simple
|
params.Time = simple
|
||||||
params["nonce_str"] = simple
|
params.NonceStr = simple
|
||||||
for k, v := range params {
|
encode := utils.JsonEncode(params)
|
||||||
data.Add(k, v)
|
m := make(map[string]string)
|
||||||
|
_ = utils.JsonDecode(encode, &m)
|
||||||
|
for k, v := range m {
|
||||||
|
data.Add(k, fmt.Sprintf("%v", v))
|
||||||
}
|
}
|
||||||
data.Add("hash", s.Sign(params))
|
// 生成签名
|
||||||
resp, err := http.PostForm(s.host, data)
|
data.Add("hash", s.Sign(data))
|
||||||
|
// 发送支付请求
|
||||||
|
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
||||||
|
resp, err := http.PostForm(apiURL, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "error", err
|
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
all, err := io.ReadAll(resp.Body)
|
all, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "error", err
|
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
|
||||||
}
|
}
|
||||||
return string(all), err
|
|
||||||
|
var res HuPiResp
|
||||||
|
err = utils.JsonDecode(string(all), &res)
|
||||||
|
if err != nil {
|
||||||
|
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.ErrCode != 0 {
|
||||||
|
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign 签名方法
|
// Sign 签名方法
|
||||||
func (s *HuPiPayService) Sign(params map[string]string) string {
|
func (s *HuPiPayService) Sign(params url.Values) string {
|
||||||
var data string
|
params.Del(`Sign`)
|
||||||
keys := make([]string, 0, 0)
|
var keys = make([]string, 0, 0)
|
||||||
params["appid"] = s.appId
|
|
||||||
for key := range params {
|
for key := range params {
|
||||||
keys = append(keys, key)
|
if params.Get(key) != `` {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
//拼接
|
|
||||||
for _, k := range keys {
|
var pList = make([]string, 0, 0)
|
||||||
data = fmt.Sprintf("%s%s=%s&", data, k, params[k])
|
for _, key := range keys {
|
||||||
|
var value = strings.TrimSpace(params.Get(key))
|
||||||
|
if len(value) > 0 {
|
||||||
|
pList = append(pList, key+"="+value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var src = strings.Join(pList, "&")
|
||||||
|
src += s.appSecret
|
||||||
|
|
||||||
|
md5bs := md5.Sum([]byte(src))
|
||||||
|
return hex.EncodeToString(md5bs[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check 校验订单状态
|
||||||
|
func (s *HuPiPayService) Check(tradeNo string) error {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Add("appid", s.appId)
|
||||||
|
data.Add("open_order_id", tradeNo)
|
||||||
|
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
|
data.Add("time", stamp)
|
||||||
|
data.Add("nonce_str", stamp)
|
||||||
|
data.Add("hash", s.Sign(data))
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
|
||||||
|
resp, err := http.PostForm(apiURL, data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error with http reqeust: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error with reading response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var r struct {
|
||||||
|
ErrCode int `json:"errcode"`
|
||||||
|
Data struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
OpenOrderId string `json:"open_order_id"`
|
||||||
|
} `json:"data,omitempty"`
|
||||||
|
ErrMsg string `json:"errmsg"`
|
||||||
|
Hash string `json:"hash"`
|
||||||
|
}
|
||||||
|
err = utils.JsonDecode(string(body), &r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error with decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.ErrCode == 0 && r.Data.Status == "OD" {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
logger.Debugf("%+v", r)
|
||||||
|
return errors.New("order not paid:" + r.ErrMsg)
|
||||||
}
|
}
|
||||||
data = strings.Trim(data, "&")
|
|
||||||
data = fmt.Sprintf("%s%s", data, s.appSecret)
|
|
||||||
m := md5.New()
|
|
||||||
m.Write([]byte(data))
|
|
||||||
sign := fmt.Sprintf("%x", m.Sum(nil))
|
|
||||||
return sign
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
package payment
|
package payment
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/utils"
|
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -26,11 +34,65 @@ func NewPayJS(appConfig *types.AppConfig) *PayJS {
|
|||||||
type JPayReq struct {
|
type JPayReq struct {
|
||||||
TotalFee int `json:"total_fee"`
|
TotalFee int `json:"total_fee"`
|
||||||
OutTradeNo string `json:"out_trade_no"`
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
Body string `json:"body"`
|
Subject string `json:"body"`
|
||||||
NotifyURL string `json:"notify_url"`
|
NotifyURL string `json:"notify_url"`
|
||||||
|
ReturnURL string `json:"callback_url"`
|
||||||
|
}
|
||||||
|
type JPayReps struct {
|
||||||
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
|
OrderId string `json:"payjs_order_id"`
|
||||||
|
ReturnCode int `json:"return_code"`
|
||||||
|
ReturnMsg string `json:"return_msg"`
|
||||||
|
Sign string `json:"Sign"`
|
||||||
|
TotalFee string `json:"total_fee"`
|
||||||
|
CodeUrl string `json:"code_url,omitempty"`
|
||||||
|
Qrcode string `json:"qrcode,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func sign(params url.Values, priKey string) string {
|
func (r JPayReps) IsOK() bool {
|
||||||
|
return r.ReturnMsg == "SUCCESS"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (js *PayJS) Pay(param JPayReq) JPayReps {
|
||||||
|
param.NotifyURL = js.config.NotifyURL
|
||||||
|
var p = url.Values{}
|
||||||
|
encode := utils.JsonEncode(param)
|
||||||
|
m := make(map[string]interface{})
|
||||||
|
_ = utils.JsonDecode(encode, &m)
|
||||||
|
for k, v := range m {
|
||||||
|
p.Add(k, fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
p.Add("mchid", js.config.AppId)
|
||||||
|
|
||||||
|
p.Add("sign", js.sign(p))
|
||||||
|
|
||||||
|
cli := http.Client{}
|
||||||
|
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
|
||||||
|
r, err := cli.PostForm(apiURL, p)
|
||||||
|
if err != nil {
|
||||||
|
return JPayReps{ReturnMsg: err.Error()}
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
bs, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return JPayReps{ReturnMsg: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
var data JPayReps
|
||||||
|
err = utils.JsonDecode(string(bs), &data)
|
||||||
|
if err != nil {
|
||||||
|
return JPayReps{ReturnMsg: err.Error()}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (js *PayJS) PayH5(p url.Values) string {
|
||||||
|
p.Add("mchid", js.config.AppId)
|
||||||
|
p.Add("sign", js.sign(p))
|
||||||
|
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (js *PayJS) sign(params url.Values) string {
|
||||||
params.Del(`sign`)
|
params.Del(`sign`)
|
||||||
var keys = make([]string, 0, 0)
|
var keys = make([]string, 0, 0)
|
||||||
for key := range params {
|
for key := range params {
|
||||||
@@ -48,34 +110,46 @@ func sign(params url.Values, priKey string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var src = strings.Join(pList, "&")
|
var src = strings.Join(pList, "&")
|
||||||
src += "&key=" + priKey
|
src += "&key=" + js.config.PrivateKey
|
||||||
|
|
||||||
md5bs := md5.Sum([]byte(src))
|
md5bs := md5.Sum([]byte(src))
|
||||||
md5res := hex.EncodeToString(md5bs[:])
|
md5res := hex.EncodeToString(md5bs[:])
|
||||||
return strings.ToUpper(md5res)
|
return strings.ToUpper(md5res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pj *PayJS) Pay(param JPayReq) (string, error) {
|
// Check 查询订单支付状态
|
||||||
var p = url.Values{}
|
// @param tradeNo 支付平台交易 ID
|
||||||
encode := utils.JsonEncode(param)
|
func (js *PayJS) Check(tradeNo string) error {
|
||||||
m := make(map[string]interface{})
|
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
|
||||||
_ = utils.JsonDecode(encode, &m)
|
params := url.Values{}
|
||||||
for k, v := range m {
|
params.Add("payjs_order_id", tradeNo)
|
||||||
p.Add(k, fmt.Sprintf("%v", v))
|
params.Add("sign", js.sign(params))
|
||||||
}
|
data := strings.NewReader(params.Encode())
|
||||||
p.Add("mchid", pj.config.AppId)
|
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
|
||||||
|
defer resp.Body.Close()
|
||||||
p.Add("sign", sign(p, pj.config.PrivateKey))
|
|
||||||
|
|
||||||
cli := http.Client{}
|
|
||||||
r, err := cli.PostForm(pj.config.ApiURL, p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return fmt.Errorf("error with http reqeust: %v", err)
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
|
||||||
bs, err := io.ReadAll(r.Body)
|
defer resp.Body.Close()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return fmt.Errorf("error with reading response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var r struct {
|
||||||
|
ReturnCode int `json:"return_code"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
err = utils.JsonDecode(string(body), &r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error with decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.ReturnCode == 1 && r.Status == 1 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
logger.Errorf("PayJs 支付验证响应:%s", string(body))
|
||||||
|
return errors.New("order not paid")
|
||||||
}
|
}
|
||||||
return string(bs), nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,43 +1,71 @@
|
|||||||
package sd
|
package sd
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/service/oss"
|
|
||||||
"chatplus/store"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service/oss"
|
||||||
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServicePool struct {
|
type ServicePool struct {
|
||||||
services []*Service
|
services []*Service
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
|
notifyQueue *store.RedisQueue
|
||||||
|
db *gorm.DB
|
||||||
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
|
uploader *oss.UploaderManager
|
||||||
|
levelDB *store.LevelDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
|
||||||
services := make([]*Service, 0)
|
services := make([]*Service, 0)
|
||||||
queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||||
// create mj client and service
|
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||||
for k, config := range appConfig.SdConfigs {
|
|
||||||
|
return &ServicePool{
|
||||||
|
taskQueue: taskQueue,
|
||||||
|
notifyQueue: notifyQueue,
|
||||||
|
services: services,
|
||||||
|
db: db,
|
||||||
|
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||||
|
uploader: manager,
|
||||||
|
levelDB: levelDB,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
|
||||||
|
// stop old service
|
||||||
|
for _, s := range p.services {
|
||||||
|
s.Stop()
|
||||||
|
}
|
||||||
|
p.services = make([]*Service, 0)
|
||||||
|
|
||||||
|
for k, config := range configs {
|
||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// create sd service
|
// create sd service
|
||||||
name := fmt.Sprintf("StableDifffusion Service-%d", k)
|
name := fmt.Sprintf(" sd-service-%d", k)
|
||||||
service := NewService(name, 1, 300, config, queue, db, manager)
|
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
|
||||||
// run sd service
|
// run sd service
|
||||||
go func() {
|
go func() {
|
||||||
service.Run()
|
service.Run()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
services = append(services, service)
|
p.services = append(p.services, service)
|
||||||
}
|
|
||||||
|
|
||||||
return &ServicePool{
|
|
||||||
taskQueue: queue,
|
|
||||||
services: services,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,6 +75,68 @@ func (p *ServicePool) PushTask(task types.SdTask) {
|
|||||||
p.taskQueue.RPush(task)
|
p.taskQueue.RPush(task)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ServicePool) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||||
|
for {
|
||||||
|
var message NotifyMessage
|
||||||
|
err := p.notifyQueue.LPop(&message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client := p.Clients.Get(uint(message.UserId))
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = client.Send([]byte(message.Message))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
|
func (p *ServicePool) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
|
for {
|
||||||
|
var jobs []model.SdJob
|
||||||
|
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 5 分钟还没完成的任务直接删除
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
||||||
|
p.db.Delete(&job)
|
||||||
|
var user model.User
|
||||||
|
p.db.Where("id = ?", job.UserId).First(&user)
|
||||||
|
// 退回绘图次数
|
||||||
|
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
|
if res.Error == nil && res.RowsAffected > 0 {
|
||||||
|
p.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power + job.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "stable-diffusion",
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// HasAvailableService check if it has available mj service in pool
|
// HasAvailableService check if it has available mj service in pool
|
||||||
func (p *ServicePool) HasAvailableService() bool {
|
func (p *ServicePool) HasAvailableService() bool {
|
||||||
return len(p.services) > 0
|
return len(p.services) > 0
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
package sd
|
package sd
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/service/oss"
|
|
||||||
"chatplus/store"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"geekai/core/types"
|
||||||
"os"
|
"geekai/service"
|
||||||
"strconv"
|
"geekai/service/oss"
|
||||||
"sync/atomic"
|
"geekai/store"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
@@ -21,283 +25,223 @@ import (
|
|||||||
// SD 绘画服务
|
// SD 绘画服务
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
httpClient *req.Client
|
httpClient *req.Client
|
||||||
config types.StableDiffusionConfig
|
config types.StableDiffusionConfig
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
notifyQueue *store.RedisQueue
|
||||||
uploadManager *oss.UploaderManager
|
db *gorm.DB
|
||||||
name string // service name
|
uploadManager *oss.UploaderManager
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
name string // service name
|
||||||
handledTaskNum int32 // already handled task number
|
leveldb *store.LevelDB
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
running bool // 运行状态
|
||||||
taskTimeout int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
|
||||||
|
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
name: name,
|
||||||
config: config,
|
config: config,
|
||||||
httpClient: req.C(),
|
httpClient: req.C(),
|
||||||
taskQueue: queue,
|
taskQueue: taskQueue,
|
||||||
db: db,
|
notifyQueue: notifyQueue,
|
||||||
uploadManager: manager,
|
db: db,
|
||||||
taskTimeout: timeout,
|
leveldb: levelDB,
|
||||||
maxHandleTaskNum: maxTaskNum,
|
uploadManager: manager,
|
||||||
taskStartTimes: make(map[int]time.Time),
|
running: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
for {
|
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
|
||||||
s.checkTasks()
|
for s.running {
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.SdTask
|
var task types.SdTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("taking task with error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// translate prompt
|
||||||
|
if utils.HasChinese(task.Params.Prompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
|
||||||
|
if err == nil {
|
||||||
|
task.Params.Prompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate negative prompt
|
||||||
|
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
|
||||||
|
if err == nil {
|
||||||
|
task.Params.NegPrompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
||||||
err = s.Txt2Img(task)
|
err = s.Txt2Img(task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("绘画任务执行失败:", err)
|
logger.Error("绘画任务执行失败:", err.Error())
|
||||||
// update the task progress
|
// update the task progress
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||||
// restore img_call quota
|
"progress": -1,
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
"err_msg": err.Error(),
|
||||||
// release task num
|
})
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
// 通知前端,任务失败
|
||||||
|
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock the task until the execute timeout
|
|
||||||
s.taskStartTimes[task.Id] = time.Now()
|
|
||||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
func (s *Service) Stop() {
|
||||||
func (s *Service) canHandleTask() bool {
|
s.running = false
|
||||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
|
||||||
return handledNum < s.maxHandleTaskNum
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the expired tasks
|
// Txt2ImgReq 文生图请求实体
|
||||||
func (s *Service) checkTasks() {
|
type Txt2ImgReq struct {
|
||||||
for k, t := range s.taskStartTimes {
|
Prompt string `json:"prompt"`
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
NegativePrompt string `json:"negative_prompt"`
|
||||||
delete(s.taskStartTimes, k)
|
Seed int64 `json:"seed,omitempty"`
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
Steps int `json:"steps"`
|
||||||
// delete task from database
|
CfgScale float32 `json:"cfg_scale"`
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
Width int `json:"width"`
|
||||||
}
|
Height int `json:"height"`
|
||||||
}
|
SamplerName string `json:"sampler_name"`
|
||||||
|
Scheduler string `json:"scheduler"`
|
||||||
|
EnableHr bool `json:"enable_hr,omitempty"`
|
||||||
|
HrScale int `json:"hr_scale,omitempty"`
|
||||||
|
HrUpscaler string `json:"hr_upscaler,omitempty"`
|
||||||
|
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
|
||||||
|
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
|
||||||
|
ForceTaskId string `json:"force_task_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Txt2ImgResp 文生图响应实体
|
||||||
|
type Txt2ImgResp struct {
|
||||||
|
Images []string `json:"images"`
|
||||||
|
Parameters struct {
|
||||||
|
} `json:"parameters"`
|
||||||
|
Info string `json:"info"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskProgressResp 任务进度响应实体
|
||||||
|
type TaskProgressResp struct {
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
EtaRelative float64 `json:"eta_relative"`
|
||||||
|
CurrentImage string `json:"current_image"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Txt2Img 文生图 API
|
// Txt2Img 文生图 API
|
||||||
func (s *Service) Txt2Img(task types.SdTask) error {
|
func (s *Service) Txt2Img(task types.SdTask) error {
|
||||||
var taskInfo TaskInfo
|
body := Txt2ImgReq{
|
||||||
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
|
Prompt: task.Params.Prompt,
|
||||||
if err != nil {
|
NegativePrompt: task.Params.NegPrompt,
|
||||||
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
|
Steps: task.Params.Steps,
|
||||||
|
CfgScale: task.Params.CfgScale,
|
||||||
|
Width: task.Params.Width,
|
||||||
|
Height: task.Params.Height,
|
||||||
|
SamplerName: task.Params.Sampler,
|
||||||
|
Scheduler: task.Params.Scheduler,
|
||||||
|
ForceTaskId: task.Params.TaskId,
|
||||||
}
|
}
|
||||||
|
if task.Params.Seed > 0 {
|
||||||
err = json.Unmarshal(bytes, &taskInfo)
|
body.Seed = task.Params.Seed
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode json params: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
|
if task.Params.HdFix {
|
||||||
data := taskInfo.Data
|
body.EnableHr = true
|
||||||
params := task.Params
|
body.HrScale = task.Params.HdScale
|
||||||
data[ParamKeys["task_id"]] = params.TaskId
|
body.HrUpscaler = task.Params.HdScaleAlg
|
||||||
data[ParamKeys["prompt"]] = params.Prompt
|
body.HrSecondPassSteps = task.Params.HdSteps
|
||||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
body.DenoisingStrength = task.Params.HdRedrawRate
|
||||||
data[ParamKeys["steps"]] = params.Steps
|
}
|
||||||
data[ParamKeys["sampler"]] = params.Sampler
|
var res Txt2ImgResp
|
||||||
// @fix bug: 有些 stable diffusion 没有面部修复功能
|
var errChan = make(chan error)
|
||||||
//data[ParamKeys["face_fix"]] = params.FaceFix
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
logger.Debugf("send image request to %s", apiURL)
|
||||||
data[ParamKeys["seed"]] = params.Seed
|
// send a request to sd api endpoint
|
||||||
data[ParamKeys["height"]] = params.Height
|
|
||||||
data[ParamKeys["width"]] = params.Width
|
|
||||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
|
||||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
|
||||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
|
||||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
|
||||||
data[ParamKeys["hd_sample_num"]] = params.HdSteps
|
|
||||||
|
|
||||||
taskInfo.SessionId = task.SessionId
|
|
||||||
taskInfo.TaskId = params.TaskId
|
|
||||||
taskInfo.Data = data
|
|
||||||
taskInfo.JobId = task.Id
|
|
||||||
taskInfo.UserId = uint(task.UserId)
|
|
||||||
go func() {
|
go func() {
|
||||||
s.runTask(taskInfo, s.httpClient)
|
response, err := s.httpClient.R().
|
||||||
}()
|
SetHeader("Authorization", s.config.ApiKey).
|
||||||
return nil
|
SetBody(body).
|
||||||
}
|
SetSuccessResult(&res).
|
||||||
|
Post(apiURL)
|
||||||
// 执行任务
|
|
||||||
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
|
|
||||||
body := map[string]any{
|
|
||||||
"data": taskInfo.Data,
|
|
||||||
"event_data": taskInfo.EventData,
|
|
||||||
"fn_index": taskInfo.FnIndex,
|
|
||||||
"session_hash": taskInfo.SessionHash,
|
|
||||||
}
|
|
||||||
var result = make(chan CBReq)
|
|
||||||
go func() {
|
|
||||||
var res struct {
|
|
||||||
Data []interface{} `json:"data"`
|
|
||||||
IsGenerating bool `json:"is_generating"`
|
|
||||||
Duration float64 `json:"duration"`
|
|
||||||
AverageDuration float64 `json:"average_duration"`
|
|
||||||
}
|
|
||||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
|
||||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cbReq.Message = "error with send request: " + err.Error()
|
errChan <- err
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.IsErrorState() {
|
if response.IsErrorState() {
|
||||||
bytes, _ := io.ReadAll(response.Body)
|
errChan <- fmt.Errorf("error http code status: %v", response.Status)
|
||||||
cbReq.Message = "error http status code: " + string(bytes)
|
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var images []struct {
|
// 保存 Base64 图片
|
||||||
Name string `json:"name"`
|
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||||
Data interface{} `json:"data"`
|
|
||||||
IsFile bool `json:"is_file"`
|
|
||||||
}
|
|
||||||
err = utils.ForceCovert(res.Data[0], &images)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cbReq.Message = "error with decode image:" + err.Error()
|
errChan <- fmt.Errorf("error with upload image: %v", err)
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 获取绘画真实的 seed
|
||||||
var info map[string]any
|
var info map[string]interface{}
|
||||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
err = utils.JsonDecode(res.Info, &info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(res.Data)
|
errChan <- fmt.Errorf("error with decode task response: %v", err)
|
||||||
cbReq.Message = "error with decode image url:" + err.Error()
|
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
|
||||||
// 获取真实的 seed 值
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
|
||||||
cbReq.ImageName = images[0].Name
|
errChan <- nil
|
||||||
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
|
|
||||||
cbReq.Seed = seed
|
|
||||||
cbReq.Success = true
|
|
||||||
cbReq.Progress = 100
|
|
||||||
result <- cbReq
|
|
||||||
close(result)
|
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// waiting for task finish
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case value := <-result:
|
case err := <-errChan:
|
||||||
s.callback(value)
|
if err != nil {
|
||||||
return
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// task finished
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||||
|
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
|
||||||
|
// 从 leveldb 中删除预览图片数据
|
||||||
|
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
var progressReq = map[string]any{
|
err, resp := s.checkTaskProgress()
|
||||||
"id_task": taskInfo.TaskId,
|
// 更新任务进度
|
||||||
"id_live_preview": 1,
|
if err == nil && resp.Progress > 0 {
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||||
|
// 发送更新状态信号
|
||||||
|
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
|
||||||
|
// 保存预览图片数据
|
||||||
|
if resp.CurrentImage != "" {
|
||||||
|
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var progressRes struct {
|
|
||||||
Active bool `json:"active"`
|
|
||||||
Queued bool `json:"queued"`
|
|
||||||
Completed bool `json:"completed"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
Eta float64 `json:"eta"`
|
|
||||||
LivePreview string `json:"live_preview"`
|
|
||||||
IDLivePreview int `json:"id_live_preview"`
|
|
||||||
TextInfo interface{} `json:"textinfo"`
|
|
||||||
}
|
|
||||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
|
|
||||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
|
||||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
|
||||||
logger.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.IsErrorState() {
|
|
||||||
bytes, _ := io.ReadAll(response.Body)
|
|
||||||
logger.Error(string(bytes))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cbReq.ImageData = progressRes.LivePreview
|
|
||||||
cbReq.Progress = int(progressRes.Progress * 100)
|
|
||||||
s.callback(cbReq)
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) callback(data CBReq) {
|
// 执行任务
|
||||||
// release task num
|
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||||
if data.Success { // 任务成功
|
var res TaskProgressResp
|
||||||
var job model.SdJob
|
response, err := s.httpClient.R().
|
||||||
res := s.db.Where("id = ?", data.JobId).First(&job)
|
SetHeader("Authorization", s.config.ApiKey).
|
||||||
if res.Error != nil {
|
SetSuccessResult(&res).
|
||||||
logger.Warn("非法任务:", res.Error)
|
Get(apiURL)
|
||||||
return
|
if err != nil {
|
||||||
}
|
return err, nil
|
||||||
// 更新任务进度
|
|
||||||
job.Progress = data.Progress
|
|
||||||
// 更新任务 seed
|
|
||||||
var params types.SdTaskParams
|
|
||||||
err := utils.JsonDecode(job.Params, ¶ms)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("任务解析失败:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
params.Seed = data.Seed
|
|
||||||
if data.ImageName != "" { // 下载图片
|
|
||||||
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
|
||||||
if data.Progress == 100 {
|
|
||||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download img: ", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
job.ImgURL = imageURL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
job.Params = utils.JsonEncode(params)
|
|
||||||
res = s.db.Updates(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update job: ", res.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("绘图进度:%d", data.Progress)
|
|
||||||
} else { // 任务失败
|
|
||||||
logger.Error("任务执行失败:", data.Message)
|
|
||||||
// update the task progress
|
|
||||||
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1)
|
|
||||||
// restore img_calls
|
|
||||||
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
}
|
}
|
||||||
|
if response.IsErrorState() {
|
||||||
|
return fmt.Errorf("error http code status: %v", response.Status), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &res
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,47 +1,24 @@
|
|||||||
package sd
|
package sd
|
||||||
|
|
||||||
import logger2 "chatplus/logger"
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import logger2 "geekai/logger"
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
type TaskInfo struct {
|
type NotifyMessage struct {
|
||||||
UserId uint `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
SessionId string `json:"session_id"`
|
JobId int `json:"job_id"`
|
||||||
JobId int `json:"job_id"`
|
Message string `json:"message"`
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
Data []interface{} `json:"data"`
|
|
||||||
EventData interface{} `json:"event_data"`
|
|
||||||
FnIndex int `json:"fn_index"`
|
|
||||||
SessionHash string `json:"session_hash"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CBReq struct {
|
const (
|
||||||
UserId uint
|
Running = "RUNNING"
|
||||||
SessionId string
|
Finished = "FINISH"
|
||||||
JobId int
|
Failed = "FAIL"
|
||||||
TaskId string
|
)
|
||||||
ImageName string
|
|
||||||
ImageData string
|
|
||||||
Progress int
|
|
||||||
Seed int64
|
|
||||||
Success bool
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
var ParamKeys = map[string]int{
|
|
||||||
"task_id": 0,
|
|
||||||
"prompt": 1,
|
|
||||||
"negative_prompt": 2,
|
|
||||||
"steps": 4,
|
|
||||||
"sampler": 5,
|
|
||||||
"face_fix": 7, // 面部修复
|
|
||||||
"cfg_scale": 8,
|
|
||||||
"seed": 27,
|
|
||||||
"height": 10,
|
|
||||||
"width": 9,
|
|
||||||
"hd_fix": 11,
|
|
||||||
"hd_redraw_rate": 12, //高清修复重绘幅度
|
|
||||||
"hd_scale": 13, // 高清修复放大倍数
|
|
||||||
"hd_scale_alg": 14, // 高清修复放大算法
|
|
||||||
"hd_sample_num": 15, // 高清修复采样次数
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,28 +1,36 @@
|
|||||||
package service
|
package sms
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AliYunSmsService struct {
|
type AliYunSmsService struct {
|
||||||
config *types.AliYunSmsConfig
|
config *types.SmsConfigAli
|
||||||
client *dysmsapi.Client
|
client *dysmsapi.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAliYunSmsService(config *types.AppConfig) (*AliYunSmsService, error) {
|
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
|
||||||
|
config := &appConfig.SMS.Ali
|
||||||
// 创建阿里云短信客户端
|
// 创建阿里云短信客户端
|
||||||
client, err := dysmsapi.NewClientWithAccessKey(
|
client, err := dysmsapi.NewClientWithAccessKey(
|
||||||
"cn-hangzhou",
|
"cn-hangzhou",
|
||||||
config.SmsConfig.AccessKey,
|
config.AccessKey,
|
||||||
config.SmsConfig.AccessSecret)
|
config.AccessSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create client: %v", err)
|
return nil, fmt.Errorf("failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AliYunSmsService{
|
return &AliYunSmsService{
|
||||||
config: &config.SmsConfig,
|
config: config,
|
||||||
client: client,
|
client: client,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -46,8 +54,7 @@ func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
|
|||||||
if response.Code != "OK" {
|
if response.Code != "OK" {
|
||||||
return fmt.Errorf("failed to send SMS:%v", response.Message)
|
return fmt.Errorf("failed to send SMS:%v", response.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ SmsService = &AliYunSmsService{}
|
var _ Service = &AliYunSmsService{}
|
||||||
79
api/service/sms/bao.go
Normal file
79
api/service/sms/bao.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package sms
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaoSmsService struct {
|
||||||
|
config *types.SmsConfigBao
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
|
||||||
|
config := appConfig.SMS.Bao
|
||||||
|
if config.Domain == "" { // use default domain
|
||||||
|
config.Domain = "api.smsbao.com"
|
||||||
|
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
|
||||||
|
}
|
||||||
|
return &BaoSmsService{
|
||||||
|
config: &config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errMsg = map[string]string{
|
||||||
|
"0": "短信发送成功",
|
||||||
|
"-1": "参数不全",
|
||||||
|
"-2": "服务器空间不支持,请确认支持curl或者fsocket,联系您的空间商解决或者更换空间",
|
||||||
|
"30": "密码错误",
|
||||||
|
"40": "账号不存在",
|
||||||
|
"41": "余额不足",
|
||||||
|
"42": "账户已过期",
|
||||||
|
"43": "IP地址限制",
|
||||||
|
"50": "内容含有敏感词",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s%s", s.config.Sign, s.config.CodeTemplate)
|
||||||
|
content = strings.ReplaceAll(content, "{code}", strconv.Itoa(code))
|
||||||
|
password := utils.Md5(s.config.Password)
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("u", s.config.Username)
|
||||||
|
params.Set("p", password)
|
||||||
|
params.Set("m", mobile)
|
||||||
|
params.Set("c", content)
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
|
||||||
|
response, err := http.Get(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result := string(body)
|
||||||
|
logger.Debugf("send SmsBao result: %v", errMsg[result])
|
||||||
|
|
||||||
|
if result != "0" {
|
||||||
|
return fmt.Errorf("failed to send SMS:%v", errMsg[result])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Service = &BaoSmsService{}
|
||||||
15
api/service/sms/service.go
Normal file
15
api/service/sms/service.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package sms
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
const Ali = "ALI"
|
||||||
|
const Bao = "BAO"
|
||||||
|
|
||||||
|
type Service interface {
|
||||||
|
SendVerifyCode(mobile string, code int) error
|
||||||
|
}
|
||||||
46
api/service/sms/service_manager.go
Normal file
46
api/service/sms/service_manager.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package sms
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core/types"
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServiceManager struct {
|
||||||
|
handler Service
|
||||||
|
}
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
|
||||||
|
active := Ali
|
||||||
|
if config.SMS.Active != "" {
|
||||||
|
active = strings.ToUpper(config.SMS.Active)
|
||||||
|
}
|
||||||
|
var handler Service
|
||||||
|
switch active {
|
||||||
|
case Ali:
|
||||||
|
client, err := NewAliYunSmsService(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
handler = client
|
||||||
|
break
|
||||||
|
case Bao:
|
||||||
|
handler = NewSmsBaoSmsService(config)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ServiceManager{handler: handler}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ServiceManager) GetService() Service {
|
||||||
|
return m.handler
|
||||||
|
}
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
type SmsService interface {
|
|
||||||
SendVerifyCode(mobile string, code int) error
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user