mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-10-29 05:13:42 +08:00
Compare commits
242 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54b45ec2ff | ||
|
|
c434f85045 | ||
|
|
4d10279870 | ||
|
|
9de9489673 | ||
|
|
9814fec930 | ||
|
|
53ba731159 | ||
|
|
b2f57aa483 | ||
|
|
4c2dba1004 | ||
|
|
79adc871ef | ||
|
|
8144fada25 | ||
|
|
754ba02263 | ||
|
|
7ddf57ae06 | ||
|
|
cc5180a6f7 | ||
|
|
96f1126d02 | ||
|
|
7f9b8d8246 | ||
|
|
5132d52a44 | ||
|
|
abdf5298fe | ||
|
|
2129f7a8b7 | ||
|
|
f6f8748521 | ||
|
|
59301df073 | ||
|
|
e17dcf4d5f | ||
|
|
09f44e6d9b | ||
|
|
59824bffc5 | ||
|
|
cb0dacd5e0 | ||
|
|
7463cfc66c | ||
|
|
b248560ba2 | ||
|
|
37368fe13f | ||
|
|
246b023624 | ||
|
|
9f44c34d34 | ||
|
|
a6b9f57a50 | ||
|
|
42bc23cacf | ||
|
|
282f55c7a3 | ||
|
|
44798f89ba | ||
|
|
596cb2b206 | ||
|
|
d1965deff1 | ||
|
|
b793b81768 | ||
|
|
a5ef4299ec | ||
|
|
cdb1a8bde1 | ||
|
|
233f6e00f0 | ||
|
|
b7dba68549 | ||
|
|
64e5fc48ba | ||
|
|
a692cf1338 | ||
|
|
6998dd7af4 | ||
|
|
9343c73e0f | ||
|
|
739cd46539 | ||
|
|
f8fed83507 | ||
|
|
d63536d5ef | ||
|
|
4905fb28d4 | ||
|
|
a3a2a8abcb | ||
|
|
839dd8dbf4 | ||
|
|
0375164f40 | ||
|
|
691294b444 | ||
|
|
bdea12c51a | ||
|
|
a27d9ea259 | ||
|
|
7cd824c284 | ||
|
|
e27d95e2b5 | ||
|
|
c24b4d7074 | ||
|
|
6839827db0 | ||
|
|
ab24398748 | ||
|
|
6110522b54 | ||
|
|
bcdf5e3776 | ||
|
|
2207830db9 | ||
|
|
d52dfbfef4 | ||
|
|
d6a04f96fe | ||
|
|
66ccb387e8 | ||
|
|
5f820b9dc1 | ||
|
|
3cc2263dc7 | ||
|
|
f0a3c5d8ae | ||
|
|
2a4ef27774 | ||
|
|
2b057f32aa | ||
|
|
bc6451026f | ||
|
|
99fd596862 | ||
|
|
f0959b5df6 | ||
|
|
6788edbe9d | ||
|
|
3895305882 | ||
|
|
1b0938b33f | ||
|
|
c2acbaaa94 | ||
|
|
02faff461a | ||
|
|
e18e5a38c6 | ||
|
|
2f9b1b7835 | ||
|
|
717b137a6d | ||
|
|
f755bdccae | ||
|
|
4bba77ab47 | ||
|
|
6944a32ff3 | ||
|
|
5742b40aee | ||
|
|
7f1ec90748 | ||
|
|
4a99be2f15 | ||
|
|
bee19392c1 | ||
|
|
27c816cf3b | ||
|
|
0d81776212 | ||
|
|
00d31a2379 | ||
|
|
cccab31c0f | ||
|
|
5d65505ab7 | ||
|
|
3dc7d0516a | ||
|
|
50335ebc2d | ||
|
|
bcadee7290 | ||
|
|
cac3194d5b | ||
|
|
4ddf3bf2bf | ||
|
|
d45f9fbad6 | ||
|
|
d98b08d7cd | ||
|
|
5a8fe5a6cf | ||
|
|
36c27d6092 | ||
|
|
3ab29da8f0 | ||
|
|
3699f024f1 | ||
|
|
3d37a3d367 | ||
|
|
73d8236697 | ||
|
|
114d0088dc | ||
|
|
43b6665370 | ||
|
|
5fb9f84182 | ||
|
|
e35c34ad9a | ||
|
|
1a4d798f8b | ||
|
|
afb91a7023 | ||
|
|
dc4c1f7877 | ||
|
|
bbc8fe2b40 | ||
|
|
3c34e8e0e7 | ||
|
|
57c932f07c | ||
|
|
922202734a | ||
|
|
8b3b0139b0 | ||
|
|
31828a3336 | ||
|
|
b270960a04 | ||
|
|
5c4899df6e | ||
|
|
9a797bb4a5 | ||
|
|
b0c9ffc5a6 | ||
|
|
f527cc5b98 | ||
|
|
debe8dc209 | ||
|
|
2f0215ac87 | ||
|
|
dd5cc206e5 | ||
|
|
142cd553a3 | ||
|
|
657ecccee3 | ||
|
|
1232c3cd9c | ||
|
|
3ac04a3938 | ||
|
|
b7abc42209 | ||
|
|
a48179ce0e | ||
|
|
e589f25a05 | ||
|
|
cc1a3ce343 | ||
|
|
7bb76d581c | ||
|
|
0d733c0be0 | ||
|
|
8b40ac5b5c | ||
|
|
24479814e9 | ||
|
|
99df028237 | ||
|
|
b354b88876 | ||
|
|
5e0be4d10e | ||
|
|
468b48151f | ||
|
|
fa5c036041 | ||
|
|
0fdc588167 | ||
|
|
2e023cb8dc | ||
|
|
e933f32d9c | ||
|
|
bd4b0c4d65 | ||
|
|
0b2501c1d8 | ||
|
|
9d28e62142 | ||
|
|
c1d892069e | ||
|
|
61b2dbc9f1 | ||
|
|
be3245666e | ||
|
|
dacdd6fe74 | ||
|
|
6807f7e88a | ||
|
|
087f5ab2d1 | ||
|
|
47c5a0387b | ||
|
|
f9da18ad52 | ||
|
|
5c9025ca22 | ||
|
|
d02cb573fd | ||
|
|
caa538a1d0 | ||
|
|
b584b4bfb6 | ||
|
|
bda335212d | ||
|
|
06f4cdc649 | ||
|
|
336a7d5b56 | ||
|
|
a0f464830f | ||
|
|
9bf7fa4081 | ||
|
|
96ead65774 | ||
|
|
7ad41927aa | ||
|
|
4ca9dfd9c0 | ||
|
|
8a9f386d8f | ||
|
|
adfee8bf58 | ||
|
|
fbfa2a71a9 | ||
|
|
9a1368ef17 | ||
|
|
31b02b97d3 | ||
|
|
42da38c5c3 | ||
|
|
0a01b55713 | ||
|
|
3b292c2a12 | ||
|
|
db0ba0d9a0 | ||
|
|
3a23ff6b42 | ||
|
|
1e9c5adb0a | ||
|
|
abab76ccc6 | ||
|
|
6efd92806f | ||
|
|
cfe333e89f | ||
|
|
a7237fe62f | ||
|
|
c3c454b7d7 | ||
|
|
d4d708d44b | ||
|
|
7f0b6a3a46 | ||
|
|
c2a7c089d2 | ||
|
|
df5bd4df60 | ||
|
|
79b6010104 | ||
|
|
97b0a98793 | ||
|
|
5230f90540 | ||
|
|
803db4e895 | ||
|
|
7cee9f2ebb | ||
|
|
8be9a21efd | ||
|
|
6a3e26b566 | ||
|
|
0355c37bef | ||
|
|
9b7ee538c4 | ||
|
|
d900a3d08e | ||
|
|
cdf5b66729 | ||
|
|
1cff4b63cd | ||
|
|
da14309ef9 | ||
|
|
fbb216fe3b | ||
|
|
95efbd5659 | ||
|
|
4596c1049c | ||
|
|
b35d95f0c7 | ||
|
|
01419df998 | ||
|
|
a6c00c42fa | ||
|
|
4cc9db7115 | ||
|
|
4f1ed54059 | ||
|
|
8227a73e35 | ||
|
|
adfd8c1939 | ||
|
|
be8a0ec184 | ||
|
|
b02e3aad95 | ||
|
|
08eca511ad | ||
|
|
c34e911596 | ||
|
|
8a452c3072 | ||
|
|
13bfb14107 | ||
|
|
4188b0969e | ||
|
|
0c27795a10 | ||
|
|
d05693c5c1 | ||
|
|
c0b2063b38 | ||
|
|
4d183747b1 | ||
|
|
08fe1b2f75 | ||
|
|
db3e8a267e | ||
|
|
8fc62682c4 | ||
|
|
75031914a3 | ||
|
|
a4c9fdd95a | ||
|
|
6a9bfeb5aa | ||
|
|
e654766f60 | ||
|
|
0ef6955f96 | ||
|
|
b4501557c9 | ||
|
|
a2ed99e6cb | ||
|
|
6bd6bb3885 | ||
|
|
399cf65fc9 | ||
|
|
24906a6df1 | ||
|
|
d772bbebe6 | ||
|
|
14988853a3 | ||
|
|
7b3f16ac9f | ||
|
|
82b2755c18 | ||
|
|
4e4dc4cb73 |
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Bug 报告 🐛
|
||||
description: 为 chatgpt-plus 提交错误报告
|
||||
description: 为 geekai 提交错误报告
|
||||
labels: ['Bug']
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: 功能优化 🚀
|
||||
description: 为 chatgpt-plus 提交优化建议
|
||||
description: 为 geekai 提交优化建议
|
||||
labels: ['feature']
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
||||
131
CHANGELOG.md
131
CHANGELOG.md
@@ -1,20 +1,133 @@
|
||||
# 更新日志
|
||||
## 4.0.2
|
||||
## v4.1.1
|
||||
* Bug修复:修复 GPT 模型 function call 调用后没有输出的问题
|
||||
* 功能新增:允许获取 License 授权用户可以自定义版权信息
|
||||
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
|
||||
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
|
||||
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
|
||||
* 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
|
||||
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
|
||||
* 功能新增:允许在管理后台设置首页显示的导航菜单
|
||||
* Bug修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
|
||||
* 功能新增:增加 Suno 文生歌曲功能
|
||||
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
|
||||
* 功能优化:在所有列表页面增加返回顶部按钮
|
||||
|
||||
## v4.1.0
|
||||
* bug修复:修复移动端修改聊天标题不生效的问题
|
||||
* Bug修复:修复用户注册不显示用户名的问题
|
||||
* Bug修复:修复管理后台拖动排序不生效的问题
|
||||
* 功能优化:允许用户设置自定义首页背景图片
|
||||
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
|
||||
* 功能优化:优化聊天界面的用户上传文件的列表样式
|
||||
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
|
||||
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
|
||||
|
||||
|
||||
## v4.0.9
|
||||
* 环境升级:升级 Golang 到 go1.22.4
|
||||
* 功能增加:接入微信商户号支付渠道
|
||||
* Bug修复:修复前端页面菜单把页面撑开,底部留白问题
|
||||
* 功能优化:聊天页面自动根据内容调整输入框的高度
|
||||
* Bug修复:修复Dalle绘图失败退回算力的问题
|
||||
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
|
||||
* 功能优化:允许设置邮件验证码的抬头
|
||||
* Bug修复:修复免费模型不会记录聊天记录的bug
|
||||
* Bug修复:修复聊天输入公式显示异常的Bug
|
||||
|
||||
## v4.0.8
|
||||
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
|
||||
* 功能优化:当数据库更新失败的时候记录错误日志
|
||||
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
|
||||
* Bug修复:修复移动端聊天页面模型切换不生效的Bug
|
||||
* 功能优化:给PC端扫码支付增加签名验证和有效期验证
|
||||
* Bug修复:修复支付码生成API权限控制的问题
|
||||
* Bug修复:模型算力设置为0时,不扣减用户算力,并且不记录算力消费日志
|
||||
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
|
||||
* 功能新增:H5端支持 Dalle 绘图
|
||||
|
||||
## v4.0.7
|
||||
|
||||
* 功能优化:升级quic-go,支持 Go1.21
|
||||
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
|
||||
* Bug修复:修复弹窗等于图形验证码一直验证失败的问题
|
||||
* 功能重构:重构前端 UI 页面,增加顶部导航
|
||||
* 功能优化:优化 Vue 非父子组件之间的通信方式
|
||||
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
|
||||
|
||||
## v4.0.6
|
||||
|
||||
* Bug修复:修复PC端画廊页面的瀑布流组件样式错乱问题
|
||||
* 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
|
||||
* Bug修复:修复思维导图不扣费的Bug
|
||||
* Bug修复:修复管理后台角色删除失败的Bug
|
||||
* Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数
|
||||
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
|
||||
* Bug修复:修复注册用户提示注册人数达到上限的 Bug
|
||||
* 功能优化:将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
|
||||
|
||||
## v4.0.5
|
||||
|
||||
* 功能优化:已授权系统在后台显示授权信息
|
||||
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
|
||||
* 功能优化:优化首页登录注册页面的 UI
|
||||
* BUG修复:修复License验证的逻辑漏洞
|
||||
* Bug修复:后台添加用户的时候密码规则限制跟前台注册保持一致
|
||||
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
|
||||
* 功能新增:移动端新增 DALL-E 绘画功能
|
||||
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
|
||||
* 功能新增:移动支持免登录预览功能
|
||||
* Bug修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
|
||||
* Bug修复:修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
|
||||
|
||||
## v4.0.4
|
||||
|
||||
* Bug修复:修复统一千问第二句不回复的问题
|
||||
* 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
|
||||
* 功能新增:Dalle AI 绘画功能实现
|
||||
* Bug修复:修复思维导图格式乱码问题
|
||||
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
|
||||
* 功能新增:支持从应用列表直接和某个应用对话
|
||||
* 功能优化:优化算力日志的页面和首页的UI
|
||||
* 功能新增:支持思维导图导出 PNG 图片下载
|
||||
|
||||
## v4.0.3
|
||||
|
||||
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
|
||||
* Bug修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
|
||||
* Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
|
||||
* 功能新增:支持上传图片和视觉模型
|
||||
* 功能优化:优化聊天页面的复制代码按钮样式乱码
|
||||
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
|
||||
* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4
|
||||
* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
|
||||
* 功能新增:支持管理后台 Logo 修改
|
||||
|
||||
## 4.0.2
|
||||
|
||||
* 功能新增:支持前端菜单可以配置
|
||||
* 功能优化:手机端支持免登录预览功能
|
||||
* 功能优化:在登录和注册界面标题显示软件版本号
|
||||
* 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
|
||||
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
|
||||
* Bug修复:解决因为图片上传使用相对路径而导致融图失败的问题。
|
||||
* 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
||||
|
||||
## v4.0.1
|
||||
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些
|
||||
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 MJ-Plus 中转
|
||||
|
||||
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
|
||||
发行版,稳定性更强一些
|
||||
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
|
||||
MJ-Plus 中转
|
||||
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
|
||||
* Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
|
||||
|
||||
## v4.0.0
|
||||
|
||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
|
||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ
|
||||
对话消耗15个算力...
|
||||
|
||||
* 功能重构:重构整体系统,全部采用算力来进行结算
|
||||
* 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
||||
@@ -29,6 +142,7 @@
|
||||
* 功能新增:管理后台新增7日内新增用户和新增订单统计
|
||||
|
||||
## v3.2.7
|
||||
|
||||
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
||||
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
||||
* Bug修复:修复 issue [
|
||||
@@ -43,6 +157,7 @@
|
||||
* 功能新增:后台管理新怎对话查看和检索功能
|
||||
|
||||
## v3.2.6
|
||||
|
||||
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
|
||||
* 功能优化:兼用旧版本微信收款消息解析
|
||||
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
|
||||
@@ -56,16 +171,18 @@
|
||||
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
|
||||
* 功能新增:新增短信宝短信平台发送平台集成
|
||||
|
||||
|
||||
## v3.2.5
|
||||
|
||||
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
|
||||
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 Plus 账号了!!!
|
||||
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
|
||||
Plus 账号了!!!
|
||||
* 功能优化:增强 markdown 图片和引用块解析。
|
||||
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
|
||||
* 功能优化:function call 兼用中转 API。
|
||||
* Bug修复:修复部分已知的 Bug。
|
||||
|
||||
## v3.2.4.1
|
||||
|
||||
* 功能新增:新增 PayJs 支付通道
|
||||
* Bug修复:紧急修复后台添加用户失败问题
|
||||
* Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题
|
||||
|
||||
214
LICENSE
214
LICENSE
@@ -1,21 +1,201 @@
|
||||
MIT License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
Copyright (c) 2023 RockYang
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
1. Definitions.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
123
README.md
123
README.md
@@ -1,124 +1,67 @@
|
||||
# ChatGPT-Plus
|
||||
# GeekAI
|
||||
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
**ChatGPT-PLUS** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
|
||||
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。主要有如下特性:
|
||||
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
|
||||
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
|
||||
|
||||
* 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||
* 基于 Websocket 实现,完美的打字机体验。
|
||||
* 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
||||
* 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||
* 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
|
||||
* 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||
* 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||
* 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
||||
主要特性:
|
||||
|
||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||
- 基于 Websocket 实现,完美的打字机体验。
|
||||
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
||||
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||
- 支持 Suno 文生音乐
|
||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
|
||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
||||
绘画函数插件。
|
||||
|
||||
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
|
||||
|
||||
- [x] 更友好的 UI 界面
|
||||
- [x] 支持 Dall-E 文生图功能
|
||||
- [x] 支持文生思维导图
|
||||
- [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能
|
||||
- [x] 支持网站 Logo 版权等信息的修改
|
||||
|
||||
## 功能截图
|
||||
|
||||
### PC 端聊天界面
|
||||
|
||||

|
||||
|
||||
### AI 对话界面
|
||||
|
||||

|
||||
|
||||
### MidJourney 专业绘画界面
|
||||
|
||||

|
||||
|
||||
### Stable-Diffusion 专业绘画页面
|
||||
|
||||

|
||||

|
||||
|
||||
### 绘图作品展
|
||||
|
||||

|
||||
|
||||
### AI应用列表
|
||||
|
||||

|
||||
|
||||
### 会员充值
|
||||
|
||||

|
||||
|
||||
### 自动调用函数插件
|
||||
|
||||

|
||||

|
||||
|
||||
### 管理后台
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
### 移动端 Web 页面
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
|
||||
|
||||
### 体验地址
|
||||
|
||||
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
|
||||
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
|
||||
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
|
||||
|
||||
## 快速部署
|
||||
|
||||
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
|
||||
|
||||
```shell
|
||||
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)"
|
||||
```
|
||||
|
||||
最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。
|
||||
|
||||
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
|
||||
|
||||
* 前端访问地址:http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
|
||||
* 后台管理地址:http://localhost:8080/admin
|
||||
* 移动端地址:http://localhost:8080/mobile
|
||||
* 初始后台管理账号:admin/admin123
|
||||
* 初始前端体验账号:18575670125/12345678
|
||||
|
||||
服务启动成功之后不能立刻使用,需要先登录管理后台 -> API-KEY 去添加一个 OpenAI 或者文心一言,科大讯飞等至少一个平台的 API
|
||||
KEY。
|
||||
|
||||

|
||||
|
||||
另外,如果您目前还没有 OpenAI 的 API KEY的,推荐您去 https://gpt.bemore.lol 购买,**无需魔法,高速稳定,且价格还远低于 OpenAI
|
||||
官方**。
|
||||
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
|
||||
|
||||
## 使用须知
|
||||
|
||||
1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
||||
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
||||
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
|
||||
|
||||
## 项目地址
|
||||
|
||||
* Github 地址:https://github.com/yangjian102621/chatgpt-plus
|
||||
* 码云地址:https://gitee.com/blackfox/chatgpt-plus
|
||||
* Github 地址:https://github.com/yangjian102621/geekai
|
||||
* 码云地址:https://gitee.com/blackfox/geekai
|
||||
|
||||
## 客户端下载
|
||||
|
||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2
|
||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
|
||||
|
||||
## TODOLIST
|
||||
|
||||
* [ ] 支持基于知识库的 AI 问答
|
||||
* [ ] 会员邀请注册推广功能
|
||||
* [ ] 文生视频,文生歌曲功能
|
||||
* [ ] 微信支付功能
|
||||
|
||||
## 项目文档
|
||||
|
||||
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
|
||||
|
||||
详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)。
|
||||
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
|
||||
|
||||
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。**
|
||||
|
||||
@@ -146,4 +89,4 @@ KEY。
|
||||
|
||||

|
||||
|
||||

|
||||

|
||||
|
||||
3
api/.gitignore
vendored
3
api/.gitignore
vendored
@@ -17,4 +17,5 @@ bin
|
||||
data
|
||||
config.toml
|
||||
static/upload
|
||||
storage.json
|
||||
storage.json
|
||||
res/certs/wechat/apiclient_key.pem
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
SHELL=/usr/bin/env bash
|
||||
NAME := chatgpt-plus
|
||||
NAME := geekai
|
||||
all: amd64 arm64
|
||||
|
||||
amd64:
|
||||
|
||||
@@ -5,6 +5,7 @@ StaticDir = "./static" # 静态资源的目录
|
||||
StaticUrl = "/static" # 静态资源访问 URL
|
||||
AesEncryptKey = ""
|
||||
WeChatBot = false
|
||||
TikaHost = "http://tika:9998"
|
||||
|
||||
[Session]
|
||||
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
|
||||
@@ -17,7 +18,7 @@ WeChatBot = false
|
||||
DB = 0
|
||||
|
||||
[ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通
|
||||
ApiURL = ""
|
||||
ApiURL = "https://sapi.geekai.me"
|
||||
AppId = ""
|
||||
Token = ""
|
||||
|
||||
@@ -108,7 +109,8 @@ WeChatBot = false
|
||||
ApiURL = "https://api.xunhupay.com"
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
|
||||
|
||||
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,所以如果需要使用邮件功能,请别用阿里云服务器
|
||||
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
|
||||
UseTls = false
|
||||
Host = "smtp.163.com"
|
||||
Port = 25
|
||||
AppName = "极客学长"
|
||||
@@ -121,4 +123,16 @@ WeChatBot = false
|
||||
AppId = "" # 商户 ID
|
||||
PrivateKey = "" # 秘钥
|
||||
ApiURL = "https://payjs.cn"
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
|
||||
|
||||
# 微信商户支付
|
||||
[WechatPayConfig]
|
||||
Enabled = false
|
||||
AppId = "" # 商户应用ID
|
||||
MchId = "" # 商户号
|
||||
SerialNo = "" # API 证书序列号
|
||||
PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
|
||||
ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/wechat/notify" # 支付成功异步回调地址,域名改成自己的
|
||||
ReturnURL = "" # 支付成功同步回调地址
|
||||
|
||||
|
||||
@@ -1,22 +1,29 @@
|
||||
package core
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/nfnt/resize"
|
||||
"golang.org/x/image/webp"
|
||||
"gorm.io/gorm"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
@@ -25,31 +32,19 @@ import (
|
||||
)
|
||||
|
||||
type AppServer struct {
|
||||
Debug bool
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||
|
||||
Debug bool
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
SysConfig *types.SystemConfig // system config cache
|
||||
|
||||
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
|
||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
}
|
||||
|
||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
Debug: false,
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
Debug: false,
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +83,7 @@ func errorHandler(c *gin.Context) {
|
||||
if r := recover(); r != nil {
|
||||
logger.Errorf("Handler Panic: %v", r)
|
||||
debug.PrintStack()
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
|
||||
c.JSON(http.StatusBadRequest, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
@@ -144,7 +139,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
|
||||
if tokenString == "" {
|
||||
if needLogin(c) {
|
||||
resp.ERROR(c, "You should put Authorization in request headers")
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
} else { // 直接放行
|
||||
@@ -200,10 +195,13 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
|
||||
func needLogin(c *gin.Context) bool {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/user/logout" ||
|
||||
c.Request.URL.Path == "/api/user/resetPass" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/admin/logout" ||
|
||||
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
c.Request.URL.Path == "/api/user/session" ||
|
||||
c.Request.URL.Path == "/api/chat/history" ||
|
||||
c.Request.URL.Path == "/api/chat/detail" ||
|
||||
c.Request.URL.Path == "/api/chat/list" ||
|
||||
@@ -215,13 +213,26 @@ func needLogin(c *gin.Context) bool {
|
||||
c.Request.URL.Path == "/api/invite/hits" ||
|
||||
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||
c.Request.URL.Path == "/api/sd/client" ||
|
||||
c.Request.URL.Path == "/api/config/get" ||
|
||||
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||
c.Request.URL.Path == "/api/dall/client" ||
|
||||
c.Request.URL.Path == "/api/product/list" ||
|
||||
c.Request.URL.Path == "/api/menu/list" ||
|
||||
c.Request.URL.Path == "/api/markMap/client" ||
|
||||
c.Request.URL.Path == "/api/payment/alipay/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/hupipay/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/payjs/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/wechat/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/client" ||
|
||||
c.Request.URL.Path == "/api/suno/Detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
|
||||
return false
|
||||
}
|
||||
@@ -326,6 +337,10 @@ func staticResourceMiddleware() gin.HandlerFunc {
|
||||
|
||||
// 解码图片
|
||||
img, _, err := image.Decode(file)
|
||||
// for .webp image
|
||||
if err != nil {
|
||||
img, err = webp.Decode(file)
|
||||
}
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||
return
|
||||
@@ -342,7 +357,9 @@ func staticResourceMiddleware() gin.HandlerFunc {
|
||||
var buffer bytes.Buffer
|
||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
logger.Error(err)
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置图片缓存有效期为一年 (365天)
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
package core
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/utils"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/utils"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
@@ -23,7 +30,7 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
SecretKey: utils.RandString(64),
|
||||
MaxAge: 86400,
|
||||
},
|
||||
ApiConfig: types.ChatPlusApiConfig{},
|
||||
ApiConfig: types.ApiConfig{},
|
||||
OSS: types.OSSConfig{
|
||||
Active: "local",
|
||||
Local: types.LocalStorageConfig{
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// ApiRequest API 请求实体
|
||||
type ApiRequest struct {
|
||||
Model string `json:"model,omitempty"` // 兼容百度文心一言
|
||||
@@ -8,7 +15,7 @@ type ApiRequest struct {
|
||||
Stream bool `json:"stream"`
|
||||
Messages []interface{} `json:"messages,omitempty"`
|
||||
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
||||
Tools []interface{} `json:"tools,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
@@ -46,22 +53,22 @@ type Delta struct {
|
||||
// ChatSession 聊天会话对象
|
||||
type ChatSession struct {
|
||||
SessionId string `json:"session_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
Username string `json:"username"` // 当前登录的 username
|
||||
UserId uint `json:"user_id"` // 当前登录的 user ID
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model ChatModel `json:"model"` // GPT 模型
|
||||
}
|
||||
|
||||
type ChatModel struct {
|
||||
Id uint `json:"id"`
|
||||
Platform Platform `json:"platform"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
Id uint `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id"` // 绑定 API KEY
|
||||
}
|
||||
|
||||
type ApiError struct {
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
@@ -14,7 +21,7 @@ type AppConfig struct {
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
MjProxyConfigs []MjProxyConfig // MJ proxy config
|
||||
@@ -22,14 +29,17 @@ type AppConfig struct {
|
||||
WeChatBot bool // 是否启用微信机器人
|
||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
||||
|
||||
XXLConfig XXLConfig
|
||||
AlipayConfig AlipayConfig
|
||||
HuPiPayConfig HuPiPayConfig
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
JPayConfig JPayConfig // payjs 支付配置
|
||||
XXLConfig XXLConfig
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
JPayConfig JPayConfig // payjs 支付配置
|
||||
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
||||
TikaHost string // TiKa 服务器地址
|
||||
}
|
||||
|
||||
type SmtpConfig struct {
|
||||
UseTls bool // 是否使用 TLS 发送
|
||||
Host string
|
||||
Port int
|
||||
AppName string // 应用名称
|
||||
@@ -37,7 +47,7 @@ type SmtpConfig struct {
|
||||
Password string // 发件人邮箱密码
|
||||
}
|
||||
|
||||
type ChatPlusApiConfig struct {
|
||||
type ApiConfig struct {
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
@@ -77,6 +87,17 @@ type AlipayConfig struct {
|
||||
ReturnURL string // 支付成功返回地址
|
||||
}
|
||||
|
||||
type WechatPayConfig struct {
|
||||
Enabled bool // 是否启用该支付通道
|
||||
AppId string // 公众号的APPID,如:wxd678efh567hg6787
|
||||
MchId string // 直连商户的商户号,由微信支付生成并下发
|
||||
SerialNo string // 商户证书的证书序列号
|
||||
PrivateKey string // 用户私钥文件路径
|
||||
ApiV3Key string // API V3 秘钥
|
||||
NotifyURL string // 异步通知回调
|
||||
ReturnURL string // 支付成功返回地址
|
||||
}
|
||||
|
||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||
Enabled bool // 是否启用该支付通道
|
||||
Name string // 支付名称,如:wechat/alipay
|
||||
@@ -114,29 +135,37 @@ type RedisConfig struct {
|
||||
DB int
|
||||
}
|
||||
|
||||
// LicenseKey 存储许可证书的 KEY
|
||||
const LicenseKey = "Geek-AI-License"
|
||||
|
||||
type License struct {
|
||||
Key string `json:"key"` // 许可证书密钥
|
||||
MachineId string `json:"machine_id"` // 机器码
|
||||
ExpiredAt int64 `json:"expired_at"` // 过期时间
|
||||
IsActive bool `json:"is_active"` // 是否激活
|
||||
Configs LicenseConfig `json:"configs"`
|
||||
}
|
||||
|
||||
type LicenseConfig struct {
|
||||
UserNum int `json:"user_num"` // 用户数量
|
||||
DeCopy bool `json:"de_copy"` // 去版权
|
||||
}
|
||||
|
||||
func (c RedisConfig) Url() string {
|
||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||
}
|
||||
|
||||
type Platform string
|
||||
|
||||
const OpenAI = Platform("OpenAI")
|
||||
const Azure = Platform("Azure")
|
||||
const ChatGLM = Platform("ChatGLM")
|
||||
const Baidu = Platform("Baidu")
|
||||
const XunFei = Platform("XunFei")
|
||||
const QWen = Platform("QWen")
|
||||
|
||||
type SystemConfig struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
AdminTitle string `json:"admin_title,omitempty"`
|
||||
Title string `json:"title,omitempty"` // 网站标题
|
||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||
Logo string `json:"logo,omitempty"`
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
||||
|
||||
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册
|
||||
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
||||
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||
|
||||
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
|
||||
@@ -144,16 +173,23 @@ type SystemConfig struct {
|
||||
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
|
||||
|
||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||
VipInfoText string `json:"vip_info_text"` // 会员页面充值说明
|
||||
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
||||
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
|
||||
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
|
||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||
|
||||
EnableContext bool `json:"enable_context,omitempty"`
|
||||
ContextDeep int `json:"context_deep,omitempty"`
|
||||
|
||||
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||
|
||||
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type ToolCall struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
@@ -8,19 +15,13 @@ type ToolCall struct {
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function Function `json:"function"`
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters Parameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]Property `json:"properties"`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type OrderStatus int
|
||||
|
||||
const (
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type OSSConfig struct {
|
||||
Active string
|
||||
Local LocalStorageConfig
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
const LoginUserID = "LOGIN_USER_ID"
|
||||
const LoginUserCache = "LOGIN_USER_CACHE"
|
||||
|
||||
const UserAuthHeader = "Authorization"
|
||||
const AdminAuthHeader = "Admin-Authorization"
|
||||
const ChatTokenHeader = "Chat-Token"
|
||||
|
||||
// Session configs struct
|
||||
type Session struct {
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type SMSConfig struct {
|
||||
Active string
|
||||
Ali SmsConfigAli
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// TaskType 任务类别
|
||||
type TaskType string
|
||||
|
||||
@@ -21,10 +28,11 @@ type MjTask struct {
|
||||
TaskId string `json:"task_id"`
|
||||
ImgArr []string `json:"img_arr"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
SessionId string `json:"session_id"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
NegPrompt string `json:"neg_prompt,omitempty"`
|
||||
Params string `json:"full_prompt"`
|
||||
Index int `json:"index,omitempty"`
|
||||
MessageId string `json:"message_id,omitempty"`
|
||||
MessageHash string `json:"message_hash,omitempty"`
|
||||
@@ -33,7 +41,6 @@ type MjTask struct {
|
||||
|
||||
type SdTask struct {
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
SessionId string `json:"session_id"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Params SdTaskParams `json:"params"`
|
||||
@@ -41,19 +48,49 @@ type SdTask struct {
|
||||
}
|
||||
|
||||
type SdTaskParams struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegativePrompt string `json:"negative_prompt"` // 反向提示词
|
||||
Steps int `json:"steps"` // 迭代步数,默认20
|
||||
Sampler string `json:"sampler"` // 采样器
|
||||
FaceFix bool `json:"face_fix"` // 面部修复
|
||||
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
|
||||
Seed int64 `json:"seed"` // 随机数种子
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
HdFix bool `json:"hd_fix"` // 启用高清修复
|
||||
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
|
||||
HdScale int `json:"hd_scale"` // 放大倍数
|
||||
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
|
||||
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||
Steps int `json:"steps"` // 迭代步数,默认20
|
||||
Sampler string `json:"sampler"` // 采样器
|
||||
Scheduler string `json:"scheduler"` // 采样调度
|
||||
FaceFix bool `json:"face_fix"` // 面部修复
|
||||
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
|
||||
Seed int64 `json:"seed"` // 随机数种子
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
HdFix bool `json:"hd_fix"` // 启用高清修复
|
||||
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
|
||||
HdScale int `json:"hd_scale"` // 放大倍数
|
||||
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
|
||||
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
|
||||
}
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
|
||||
Power int `json:"power"`
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type int `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Title string `json:"title"`
|
||||
RefTaskId string `json:"ref_task_id"`
|
||||
RefSongId string `json:"ref_song_id"`
|
||||
Prompt string `json:"prompt"` // 提示词/歌词
|
||||
Tags string `json:"tags"`
|
||||
Model string `json:"model"`
|
||||
Instrumental bool `json:"instrumental"` // 是否纯音乐
|
||||
ExtendSecs int `json:"extend_secs"` // 延长秒杀
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// BizVo 业务返回 VO
|
||||
type BizVo struct {
|
||||
Code BizCode `json:"code"`
|
||||
@@ -15,13 +22,14 @@ type WsMessage struct {
|
||||
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type WsMsgType string
|
||||
|
||||
const (
|
||||
WsStart = WsMsgType("start")
|
||||
WsMiddle = WsMsgType("middle")
|
||||
WsEnd = WsMsgType("end")
|
||||
WsMjImg = WsMsgType("mj")
|
||||
WsErr = WsMsgType("error")
|
||||
)
|
||||
|
||||
type BizCode int
|
||||
@@ -29,11 +37,9 @@ type BizCode int
|
||||
const (
|
||||
Success = BizCode(0)
|
||||
Failed = BizCode(1)
|
||||
NotAuthorized = BizCode(400) // 未授权
|
||||
NotPermission = BizCode(403) // 没有权限
|
||||
NotAuthorized = BizCode(401) // 未授权
|
||||
|
||||
OkMsg = "Success"
|
||||
ErrorMsg = "系统开小差了"
|
||||
InvalidArgs = "非法参数或参数解析失败"
|
||||
NoData = "No Data"
|
||||
)
|
||||
|
||||
59
api/go.mod
59
api/go.mod
@@ -1,6 +1,8 @@
|
||||
module chatplus
|
||||
module geekai
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
toolchain go1.22.4
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.1.0
|
||||
@@ -17,7 +19,6 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/smartwalle/alipay/v3 v3.2.15
|
||||
go.uber.org/zap v1.23.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/driver/mysql v1.4.7
|
||||
@@ -26,19 +27,37 @@ require (
|
||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
||||
|
||||
require (
|
||||
github.com/mojocn/base64Captcha v1.3.1
|
||||
github.com/go-pay/gopay v1.5.101
|
||||
github.com/google/go-tika v0.3.1
|
||||
github.com/microcosm-cc/bluemonday v1.0.26
|
||||
github.com/mojocn/base64Captcha v1.3.6
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.3.1
|
||||
github.com/syndtr/goleveldb v1.0.0
|
||||
golang.org/x/image v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-pay/crypto v0.0.1 // indirect
|
||||
github.com/go-pay/errgroup v0.0.2 // indirect
|
||||
github.com/go-pay/util v0.0.2 // indirect
|
||||
github.com/go-pay/xlog v0.0.2 // indirect
|
||||
github.com/go-pay/xtime v0.0.2 // indirect
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||
github.com/gorilla/css v1.0.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.4 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||
@@ -49,7 +68,6 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
@@ -66,26 +84,21 @@ require (
|
||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
|
||||
github.com/quic-go/quic-go v0.35.1 // indirect
|
||||
github.com/quic-go/quic-go v0.45.0 // indirect
|
||||
github.com/refraction-networking/utls v1.3.2 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/smartwalle/ncrypto v1.0.2 // indirect
|
||||
github.com/smartwalle/ngx v1.0.6 // indirect
|
||||
github.com/smartwalle/nsign v1.0.8 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
go.uber.org/dig v1.16.1 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/net v0.14.0 // indirect
|
||||
golang.org/x/sync v0.3.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
golang.org/x/tools v0.10.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
@@ -104,7 +117,7 @@ require (
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/fx v1.19.3
|
||||
go.uber.org/multierr v1.6.0 // indirect
|
||||
golang.org/x/crypto v0.12.0
|
||||
golang.org/x/sys v0.11.0 // indirect
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
gorm.io/gorm v1.25.1
|
||||
)
|
||||
|
||||
170
api/go.sum
170
api/go.sum
@@ -6,12 +6,15 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
@@ -27,7 +30,9 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
|
||||
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
|
||||
@@ -39,8 +44,24 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU
|
||||
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
|
||||
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-pay/crypto v0.0.1 h1:B6InT8CLfSLc6nGRVx9VMJRBBazFMjr293+jl0lLXUY=
|
||||
github.com/go-pay/crypto v0.0.1/go.mod h1:41oEIvHMKbNcYlWUlRWtsnC6+ASgh7u29z0gJXe5bes=
|
||||
github.com/go-pay/errgroup v0.0.2 h1:5mZMdm0TDClDm2S3G0/sm0f8AuQRtz0dOrTHDR9R8Cc=
|
||||
github.com/go-pay/errgroup v0.0.2/go.mod h1:0+4b8mvFMS71MIzsaC+gVvB4x37I93lRb2dqrwuU8x8=
|
||||
github.com/go-pay/gopay v1.5.101 h1:rVb+sfv6hiQtknAlZnTTLvU27NvFJ4p0yglN/vPpGXI=
|
||||
github.com/go-pay/gopay v1.5.101/go.mod h1:AW4Yj8jDZX9BM1/GTLTY1Gy5SHjiq8kQvG5sBTN2sxI=
|
||||
github.com/go-pay/util v0.0.2 h1:goJ4f6kNY5zzdtg1Cj8oWC+Cw7bfg/qq2rJangMAb9U=
|
||||
github.com/go-pay/util v0.0.2/go.mod h1:qM8VbyF1n7YAPZBSJONSPMPsPedhUTktewUAdf1AjPg=
|
||||
github.com/go-pay/xlog v0.0.2 h1:kUg5X8/5VZAPDg1J5eGjA3MG0/H5kK6Ew0dW/Bycsws=
|
||||
github.com/go-pay/xlog v0.0.2/go.mod h1:DbjMADPK4+Sjxj28ekK9goqn4zmyY4hql/zRiab+S9E=
|
||||
github.com/go-pay/xtime v0.0.2 h1:7YR4/iuELsEHpJ6LUO0SVK80hQxDO9MLCfuVYIiTCRM=
|
||||
github.com/go-pay/xtime v0.0.2/go.mod h1:W1yRbJaSt4CSBcdAtLBQ8xajiN/Pl5hquGczUcUE9xE=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -65,17 +86,22 @@ github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJ
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
|
||||
github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
|
||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
@@ -83,6 +109,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
|
||||
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
@@ -116,6 +143,8 @@ github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259
|
||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
|
||||
github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc=
|
||||
@@ -128,15 +157,21 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
|
||||
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
|
||||
github.com/mojocn/base64Captcha v1.3.6 h1:gZEKu1nsKpttuIAQgWHO+4Mhhls8cAKyiV2Ew03H+Tw=
|
||||
github.com/mojocn/base64Captcha v1.3.6/go.mod h1:i5CtHvm+oMbj1UzEPXaA8IH/xHFZ3DGY3Wh3dBpZ28E=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs=
|
||||
github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE=
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
|
||||
github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4=
|
||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
|
||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
@@ -154,12 +189,8 @@ github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYX
|
||||
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
|
||||
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
||||
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
|
||||
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
|
||||
github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8=
|
||||
github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
@@ -167,20 +198,14 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8=
|
||||
github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE=
|
||||
github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI=
|
||||
github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
|
||||
github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps=
|
||||
github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
|
||||
github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo=
|
||||
github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -193,6 +218,14 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
||||
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
|
||||
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
|
||||
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
|
||||
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
|
||||
github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
|
||||
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
|
||||
@@ -203,8 +236,9 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
@@ -213,6 +247,9 @@ go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk=
|
||||
go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA=
|
||||
go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM=
|
||||
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
|
||||
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
|
||||
@@ -221,38 +258,43 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
|
||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
|
||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.13.0/go.mod h1:6mmbMOeV28HuMTgA6OSRkdXKYw/t5W9Uwn2Yv1r3Yxk=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
||||
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
|
||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -261,45 +303,55 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
|
||||
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -49,12 +55,6 @@ func (h *ManagerHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// add captcha
|
||||
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
|
||||
resp.ERROR(c, "验证码错误!")
|
||||
return
|
||||
}
|
||||
|
||||
var manager model.AdminUser
|
||||
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||
if res.Error != nil {
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -23,7 +31,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Platform string `json:"platform"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
@@ -40,7 +47,6 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
if data.Id > 0 {
|
||||
h.DB.Find(&apiKey, data.Id)
|
||||
}
|
||||
apiKey.Platform = data.Platform
|
||||
apiKey.Value = data.Value
|
||||
apiKey.Type = data.Type
|
||||
apiKey.ApiURL = data.ApiURL
|
||||
@@ -49,6 +55,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
apiKey.Name = data.Name
|
||||
res := h.DB.Save(&apiKey)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -65,14 +72,24 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
status := h.GetBool(c, "status")
|
||||
t := h.GetTrim(c, "type")
|
||||
platform := h.GetTrim(c, "platform")
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if status {
|
||||
session = session.Where("enabled", true)
|
||||
}
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
}
|
||||
if platform != "" {
|
||||
session = session.Where("platform", platform)
|
||||
}
|
||||
|
||||
var items []model.ApiKey
|
||||
var keys = make([]vo.ApiKey, 0)
|
||||
res := h.DB.Find(&items)
|
||||
res := session.Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var key vo.ApiKey
|
||||
@@ -104,6 +121,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
||||
|
||||
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -111,19 +129,17 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if data.Id > 0 {
|
||||
res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
res := h.DB.Where("id", id).Delete(&model.ApiKey{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/handler"
|
||||
"chatplus/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
)
|
||||
|
||||
type CaptchaHandler struct {
|
||||
handler.BaseHandler
|
||||
}
|
||||
|
||||
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
|
||||
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
|
||||
}
|
||||
|
||||
type CaptchaVo struct {
|
||||
CaptchaId string `json:"captcha_id"`
|
||||
PicPath string `json:"pic_path"`
|
||||
}
|
||||
|
||||
// GetCaptcha 获取验证码
|
||||
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
|
||||
var captchaVo CaptchaVo
|
||||
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
|
||||
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
|
||||
// b64s是图片的base64编码
|
||||
id, b64s, err := cp.Generate()
|
||||
if err != nil {
|
||||
resp.ERROR(c, "生成验证码错误!")
|
||||
return
|
||||
}
|
||||
captchaVo.CaptchaId = id
|
||||
captchaVo.PicPath = b64s
|
||||
|
||||
resp.SUCCESS(c, captchaVo)
|
||||
}
|
||||
@@ -1,13 +1,20 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -33,11 +40,6 @@ type chatItemVo struct {
|
||||
}
|
||||
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Title string `json:"title"`
|
||||
UserId uint `json:"user_id"`
|
||||
@@ -259,6 +261,7 @@ func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
|
||||
if tx.Error != nil {
|
||||
logger.Error("error with update database:", tx.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatModelHandler struct {
|
||||
@@ -34,6 +41,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
@@ -41,24 +49,32 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
item := model.ChatModel{
|
||||
Platform: data.Platform,
|
||||
Name: data.Name,
|
||||
Value: data.Value,
|
||||
Enabled: data.Enabled,
|
||||
SortNum: data.SortNum,
|
||||
Open: data.Open,
|
||||
MaxTokens: data.MaxTokens,
|
||||
MaxContext: data.MaxContext,
|
||||
Temperature: data.Temperature,
|
||||
Power: data.Power}
|
||||
item.Id = data.Id
|
||||
if item.Id > 0 {
|
||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||
item := model.ChatModel{}
|
||||
// 更新
|
||||
if data.Id > 0 {
|
||||
h.DB.Where("id", data.Id).First(&item)
|
||||
}
|
||||
|
||||
item.Name = data.Name
|
||||
item.Value = data.Value
|
||||
item.Enabled = data.Enabled
|
||||
item.SortNum = data.SortNum
|
||||
item.Open = data.Open
|
||||
item.Power = data.Power
|
||||
item.MaxTokens = data.MaxTokens
|
||||
item.MaxContext = data.MaxContext
|
||||
item.Temperature = data.Temperature
|
||||
item.KeyId = data.KeyId
|
||||
|
||||
var res *gorm.DB
|
||||
if data.Id > 0 {
|
||||
res = h.DB.Save(&item)
|
||||
} else {
|
||||
res = h.DB.Create(&item)
|
||||
}
|
||||
res := h.DB.Save(&item)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,31 +91,45 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
|
||||
// List 模型列表
|
||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
enable := h.GetBool(c, "enable")
|
||||
name := h.GetTrim(c, "name")
|
||||
if enable {
|
||||
session = session.Where("enabled", enable)
|
||||
}
|
||||
if name != "" {
|
||||
session = session.Where("name LIKE ?", name+"%")
|
||||
}
|
||||
var items []model.ChatModel
|
||||
var cms = make([]vo.ChatModel, 0)
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var cm vo.ChatModel
|
||||
err := utils.CopyObject(item, &cm)
|
||||
if err == nil {
|
||||
cm.Id = item.Id
|
||||
cm.CreatedAt = item.CreatedAt.Unix()
|
||||
cm.UpdatedAt = item.UpdatedAt.Unix()
|
||||
cms = append(cms, cm)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
if res.Error != nil {
|
||||
resp.SUCCESS(c, cms)
|
||||
return
|
||||
}
|
||||
|
||||
// initialize key name
|
||||
keyIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
keyIds = append(keyIds, v.KeyId)
|
||||
}
|
||||
var keys []model.ApiKey
|
||||
keyMap := make(map[uint]string)
|
||||
h.DB.Where("id IN ?", keyIds).Find(&keys)
|
||||
for _, v := range keys {
|
||||
keyMap[v.Id] = v.Name
|
||||
}
|
||||
for _, item := range items {
|
||||
var cm vo.ChatModel
|
||||
err := utils.CopyObject(item, &cm)
|
||||
if err == nil {
|
||||
cm.Id = item.Id
|
||||
cm.CreatedAt = item.CreatedAt.Unix()
|
||||
cm.UpdatedAt = item.UpdatedAt.Unix()
|
||||
cm.KeyName = keyMap[uint(item.KeyId)]
|
||||
cms = append(cms, cm)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, cms)
|
||||
@@ -119,6 +149,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
|
||||
|
||||
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -139,6 +170,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
|
||||
for index, id := range data.Ids {
|
||||
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -156,6 +188,7 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
|
||||
|
||||
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatRoleHandler struct {
|
||||
@@ -40,6 +48,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
||||
}
|
||||
res := h.DB.Save(&role)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -50,11 +59,6 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
var items []model.ChatRole
|
||||
var roles = make([]vo.ChatRole, 0)
|
||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||
@@ -63,6 +67,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// initialize model mane for role
|
||||
modelIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
if v.ModelId > 0 {
|
||||
modelIds = append(modelIds, v.ModelId)
|
||||
}
|
||||
}
|
||||
|
||||
modelNameMap := make(map[int]string)
|
||||
if len(modelIds) > 0 {
|
||||
var models []model.ChatModel
|
||||
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
||||
if tx.Error == nil {
|
||||
for _, m := range models {
|
||||
modelNameMap[int(m.Id)] = m.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var role vo.ChatRole
|
||||
err := utils.CopyObject(v, &role)
|
||||
@@ -70,6 +93,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
role.Id = v.Id
|
||||
role.CreatedAt = v.CreatedAt.Unix()
|
||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||
role.ModelName = modelNameMap[role.ModelId]
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
@@ -92,6 +116,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
||||
for index, id := range data.Ids {
|
||||
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -114,6 +139,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
||||
|
||||
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -121,19 +147,15 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if data.Id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
|
||||
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "删除失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,23 +1,45 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConfigHandler struct {
|
||||
handler.BaseHandler
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
mjServicePool *mj.ServicePool
|
||||
sdServicePool *sd.ServicePool
|
||||
}
|
||||
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
||||
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
levelDB: levelDB,
|
||||
mjServicePool: mjPool,
|
||||
sdServicePool: sdPool,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
@@ -28,6 +50,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
Content string `json:"content,omitempty"`
|
||||
Updated bool `json:"updated,omitempty"`
|
||||
} `json:"config"`
|
||||
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
@@ -35,6 +58,12 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// ONLY authorized user can change the copyright
|
||||
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
|
||||
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
|
||||
return
|
||||
}
|
||||
|
||||
value := utils.JsonEncode(&data.Config)
|
||||
config := model.Config{Key: data.Key, Config: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
|
||||
@@ -70,11 +99,6 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
|
||||
// Get 获取指定的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
@@ -92,3 +116,88 @@ func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c, value)
|
||||
}
|
||||
|
||||
// Active 激活系统
|
||||
func (h *ConfigHandler) Active(c *gin.Context) {
|
||||
var data struct {
|
||||
License string `json:"license"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
info, err := host.Info()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = h.licenseService.ActiveLicense(data.License, info.HostID)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, info.HostID)
|
||||
}
|
||||
|
||||
// GetLicense 获取 License 信息
|
||||
func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
license := h.licenseService.GetLicense()
|
||||
resp.SUCCESS(c, license)
|
||||
}
|
||||
|
||||
// GetAppConfig 获取内置配置
|
||||
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"mj_plus": h.App.Config.MjPlusConfigs,
|
||||
"mj_proxy": h.App.Config.MjProxyConfigs,
|
||||
"sd": h.App.Config.SdConfigs,
|
||||
})
|
||||
}
|
||||
|
||||
// SaveDrawingConfig 保存AI绘画配置
|
||||
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
|
||||
var data struct {
|
||||
Sd []types.StableDiffusionConfig `json:"sd"`
|
||||
MjPlus []types.MjPlusConfig `json:"mj_plus"`
|
||||
MjProxy []types.MjProxyConfig `json:"mj_proxy"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
if configChanged(data.Sd, h.App.Config.SdConfigs) {
|
||||
logger.Debugf("SD 配置变动了")
|
||||
h.App.Config.SdConfigs = data.Sd
|
||||
h.sdServicePool.InitServices(data.Sd)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
|
||||
logger.Debugf("MidJourney 配置变动了")
|
||||
h.App.Config.MjPlusConfigs = data.MjPlus
|
||||
h.App.Config.MjProxyConfigs = data.MjProxy
|
||||
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
err := core.SaveConfig(h.App.Config)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置文档失败!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
|
||||
}
|
||||
|
||||
func configChanged(c1 interface{}, c2 interface{}) bool {
|
||||
encode1 := utils.JsonEncode(c1)
|
||||
encode2 := utils.JsonEncode(c2)
|
||||
return utils.Md5(encode1) != utils.Md5(encode2)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
@@ -64,6 +71,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
||||
|
||||
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -71,11 +79,6 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *FunctionHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
var items []model.Function
|
||||
res := h.DB.Find(&items)
|
||||
if res.Error != nil {
|
||||
@@ -101,6 +104,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
|
||||
if id > 0 {
|
||||
res := h.DB.Delete(&model.Function{Id: uint(id)})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
132
api/handler/admin/menu_handler.go
Normal file
132
api/handler/admin/menu_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MenuHandler struct {
|
||||
handler.BaseHandler
|
||||
}
|
||||
|
||||
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
URL string `json:"url"`
|
||||
SortNum int `json:"sort_num"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
res := h.DB.Save(&model.Menu{
|
||||
Id: data.Id,
|
||||
Name: data.Name,
|
||||
Icon: data.Icon,
|
||||
URL: data.URL,
|
||||
SortNum: data.SortNum,
|
||||
Enabled: data.Enabled,
|
||||
})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// List 数据列表
|
||||
func (h *MenuHandler) List(c *gin.Context) {
|
||||
var items []model.Menu
|
||||
var list = make([]vo.Menu, 0)
|
||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var product vo.Menu
|
||||
err := utils.CopyObject(item, &product)
|
||||
if err == nil {
|
||||
list = append(list, product)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, list)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Enable(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Sort(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
Sorts []int `json:"sorts"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
for index, id := range data.Ids {
|
||||
res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
|
||||
if id > 0 {
|
||||
res := h.DB.Where("id", id).Delete(&model.Menu{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
@@ -1,13 +1,20 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -22,11 +29,6 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||
}
|
||||
|
||||
func (h *OrderHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
OrderNo string `json:"order_no"`
|
||||
Status int `json:"status"`
|
||||
@@ -92,6 +94,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
||||
|
||||
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
@@ -50,6 +57,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
||||
}
|
||||
res := h.DB.Save(&item)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -65,21 +73,11 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
||||
resp.SUCCESS(c, itemVo)
|
||||
}
|
||||
|
||||
// List 模型列表
|
||||
// List 数据列表
|
||||
func (h *ProductHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
enable := h.GetBool(c, "enable")
|
||||
if enable {
|
||||
session = session.Where("enabled", enable)
|
||||
}
|
||||
var items []model.Product
|
||||
var list = make([]vo.Product, 0)
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var product vo.Product
|
||||
@@ -110,6 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
|
||||
|
||||
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -128,8 +127,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
|
||||
}
|
||||
|
||||
for index, id := range data.Ids {
|
||||
res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||
res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -142,8 +142,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
|
||||
if id > 0 {
|
||||
res := h.DB.Where("id = ?", id).Delete(&model.Product{})
|
||||
res := h.DB.Where("id", id).Delete(&model.Product{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -21,11 +28,6 @@ func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
||||
}
|
||||
|
||||
func (h *RewardHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
var items []model.Reward
|
||||
res := h.DB.Order("id DESC").Find(&items)
|
||||
var rewards = make([]vo.Reward, 0)
|
||||
@@ -70,6 +72,7 @@ func (h *RewardHandler) Remove(c *gin.Context) {
|
||||
if data.Id > 0 {
|
||||
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/handler"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/handler"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -17,19 +25,15 @@ import (
|
||||
|
||||
type UserHandler struct {
|
||||
handler.BaseHandler
|
||||
licenseService *service.LicenseService
|
||||
}
|
||||
|
||||
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
|
||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
|
||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||
}
|
||||
|
||||
// List 用户列表
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||
resp.NotPermission(c)
|
||||
return
|
||||
}
|
||||
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
username := h.GetTrim(c, "username")
|
||||
@@ -80,6 +84,13 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
// 检测最大注册人数
|
||||
var totalUser int64
|
||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||
return
|
||||
}
|
||||
var user = model.User{}
|
||||
var res *gorm.DB
|
||||
var userVo vo.User
|
||||
@@ -100,7 +111,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
|
||||
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
// 记录算力日志
|
||||
@@ -124,10 +136,16 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// 检查用户是否已经存在
|
||||
h.DB.Where("username", data.Username).First(&user)
|
||||
if user.Id > 0 {
|
||||
resp.ERROR(c, "用户名已存在")
|
||||
return
|
||||
}
|
||||
|
||||
salt := utils.RandString(8)
|
||||
u := model.User{
|
||||
Username: data.Username,
|
||||
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
@@ -137,6 +155,11 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
ChatModels: utils.JsonEncode(data.ChatModels),
|
||||
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
||||
}
|
||||
if h.licenseService.GetLicense().Configs.DeCopy {
|
||||
u.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
} else {
|
||||
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
||||
}
|
||||
res = h.DB.Create(&u)
|
||||
_ = utils.CopyObject(u, &userVo)
|
||||
userVo.Id = u.Id
|
||||
@@ -145,6 +168,7 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -23,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
var items []model.ChatModel
|
||||
var chatModels = make([]vo.ChatModel, 0)
|
||||
var res *gorm.DB
|
||||
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
|
||||
t := c.Query("type")
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
}
|
||||
// 如果用户没有登录,则加载所有开放模型
|
||||
if !h.IsLogin(c) {
|
||||
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
|
||||
res = session.Where("open", true).Order("sort_num ASC").Find(&items)
|
||||
} else {
|
||||
user, _ := h.GetLoginUser(c)
|
||||
var models []int
|
||||
@@ -36,7 +49,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
}
|
||||
// 查询用户有权限访问的模型以及所有开放的模型
|
||||
res = h.DB.Where("enabled = ?", true).Where(
|
||||
h.DB.Where("id IN ?", models).Or("open =?", true),
|
||||
h.DB.Where("id IN ?", models).Or("open", true),
|
||||
).Order("sort_num ASC").Find(&items)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -22,45 +29,32 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||
|
||||
// List 获取用户聊天应用列表
|
||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
all := h.GetBool(c, "all")
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
var roles []model.ChatRole
|
||||
query := h.DB.Where("enable", true)
|
||||
if userId > 0 {
|
||||
var user model.User
|
||||
h.DB.First(&user, userId)
|
||||
var roleKeys []string
|
||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
}
|
||||
query = query.Where("marker IN ?", roleKeys)
|
||||
}
|
||||
if id > 0 {
|
||||
query = query.Or("id", id)
|
||||
}
|
||||
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No roles found,"+res.Error.Error())
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取所有角色
|
||||
if userId == 0 || all {
|
||||
// 转成 vo
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
for _, r := range roles {
|
||||
var v vo.ChatRole
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
v.Id = r.Id
|
||||
roleVos = append(roleVos, v)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, roleVos)
|
||||
return
|
||||
}
|
||||
|
||||
var user model.User
|
||||
h.DB.First(&user, userId)
|
||||
var roleKeys []string
|
||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
}
|
||||
// 转成 vo
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
for _, r := range roles {
|
||||
if !utils.ContainsStr(roleKeys, r.Key) {
|
||||
continue
|
||||
}
|
||||
var v vo.ChatRole
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
@@ -89,7 +83,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
||||
|
||||
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
|
||||
if res.Error != nil {
|
||||
logger.Error("添加应用失败:", err)
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
package chatimpl
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// 微软 Azure 模型消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendAzureMessage(
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
return nil
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return err
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
}
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
break
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
continue
|
||||
} else if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
}
|
||||
} // end for
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else {
|
||||
logger.Error("信息读取出错:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
replyTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: replyTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
var res types.ApiError
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
|
||||
if strings.Contains(res.Error.Message, "maximum context length") {
|
||||
logger.Error(res.Error.Message)
|
||||
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
||||
h.App.ChatContexts.Delete(session.ChatId)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
package chatimpl
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type baiduResp struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
Result string `json:"result"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// 百度文心一言消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendBaiduMessage(
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
return nil
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return err
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
}
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var content string
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if len(line) < 5 || strings.HasPrefix(line, "id:") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
content = line[5:]
|
||||
}
|
||||
|
||||
// 处理代码换行
|
||||
if len(content) == 0 {
|
||||
content = "\n"
|
||||
}
|
||||
|
||||
var resp baiduResp
|
||||
err := utils.JsonDecode(content, &resp)
|
||||
if err != nil {
|
||||
logger.Error("error with parse data line: ", err)
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||
break
|
||||
}
|
||||
|
||||
if len(contents) == 0 {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(resp.Result),
|
||||
})
|
||||
contents = append(contents, resp.Result)
|
||||
|
||||
if resp.IsTruncated {
|
||||
utils.ReplyMessage(ws, "AI 输出异常中断")
|
||||
break
|
||||
}
|
||||
|
||||
if resp.IsEnd {
|
||||
break
|
||||
}
|
||||
|
||||
} // end for
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else {
|
||||
logger.Error("信息读取出错:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var res struct {
|
||||
Code int `json:"error_code"`
|
||||
Msg string `json:"error_msg"`
|
||||
}
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
|
||||
ctx := context.Background()
|
||||
tokenString, err := h.redis.Get(ctx, apiKey).Result()
|
||||
if err == nil {
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
expr := time.Hour * 24 * 20 // access_token 有效期
|
||||
key := strings.Split(apiKey, "|")
|
||||
if len(key) != 2 {
|
||||
return "", fmt.Errorf("invalid api key: %s", apiKey)
|
||||
}
|
||||
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("POST", url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with send request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with read response: %w", err)
|
||||
}
|
||||
var r map[string]interface{}
|
||||
err = json.Unmarshal(body, &r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse response: %w", err)
|
||||
}
|
||||
|
||||
if r["error"] != nil {
|
||||
return "", fmt.Errorf("error with api response: %s", r["error_description"])
|
||||
}
|
||||
|
||||
tokenString = fmt.Sprintf("%s", r["access_token"])
|
||||
h.redis.Set(ctx, apiKey, tokenString, expr)
|
||||
return tokenString, nil
|
||||
}
|
||||
@@ -1,25 +1,35 @@
|
||||
package chatimpl
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
@@ -27,30 +37,25 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
||||
|
||||
var ErrImg = ""
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type ChatHandler struct {
|
||||
handler.BaseHandler
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Init() {
|
||||
// 如果后台有上传微信客服微信二维码,则覆盖
|
||||
if h.App.SysConfig.WechatCardURL != "" {
|
||||
ErrImg = fmt.Sprintf("", h.App.SysConfig.WechatCardURL)
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,30 +73,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
modelId := h.GetInt(c, "model_id", 0)
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
var chatRole model.ChatRole
|
||||
res := h.DB.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// if the role bind a model_id, use role's bind model_id
|
||||
if chatRole.ModelId > 0 {
|
||||
modelId = chatRole.ModelId
|
||||
}
|
||||
// get model info
|
||||
var chatModel model.ChatModel
|
||||
res := h.DB.First(&chatModel, modelId)
|
||||
res = h.DB.First(&chatModel, modelId)
|
||||
if res.Error != nil || chatModel.Enabled == false {
|
||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
session := h.App.ChatSession.Get(sessionId)
|
||||
if session == nil {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
logger.Info("用户未登录")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
session = &types.ChatSession{
|
||||
SessionId: sessionId,
|
||||
ClientIP: c.ClientIP(),
|
||||
Username: user.Username,
|
||||
UserId: user.Id,
|
||||
}
|
||||
h.App.ChatSession.Put(sessionId, session)
|
||||
session := &types.ChatSession{
|
||||
SessionId: sessionId,
|
||||
ClientIP: c.ClientIP(),
|
||||
UserId: h.GetLoginUserId(c),
|
||||
}
|
||||
|
||||
// use old chat data override the chat model and role ID
|
||||
@@ -111,30 +116,19 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
MaxTokens: chatModel.MaxTokens,
|
||||
MaxContext: chatModel.MaxContext,
|
||||
Temperature: chatModel.Temperature,
|
||||
Platform: types.Platform(chatModel.Platform)}
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||
var chatRole model.ChatRole
|
||||
res = h.DB.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
KeyId: chatModel.KeyId}
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
|
||||
h.Init()
|
||||
|
||||
// 保存会话连接
|
||||
h.App.ChatClients.Put(sessionId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||
client.Close()
|
||||
h.App.ChatClients.Delete(sessionId)
|
||||
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
|
||||
cancelFunc := h.ReqCancelFunc.Get(sessionId)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
h.App.ReqCancelFunc.Delete(sessionId)
|
||||
h.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -154,12 +148,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
logger.Info("Receive a message: ", message.Content)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
||||
h.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
utils.ReplyMessage(client, err.Error())
|
||||
} else {
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
logger.Infof("回答完毕: %v", message.Content)
|
||||
@@ -181,8 +175,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
|
||||
if res.Error != nil {
|
||||
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!")
|
||||
return res.Error
|
||||
return errors.New("未授权用户,您正在进行非法操作!")
|
||||
}
|
||||
var userVo vo.User
|
||||
err := utils.CopyObject(user, &userVo)
|
||||
@@ -192,92 +185,67 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
if userVo.Status == false {
|
||||
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
}
|
||||
|
||||
if userVo.Power < session.Model.Power {
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型的单次对话需要消耗的算力(%d)!", userVo.Power, session.Model.Power))
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power)
|
||||
}
|
||||
|
||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
return errors.New("您的账号已经过期,请联系管理员!")
|
||||
}
|
||||
|
||||
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||
if promptTokens > session.Model.MaxContext {
|
||||
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
|
||||
return nil
|
||||
|
||||
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: session.Model.Value,
|
||||
Stream: true,
|
||||
}
|
||||
switch session.Model.Platform {
|
||||
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
break
|
||||
case types.OpenAI:
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
// OpenAI 支持函数功能
|
||||
var items []model.Function
|
||||
res := h.DB.Where("enabled", true).Find(&items)
|
||||
if res.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var tools = make([]interface{}, 0)
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
// OpenAI 支持函数功能
|
||||
var items []model.Function
|
||||
res = h.DB.Where("enabled", true).Find(&items)
|
||||
if res.Error == nil {
|
||||
var tools = make([]types.Tool, 0)
|
||||
for _, v := range items {
|
||||
var parameters map[string]interface{}
|
||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
required := parameters["required"]
|
||||
delete(parameters, "required")
|
||||
tools = append(tools, gin.H{
|
||||
"type": "function",
|
||||
"function": gin.H{
|
||||
"name": v.Name,
|
||||
"description": v.Description,
|
||||
"parameters": parameters,
|
||||
"required": required,
|
||||
tool := types.Tool{
|
||||
Type: "function",
|
||||
Function: types.Function{
|
||||
Name: v.Name,
|
||||
Description: v.Description,
|
||||
Parameters: parameters,
|
||||
},
|
||||
})
|
||||
}
|
||||
if v, ok := parameters["required"]; v == nil || !ok {
|
||||
tool.Function.Parameters["required"] = []string{}
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
req.Tools = tools
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
case types.QWen:
|
||||
req.Parameters = map[string]interface{}{
|
||||
"max_tokens": session.Model.MaxTokens,
|
||||
"temperature": session.Model.Temperature,
|
||||
}
|
||||
break
|
||||
|
||||
default:
|
||||
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 加载聊天上下文
|
||||
chatCtx := make([]types.Message, 0)
|
||||
messages := make([]types.Message, 0)
|
||||
if h.App.SysConfig.EnableContext {
|
||||
if h.App.ChatContexts.Has(session.ChatId) {
|
||||
messages = h.App.ChatContexts.Get(session.ChatId)
|
||||
if h.ChatContexts.Has(session.ChatId) {
|
||||
messages = h.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
_ = utils.JsonDecode(role.Context, &messages)
|
||||
if h.App.SysConfig.ContextDeep > 0 {
|
||||
@@ -325,37 +293,69 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
reqMgs = append(reqMgs, m)
|
||||
}
|
||||
|
||||
if session.Model.Platform == types.QWen {
|
||||
req.Input = map[string]interface{}{"prompt": prompt}
|
||||
if len(reqMgs) > 0 {
|
||||
req.Input["messages"] = reqMgs
|
||||
fullPrompt := prompt
|
||||
text := prompt
|
||||
// extract files in prompt
|
||||
files := utils.ExtractFileURLs(prompt)
|
||||
logger.Debugf("detected FILES: %+v", files)
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
|
||||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
|
||||
strings.HasSuffix(session.Model.Value, "claude-3")) {
|
||||
contents := make([]string, 0)
|
||||
var file model.File
|
||||
for _, v := range files {
|
||||
h.DB.Where("url = ?", v).First(&file)
|
||||
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
} else {
|
||||
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
|
||||
}
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
})
|
||||
}
|
||||
|
||||
switch session.Model.Platform {
|
||||
case types.Azure:
|
||||
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.OpenAI:
|
||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.ChatGLM:
|
||||
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.Baidu:
|
||||
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.XunFei:
|
||||
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.QWen:
|
||||
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
|
||||
if tokens > session.Model.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform),
|
||||
logger.Debug("最终Prompt:", fullPrompt)
|
||||
|
||||
// extract images from prompt
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": v,
|
||||
},
|
||||
})
|
||||
}
|
||||
data = append(data, gin.H{
|
||||
"type": "text",
|
||||
"text": strings.TrimSpace(text),
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
return nil
|
||||
|
||||
logger.Debugf("%+v", req.Messages)
|
||||
|
||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
@@ -415,55 +415,36 @@ func getTotalTokens(req types.ApiRequest) int {
|
||||
// StopGenerate 停止生成
|
||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
sessionId := c.Query("session_id")
|
||||
if h.App.ReqCancelFunc.Has(sessionId) {
|
||||
h.App.ReqCancelFunc.Get(sessionId)()
|
||||
h.App.ReqCancelFunc.Delete(sessionId)
|
||||
if h.ReqCancelFunc.Has(sessionId) {
|
||||
h.ReqCancelFunc.Get(sessionId)()
|
||||
h.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
// 发送请求到 OpenAI 服务器
|
||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
||||
if res.Error != nil {
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if session.Model.KeyId > 0 {
|
||||
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if apiKey.Id == 0 {
|
||||
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||
}
|
||||
if apiKey.Id == 0 {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
var apiURL string
|
||||
switch platform {
|
||||
case types.Azure:
|
||||
md := strings.Replace(req.Model, ".", "", 1)
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
||||
break
|
||||
case types.ChatGLM:
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
||||
req.Messages = nil
|
||||
break
|
||||
case types.Baidu:
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||
break
|
||||
case types.QWen:
|
||||
apiURL = apiKey.ApiURL
|
||||
req.Messages = nil
|
||||
break
|
||||
default:
|
||||
apiURL = apiKey.ApiURL
|
||||
}
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// 百度文心,需要串接 access_token
|
||||
if platform == types.Baidu {
|
||||
token, err := h.getBaiduToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Info("百度文心 Access_Token:", token)
|
||||
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
|
||||
}
|
||||
|
||||
// ONLY allow apiURL in blank list
|
||||
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debugf(utils.JsonEncode(req))
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
@@ -477,8 +458,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
|
||||
request = request.WithContext(ctx)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
var proxyURL string
|
||||
if apiKey.ProxyURL != "" { // 使用代理
|
||||
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -488,28 +468,10 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
|
||||
switch platform {
|
||||
case types.Azure:
|
||||
request.Header.Set("api-key", apiKey.Value)
|
||||
break
|
||||
case types.ChatGLM:
|
||||
token, err := h.getChatGLMToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
break
|
||||
case types.Baidu:
|
||||
request.RequestURI = ""
|
||||
case types.OpenAI:
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
break
|
||||
case types.QWen:
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
request.Header.Set("X-DashScope-SSE", "enable")
|
||||
break
|
||||
}
|
||||
logger.Debugf("Sending %s request, Channel:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
// 更新API KEY 最后使用时间
|
||||
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
@@ -539,6 +501,98 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
|
||||
|
||||
}
|
||||
|
||||
func (h *ChatHandler) saveChatHistory(
|
||||
req types.ApiRequest,
|
||||
prompt string,
|
||||
contents []string,
|
||||
message types.Message,
|
||||
chatCtx []types.Message,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
userVo vo.User,
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
if session.Model.Power > 0 {
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = userVo.Id
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
|
||||
// 将AI回复消息中生成的图片链接下载到本地
|
||||
func (h *ChatHandler) extractImgUrl(text string) string {
|
||||
pattern := `!\[([^\]]*)]\(([^)]+)\)`
|
||||
@@ -554,7 +608,7 @@ func (h *ChatHandler) extractImgUrl(text string) string {
|
||||
continue
|
||||
}
|
||||
|
||||
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
|
||||
newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
continue
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package chatimpl
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -89,7 +96,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
for _, chat := range chats {
|
||||
chatIds = append(chatIds, chat.ChatId)
|
||||
// 清空会话上下文
|
||||
h.App.ChatContexts.Delete(chat.ChatId)
|
||||
h.ChatContexts.Delete(chat.ChatId)
|
||||
}
|
||||
err = h.DB.Transaction(func(tx *gorm.DB) error {
|
||||
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
||||
@@ -101,8 +108,6 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
|
||||
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -168,7 +173,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
||||
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
|
||||
|
||||
// 清空会话上下文
|
||||
h.App.ChatContexts.Delete(chatId)
|
||||
h.ChatContexts.Delete(chatId)
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
@@ -187,12 +192,20 @@ func (h *ChatHandler) Detail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 填充角色名称
|
||||
var role model.ChatRole
|
||||
res = h.DB.Where("id", chatItem.RoleId).First(&role)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Role not found")
|
||||
return
|
||||
}
|
||||
|
||||
var chatItemVo vo.ChatItem
|
||||
err := utils.CopyObject(chatItem, &chatItemVo)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
chatItemVo.RoleName = role.Name
|
||||
resp.SUCCESS(c, chatItemVo)
|
||||
}
|
||||
|
||||
@@ -1,236 +0,0 @@
|
||||
package chatimpl
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"html/template"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// 清华大学 ChatGML 消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendChatGLMMessage(
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
return nil
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return err
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
}
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var event, content string
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if len(line) < 5 || strings.HasPrefix(line, "id:") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
event = line[6:]
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
content = line[5:]
|
||||
}
|
||||
// 处理代码换行
|
||||
if len(content) == 0 {
|
||||
content = "\n"
|
||||
}
|
||||
switch event {
|
||||
case "add":
|
||||
if len(contents) == 0 {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(content),
|
||||
})
|
||||
contents = append(contents, content)
|
||||
case "finish":
|
||||
break
|
||||
case "error":
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
|
||||
break
|
||||
case "interrupted":
|
||||
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
|
||||
}
|
||||
|
||||
} // end for
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else {
|
||||
logger.Error("信息读取出错:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var res struct {
|
||||
Code int `json:"code"`
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
if !res.Success {
|
||||
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
|
||||
ctx := context.Background()
|
||||
tokenString, err := h.redis.Get(ctx, apiKey).Result()
|
||||
if err == nil {
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
expr := time.Hour * 2
|
||||
key := strings.Split(apiKey, ".")
|
||||
if len(key) != 2 {
|
||||
return "", fmt.Errorf("invalid api key: %s", apiKey)
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"api_key": key[0],
|
||||
"timestamp": time.Now().Unix(),
|
||||
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
|
||||
})
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
delete(token.Header, "typ")
|
||||
// Sign and get the complete encoded token as a string using the secret
|
||||
tokenString, err = token.SignedString([]byte(key[1]))
|
||||
h.redis.Set(ctx, apiKey, tokenString, expr)
|
||||
return tokenString, err
|
||||
}
|
||||
@@ -1,21 +1,26 @@
|
||||
package chatimpl
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
req2 "github.com/imroc/req/v3"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
req2 "github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// OPenAI 消息发送实现
|
||||
@@ -31,24 +36,13 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
return nil
|
||||
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
if response.Body != nil {
|
||||
all, _ := io.ReadAll(response.Body)
|
||||
logger.Error(string(all))
|
||||
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
@@ -65,18 +59,26 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
var toolCall = false
|
||||
var arguments = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
var isNew = true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
if err != nil { // 数据解析出错
|
||||
return errors.New(line)
|
||||
}
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -103,8 +105,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
|
||||
if res.Error == nil {
|
||||
toolCall = true
|
||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
|
||||
contents = append(contents, callMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -114,16 +118,16 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
break
|
||||
}
|
||||
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
continue
|
||||
} else if responseBody.Choices[0].FinishReason != "" {
|
||||
// output stopped
|
||||
if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
if isNew {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
isNew = false
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
@@ -140,7 +144,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
}
|
||||
|
||||
if toolCall { // 调用函数完成任务
|
||||
var params map[string]interface{}
|
||||
params := make(map[string]interface{})
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
params["user_id"] = userVo.Id
|
||||
@@ -173,126 +177,11 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext && toolCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
useContext := true
|
||||
if toolCall {
|
||||
useContext = false
|
||||
}
|
||||
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: useContext,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var replyTokens = 0
|
||||
if toolCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
||||
replyTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
replyTokens += tokens
|
||||
} else {
|
||||
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
replyTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: h.extractImgUrl(message.Content),
|
||||
Tokens: replyTokens,
|
||||
UseContext: useContext,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
var res types.ApiError
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
|
||||
// OpenAI API 调用异常处理
|
||||
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
||||
// 移除当前 API key
|
||||
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
||||
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||
logger.Error(res.Error.Message)
|
||||
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
||||
h.App.ChatContexts.Delete(session.ChatId)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
}
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%s", body)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
package chatimpl
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type qWenResp struct {
|
||||
Output struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Text string `json:"text"`
|
||||
} `json:"output,omitempty"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
RequestID string `json:"request_id"`
|
||||
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// 通义千问消息发送实现
|
||||
func (h *ChatHandler) sendQWenMessage(
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
return nil
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return err
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
}
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
|
||||
var content, lastText, newText string
|
||||
var outPutStart = false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
|
||||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
content = line[5:]
|
||||
}
|
||||
|
||||
var resp qWenResp
|
||||
if len(contents) == 0 { // 发送消息头
|
||||
if !outPutStart {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
outPutStart = true
|
||||
continue
|
||||
} else {
|
||||
// 处理代码换行
|
||||
content = "\n"
|
||||
}
|
||||
} else {
|
||||
err := utils.JsonDecode(content, &resp)
|
||||
if err != nil {
|
||||
logger.Error("error with parse data line: ", content)
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||
break
|
||||
}
|
||||
if resp.Message != "" {
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
//通过比较 lastText(上一次的文本)和 currentText(当前的文本),
|
||||
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
|
||||
//每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
|
||||
currentText := resp.Output.Text
|
||||
if currentText != lastText {
|
||||
// 提取新增文本
|
||||
newText = strings.Replace(currentText, lastText, "", 1)
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(newText),
|
||||
})
|
||||
lastText = currentText // 更新 lastText
|
||||
}
|
||||
contents = append(contents, newText)
|
||||
|
||||
if resp.Output.FinishReason == "stop" {
|
||||
break
|
||||
}
|
||||
|
||||
} //end for
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else {
|
||||
logger.Error("信息读取出错:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var res struct {
|
||||
Code int `json:"error_code"`
|
||||
Msg string `json:"error_msg"`
|
||||
}
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package chatimpl
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type xunFeiResp struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
} `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
Text struct {
|
||||
QuestionTokens int `json:"question_tokens"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
var Model2URL = map[string]string{
|
||||
"general": "v1.1",
|
||||
"generalv2": "v2.1",
|
||||
"generalv3": "v3.1",
|
||||
"generalv3.5": "v3.5",
|
||||
}
|
||||
|
||||
// 科大讯飞消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendXunFeiMessage(
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
var apiKey model.ApiKey
|
||||
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||
if res.Error != nil {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
}
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
key := strings.Split(apiKey.Value, "|")
|
||||
if len(key) != 3 {
|
||||
utils.ReplyMessage(ws, "非法的 API KEY!")
|
||||
return nil
|
||||
}
|
||||
|
||||
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
|
||||
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
||||
//握手并建立websocket 连接
|
||||
conn, resp, err := d.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
logger.Error(readResp(resp) + err.Error())
|
||||
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
|
||||
return nil
|
||||
} else if resp.StatusCode != 101 {
|
||||
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
data := buildRequest(key[0], req)
|
||||
fmt.Printf("%+v", data)
|
||||
fmt.Println(apiURL)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var content string
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.Error("error with read message:", err)
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
|
||||
break
|
||||
}
|
||||
|
||||
// 解析数据
|
||||
var result xunFeiResp
|
||||
err = json.Unmarshal(msg, &result)
|
||||
if err != nil {
|
||||
logger.Error("error with parsing JSON:", err)
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
if result.Header.Code != 0 {
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
|
||||
return nil
|
||||
}
|
||||
|
||||
content = result.Payload.Choices.Text[0].Content
|
||||
// 处理代码换行
|
||||
if len(content) == 0 {
|
||||
content = "\n"
|
||||
}
|
||||
contents = append(contents, content)
|
||||
// 第一个结果
|
||||
if result.Payload.Choices.Status == 0 {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(content),
|
||||
})
|
||||
|
||||
if result.Payload.Choices.Status == 2 { // 最终结果
|
||||
_ = conn.Close() // 关闭连接
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
|
||||
return nil
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.DB.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.DB.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
h.DB.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建 websocket 请求实体
|
||||
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"header": map[string]interface{}{
|
||||
"app_id": appid,
|
||||
},
|
||||
"parameter": map[string]interface{}{
|
||||
"chat": map[string]interface{}{
|
||||
"domain": req.Model,
|
||||
"temperature": req.Temperature,
|
||||
"top_k": int64(6),
|
||||
"max_tokens": int64(req.MaxTokens),
|
||||
"auditing": "default",
|
||||
},
|
||||
},
|
||||
"payload": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"text": req.Messages,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 创建鉴权 URL
|
||||
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
|
||||
ul, err := url.Parse(hostURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
//拼接签名字符串
|
||||
signStr := strings.Join(signString, "\n")
|
||||
sha := hmacWithSha256(signStr, apiSecret)
|
||||
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
//将请求参数使用base64编码
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
//将编码后的字符串url encode后添加到url后面
|
||||
return hostURL + "?" + v.Encode(), nil
|
||||
}
|
||||
|
||||
// 使用 sha256 签名
|
||||
func hmacWithSha256(data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
func readResp(resp *http.Response) string {
|
||||
if resp == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
|
||||
}
|
||||
@@ -1,10 +1,18 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -12,10 +20,11 @@ import (
|
||||
|
||||
type ConfigHandler struct {
|
||||
BaseHandler
|
||||
licenseService *service.LicenseService
|
||||
}
|
||||
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
||||
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler {
|
||||
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||
}
|
||||
|
||||
// Get 获取指定的系统配置
|
||||
@@ -37,3 +46,9 @@ func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c, value)
|
||||
}
|
||||
|
||||
// License 获取 License 配置
|
||||
func (h *ConfigHandler) License(c *gin.Context) {
|
||||
license := h.licenseService.GetLicense()
|
||||
resp.SUCCESS(c, license.Configs)
|
||||
}
|
||||
|
||||
255
api/handler/dalle_handler.go
Normal file
255
api/handler/dalle_handler.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DallJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
service *dalle.Service
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
|
||||
return &DallJobHandler{
|
||||
service: service,
|
||||
uploader: manager,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *DallJobHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.service.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
h.service.Clients.Delete(uint(userId))
|
||||
return
|
||||
}
|
||||
|
||||
var message types.WsMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
if message.Type == "heartbeat" {
|
||||
logger.Debug("收到 DallE 心跳消息:", message.Content)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return false
|
||||
}
|
||||
if user.Power < h.App.SysConfig.DallPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
if !h.preCheck(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var data types.DallTask
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
job := model.DallJob{
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Power: h.App.SysConfig.DallPower,
|
||||
}
|
||||
res := h.DB.Create(&job)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.service.PushTask(types.DallTask{
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
})
|
||||
|
||||
client := h.service.Clients.Get(job.UserId)
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// ImgWall 照片墙
|
||||
func (h *DallJobHandler) ImgWall(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 0)
|
||||
pageSize := h.GetInt(c, "page_size", 0)
|
||||
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, jobs)
|
||||
}
|
||||
|
||||
// JobList 获取 SD 任务列表
|
||||
func (h *DallJobHandler) JobList(c *gin.Context) {
|
||||
finish := h.GetBool(c, "finish")
|
||||
userId := h.GetLoginUserId(c)
|
||||
page := h.GetInt(c, "page", 0)
|
||||
pageSize := h.GetInt(c, "page_size", 0)
|
||||
publish := h.GetBool(c, "publish")
|
||||
|
||||
err, jobs := h.getData(finish, userId, page, pageSize, publish)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, jobs)
|
||||
}
|
||||
|
||||
// JobList 获取任务列表
|
||||
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if finish {
|
||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||
} else {
|
||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||
}
|
||||
if userId > 0 {
|
||||
session = session.Where("user_id = ?", userId)
|
||||
}
|
||||
if publish {
|
||||
session = session.Where("publish", publish)
|
||||
}
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
session = session.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
|
||||
var items []model.DallJob
|
||||
res := session.Find(&items)
|
||||
if res.Error != nil {
|
||||
return res.Error, nil
|
||||
}
|
||||
|
||||
var jobs = make([]vo.DallJob, 0)
|
||||
for _, item := range items {
|
||||
var job vo.DallJob
|
||||
err := utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
return nil, jobs
|
||||
}
|
||||
|
||||
// Remove remove task image
|
||||
func (h *DallJobHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
var job model.DallJob
|
||||
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// remove job recode
|
||||
res := h.DB.Delete(&model.DallJob{Id: job.Id})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// remove image
|
||||
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Publish 发布/取消发布图片到画廊显示
|
||||
func (h *DallJobHandler) Publish(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
|
||||
res := h.DB.Model(&model.DallJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
@@ -1,29 +1,44 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FunctionHandler struct {
|
||||
BaseHandler
|
||||
config types.ChatPlusApiConfig
|
||||
config types.ApiConfig
|
||||
uploadManager *oss.UploaderManager
|
||||
dallService *dalle.Service
|
||||
}
|
||||
|
||||
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
|
||||
func NewFunctionHandler(
|
||||
server *core.AppServer,
|
||||
db *gorm.DB,
|
||||
config *types.AppConfig,
|
||||
manager *oss.UploaderManager,
|
||||
dallService *dalle.Service) *FunctionHandler {
|
||||
return &FunctionHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: server,
|
||||
@@ -31,6 +46,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo
|
||||
},
|
||||
config: config.ApiConfig,
|
||||
uploadManager: manager,
|
||||
dallService: dallService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,30 +167,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
||||
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||||
}
|
||||
|
||||
type imgReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
|
||||
type imgRes struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
RevisedPrompt string `json:"revised_prompt"`
|
||||
Url string `json:"url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type ErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Dall3 DallE3 AI 绘图
|
||||
func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
if err := h.checkAuth(c); err != nil {
|
||||
@@ -190,85 +182,45 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
|
||||
logger.Debugf("绘画参数:%+v", params)
|
||||
var user model.User
|
||||
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||
if tx.Error != nil {
|
||||
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "当前用户不存在!")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.DallPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// create dall task
|
||||
prompt := utils.InterfaceToString(params["prompt"])
|
||||
// get image generation API KEY
|
||||
var apiKey model.ApiKey
|
||||
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
||||
job := model.DallJob{
|
||||
UserId: user.Id,
|
||||
Prompt: prompt,
|
||||
Power: h.App.SysConfig.DallPower,
|
||||
}
|
||||
res = h.DB.Create(&job)
|
||||
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// translate prompt
|
||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
|
||||
if err == nil {
|
||||
logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt)
|
||||
prompt = pt
|
||||
}
|
||||
var res imgRes
|
||||
var errRes ErrRes
|
||||
var request *req.Request
|
||||
if apiKey.ProxyURL != "" {
|
||||
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
|
||||
} else {
|
||||
request = req.C().R()
|
||||
}
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
|
||||
r, err := request.SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(imgReq{
|
||||
Model: "dall-e-3",
|
||||
Prompt: prompt,
|
||||
N: 1,
|
||||
Size: "1024x1024",
|
||||
}).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&res).Post(apiKey.ApiURL)
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
|
||||
return
|
||||
}
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
logger.Debugf("%+v", res)
|
||||
// 存储图片
|
||||
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
|
||||
content, err := h.dallService.Image(types.DallTask{
|
||||
JobId: job.Id,
|
||||
UserId: user.Id,
|
||||
Prompt: job.Prompt,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
Style: "vivid",
|
||||
Power: job.Power,
|
||||
}, true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "下载图片失败: "+err.Error())
|
||||
resp.ERROR(c, "任务执行失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL)
|
||||
// 更新用户算力
|
||||
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
|
||||
// 记录算力变化日志
|
||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||
var u model.User
|
||||
h.DB.Where("id", user.Id).First(&u)
|
||||
h.DB.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: h.App.SysConfig.DallPower,
|
||||
Balance: u.Power,
|
||||
Mark: types.PowerSub,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
|
||||
257
api/handler/markmap_handler.go
Normal file
257
api/handler/markmap_handler.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MarkMapHandler 生成思维导图
|
||||
type MarkMapHandler struct {
|
||||
BaseHandler
|
||||
clients *types.LMap[int, *types.WsClient]
|
||||
}
|
||||
|
||||
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
|
||||
return &MarkMapHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
clients: types.NewLMap[int, *types.WsClient](),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
modelId := h.GetInt(c, "model_id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.clients.Put(userId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
h.clients.Delete(userId)
|
||||
return
|
||||
}
|
||||
|
||||
var message types.WsMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
if message.Type == "heartbeat" {
|
||||
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
|
||||
continue
|
||||
}
|
||||
// change model
|
||||
if message.Type == "model_id" {
|
||||
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Receive a message: ", message.Content)
|
||||
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, userId)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query user info: %v", res.Error)
|
||||
}
|
||||
var chatModel model.ChatModel
|
||||
res = h.DB.Where("id", modelId).First(&chatModel)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query chat model: %v", res.Error)
|
||||
}
|
||||
|
||||
if user.Status == false {
|
||||
return errors.New("当前用户被禁用")
|
||||
}
|
||||
|
||||
if user.Power < chatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 0)
|
||||
messages = append(messages, types.Message{Role: "system", Content: `
|
||||
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
|
||||
# Geek-AI 助手
|
||||
|
||||
## 完整的开源系统
|
||||
### 前端开源
|
||||
### 后端开源
|
||||
|
||||
## 支持各种大模型
|
||||
### OpenAI
|
||||
### Azure
|
||||
### 文心一言
|
||||
### 通义千问
|
||||
|
||||
## 集成多种收费方式
|
||||
### 支付宝
|
||||
### 微信
|
||||
|
||||
另外,除此之外不要任何解释性语句。
|
||||
`})
|
||||
messages = append(messages, types.Message{Role: "user", Content: prompt})
|
||||
var req = types.ApiRequest{
|
||||
Model: chatModel.Value,
|
||||
Stream: true,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
response, err := h.doRequest(req, chatModel, &apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
// 循环读取 Chunk 消息
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
var isNew = true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
return fmt.Errorf("error with decode data: %v", line)
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" {
|
||||
break
|
||||
}
|
||||
|
||||
if isNew {
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
|
||||
isNew = false
|
||||
}
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
} // end for
|
||||
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
|
||||
} else {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
if chatModel.Power > 0 {
|
||||
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
|
||||
if res.Error == nil {
|
||||
// 记录算力消费日志
|
||||
var u model.User
|
||||
h.DB.Where("id", userId).First(&u)
|
||||
h.DB.Create(&model.PowerLog{
|
||||
UserId: u.Id,
|
||||
Username: u.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: chatModel.Power,
|
||||
Mark: types.PowerSub,
|
||||
Balance: u.Power,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if chatModel.KeyId > 0 {
|
||||
session = session.Where("id", chatModel.KeyId)
|
||||
} else { // use the last unused key
|
||||
session = session.Where("type", "chat").
|
||||
Where("enabled", true).Order("last_used_at ASC")
|
||||
}
|
||||
|
||||
res := session.First(apiKey)
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
return client.Do(request)
|
||||
}
|
||||
49
api/handler/menu_handler.go
Normal file
49
api/handler/menu_handler.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MenuHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// List 数据列表
|
||||
func (h *MenuHandler) List(c *gin.Context) {
|
||||
index := h.GetBool(c, "index")
|
||||
var items []model.Menu
|
||||
var list = make([]vo.Menu, 0)
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
session = session.Where("enabled", true)
|
||||
if index {
|
||||
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
|
||||
}
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var product vo.Menu
|
||||
err := utils.CopyObject(item, &product)
|
||||
if err == nil {
|
||||
list = append(list, product)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, list)
|
||||
}
|
||||
@@ -1,17 +1,24 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -85,20 +92,22 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
// Image 创建一个绘画任务
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
SessionId string `json:"session_id"`
|
||||
TaskType string `json:"task_type"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegPrompt string `json:"neg_prompt"`
|
||||
Rate string `json:"rate"`
|
||||
Model string `json:"model"`
|
||||
Chaos int `json:"chaos"`
|
||||
Raw bool `json:"raw"`
|
||||
Seed int64 `json:"seed"`
|
||||
Stylize int `json:"stylize"`
|
||||
Model string `json:"model"` // 模型
|
||||
Chaos int `json:"chaos"` // 创意度取值范围: 0-100
|
||||
Raw bool `json:"raw"` // 是否开启原始模型
|
||||
Seed int64 `json:"seed"` // 随机数
|
||||
Stylize int `json:"stylize"` // 风格化
|
||||
ImgArr []string `json:"img_arr"`
|
||||
Tile bool `json:"tile"`
|
||||
Quality float32 `json:"quality"`
|
||||
Weight float32 `json:"weight"`
|
||||
Tile bool `json:"tile"` // 重复平铺
|
||||
Quality float32 `json:"quality"` // 画质
|
||||
Iw float32 `json:"iw"`
|
||||
CRef string `json:"cref"` //生成角色一致的图像
|
||||
SRef string `json:"sref"` //生成风格一致的图像
|
||||
Cw int `json:"cw"` // 参考程度
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -108,41 +117,57 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var prompt = data.Prompt
|
||||
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
||||
prompt += " --ar " + data.Rate
|
||||
var params = ""
|
||||
if data.Rate != "" && !strings.Contains(params, "--ar") {
|
||||
params += " --ar " + data.Rate
|
||||
}
|
||||
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
|
||||
prompt += fmt.Sprintf(" --seed %d", data.Seed)
|
||||
if data.Seed > 0 && !strings.Contains(params, "--seed") {
|
||||
params += fmt.Sprintf(" --seed %d", data.Seed)
|
||||
}
|
||||
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
|
||||
prompt += fmt.Sprintf(" --s %d", data.Stylize)
|
||||
if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
|
||||
params += fmt.Sprintf(" --s %d", data.Stylize)
|
||||
}
|
||||
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
|
||||
prompt += fmt.Sprintf(" --c %d", data.Chaos)
|
||||
if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
|
||||
params += fmt.Sprintf(" --c %d", data.Chaos)
|
||||
}
|
||||
if data.Weight > 0 {
|
||||
prompt += fmt.Sprintf(" --iw %f", data.Weight)
|
||||
if len(data.ImgArr) > 0 && data.Iw > 0 {
|
||||
params += fmt.Sprintf(" --iw %.2f", data.Iw)
|
||||
}
|
||||
if data.Raw {
|
||||
prompt += " --style raw"
|
||||
params += " --style raw"
|
||||
}
|
||||
if data.Quality > 0 {
|
||||
prompt += fmt.Sprintf(" --q %.2f", data.Quality)
|
||||
}
|
||||
if data.NegPrompt != "" {
|
||||
prompt += fmt.Sprintf(" --no %s", data.NegPrompt)
|
||||
params += fmt.Sprintf(" --q %.2f", data.Quality)
|
||||
}
|
||||
if data.Tile {
|
||||
prompt += " --tile "
|
||||
params += " --tile "
|
||||
}
|
||||
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
|
||||
prompt += fmt.Sprintf(" %s", data.Model)
|
||||
if data.CRef != "" {
|
||||
params += fmt.Sprintf(" --cref %s", data.CRef)
|
||||
if data.Cw > 0 {
|
||||
params += fmt.Sprintf(" --cw %d", data.Cw)
|
||||
} else {
|
||||
params += " --cw 100"
|
||||
}
|
||||
}
|
||||
|
||||
if data.SRef != "" {
|
||||
params += fmt.Sprintf(" --sref %s", data.SRef)
|
||||
}
|
||||
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
|
||||
params += fmt.Sprintf(" %s", data.Model)
|
||||
}
|
||||
|
||||
// 处理融图和换脸的提示词
|
||||
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
|
||||
prompt = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
|
||||
params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
|
||||
}
|
||||
|
||||
// 如果本地图片上传的是相对地址,处理成绝对地址
|
||||
for k, v := range data.ImgArr {
|
||||
if !strings.HasPrefix(v, "http") {
|
||||
data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/"))
|
||||
}
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
@@ -158,7 +183,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Progress: 0,
|
||||
Prompt: prompt,
|
||||
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
||||
Power: h.App.SysConfig.MjPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@@ -179,9 +204,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
TaskId: taskId,
|
||||
SessionId: data.SessionId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
Prompt: prompt,
|
||||
Prompt: data.Prompt,
|
||||
NegPrompt: data.NegPrompt,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
})
|
||||
@@ -216,17 +242,12 @@ type reqVo struct {
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
SessionId string `json:"session_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
ChatId string `json:"chat_id"`
|
||||
RoleId int `json:"role_id"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
// Upscale send upscale command to MidJourney Bot
|
||||
func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
var data reqVo
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" {
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
@@ -244,7 +265,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Progress: 0,
|
||||
Prompt: data.Prompt,
|
||||
Power: h.App.SysConfig.MjActionPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@@ -255,9 +275,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
SessionId: data.SessionId,
|
||||
Type: types.TaskUpscale,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
ChannelId: data.ChannelId,
|
||||
Index: data.Index,
|
||||
@@ -292,7 +310,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
// Variation send variation command to MidJourney Bot
|
||||
func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
var data reqVo
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" {
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
@@ -311,7 +329,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Progress: 0,
|
||||
Prompt: data.Prompt,
|
||||
Power: h.App.SysConfig.MjActionPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@@ -322,9 +339,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
SessionId: data.SessionId,
|
||||
Type: types.TaskVariation,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
Index: data.Index,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -372,13 +387,13 @@ func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
|
||||
|
||||
// JobList 获取 MJ 任务列表
|
||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||
status := h.GetBool(c, "status")
|
||||
finish := h.GetBool(c, "finish")
|
||||
userId := h.GetLoginUserId(c)
|
||||
page := h.GetInt(c, "page", 0)
|
||||
pageSize := h.GetInt(c, "page_size", 0)
|
||||
publish := h.GetBool(c, "publish")
|
||||
|
||||
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||
err, jobs := h.getData(finish, userId, page, pageSize, publish)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -391,7 +406,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if finish {
|
||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||
session = session.Where("progress >= ?", 100).Order("id DESC")
|
||||
} else {
|
||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||
}
|
||||
@@ -421,14 +436,9 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
|
||||
}
|
||||
|
||||
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
||||
// discord 服务器图片需要使用代理转发图片数据流
|
||||
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
|
||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
} else {
|
||||
job.ImgURL = job.OrgURL
|
||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -439,30 +449,56 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
|
||||
|
||||
// Remove remove task image
|
||||
func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
ImgURL string `json:"img_url"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
var job model.MidJourneyJob
|
||||
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// remove job recode
|
||||
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Delete(&job).Error; err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// refund power
|
||||
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
h.DB.Where("id = ?", job.UserId).First(&user)
|
||||
err = tx.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: job.Power,
|
||||
Balance: user.Power + job.Power,
|
||||
Mark: types.PowerAdd,
|
||||
Model: "mid-journey",
|
||||
Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||
CreatedAt: time.Now(),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
// remove image
|
||||
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
|
||||
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.pool.Clients.Get(data.UserId)
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -472,17 +508,12 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
|
||||
// Publish 发布图片到画廊显示
|
||||
func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
res := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -20,23 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// List 订单列表
|
||||
func (h *OrderHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
userId := h.GetLoginUserId(c)
|
||||
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
|
||||
var total int64
|
||||
session.Model(&model.Order{}).Count(&total)
|
||||
var items []model.Order
|
||||
var list = make([]vo.Order, 0)
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||
offset := (page - 1) * pageSize
|
||||
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var order vo.Order
|
||||
@@ -51,5 +54,35 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||
}
|
||||
|
||||
// Query 查询订单状态
|
||||
func (h *OrderHandler) Query(c *gin.Context) {
|
||||
orderNo := h.GetTrim(c, "order_no")
|
||||
var order model.Order
|
||||
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Order not found")
|
||||
return
|
||||
}
|
||||
|
||||
if order.Status == types.OrderPaidSuccess {
|
||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||
return
|
||||
}
|
||||
|
||||
counter := 0
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
var item model.Order
|
||||
h.DB.Where("order_no = ?", orderNo).First(&item)
|
||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
||||
order.Status = item.Status
|
||||
break
|
||||
}
|
||||
counter++
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/payment"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"embed"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/payment"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/shopspring/decimal"
|
||||
"math"
|
||||
"net/http"
|
||||
@@ -22,48 +29,73 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
PayWayAlipay = "支付宝"
|
||||
PayWayXunHu = "虎皮椒"
|
||||
PayWayJs = "PayJS"
|
||||
type PayWay struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
var (
|
||||
PayWayAlipay = PayWay{Name: "支付宝", Value: "alipay"}
|
||||
PayWayXunHu = PayWay{Name: "虎皮椒", Value: "hupi"}
|
||||
PayWayJs = PayWay{Name: "PayJS", Value: "payjs"}
|
||||
PayWayWechat = PayWay{Name: "微信支付", Value: "wechat"}
|
||||
)
|
||||
|
||||
// PaymentHandler 支付服务回调 handler
|
||||
type PaymentHandler struct {
|
||||
BaseHandler
|
||||
alipayService *payment.AlipayService
|
||||
huPiPayService *payment.HuPiPayService
|
||||
js *payment.PayJS
|
||||
snowflake *service.Snowflake
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
alipayService *payment.AlipayService
|
||||
huPiPayService *payment.HuPiPayService
|
||||
jsPayService *payment.JPayService
|
||||
wechatPayService *payment.WechatPayService
|
||||
snowflake *service.Snowflake
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
signKey string // 用来签名的随机秘钥
|
||||
}
|
||||
|
||||
func NewPaymentHandler(
|
||||
server *core.AppServer,
|
||||
alipayService *payment.AlipayService,
|
||||
huPiPayService *payment.HuPiPayService,
|
||||
js *payment.PayJS,
|
||||
jsPayService *payment.JPayService,
|
||||
wechatPayService *payment.WechatPayService,
|
||||
db *gorm.DB,
|
||||
snowflake *service.Snowflake,
|
||||
fs embed.FS) *PaymentHandler {
|
||||
return &PaymentHandler{
|
||||
alipayService: alipayService,
|
||||
huPiPayService: huPiPayService,
|
||||
js: js,
|
||||
snowflake: snowflake,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
alipayService: alipayService,
|
||||
huPiPayService: huPiPayService,
|
||||
jsPayService: jsPayService,
|
||||
wechatPayService: wechatPayService,
|
||||
snowflake: snowflake,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
BaseHandler: BaseHandler{
|
||||
App: server,
|
||||
DB: db,
|
||||
},
|
||||
signKey: utils.RandString(32),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
||||
orderNo := h.GetTrim(c, "order_no")
|
||||
payWay := h.GetTrim(c, "pay_way")
|
||||
t := h.GetInt(c, "t", 0)
|
||||
sign := h.GetTrim(c, "sign")
|
||||
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey)
|
||||
newSign := utils.Sha256(signStr)
|
||||
if newSign != sign {
|
||||
resp.ERROR(c, "订单签名错误!")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查二维码是否过期
|
||||
if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) {
|
||||
resp.ERROR(c, "支付二维码已过期,请重新生成!")
|
||||
return
|
||||
}
|
||||
|
||||
if orderNo == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -79,19 +111,16 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
|
||||
|
||||
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
|
||||
if order.Status == types.OrderPaidSuccess {
|
||||
resp.ERROR(c, "This order had been paid, please do not pay twice")
|
||||
resp.ERROR(c, "订单已支付成功,无需重复支付!")
|
||||
return
|
||||
}
|
||||
|
||||
// 更新扫码状态
|
||||
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
|
||||
if payWay == "alipay" { // 支付宝
|
||||
// 生成支付链接
|
||||
notifyURL := h.App.Config.AlipayConfig.NotifyURL
|
||||
returnURL := "" // 关闭同步回跳
|
||||
amount := fmt.Sprintf("%.2f", order.Amount)
|
||||
|
||||
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
|
||||
if payWay == "alipay" { // 支付宝
|
||||
amount := fmt.Sprintf("%.2f", order.Amount)
|
||||
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, amount, order.Subject)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||
return
|
||||
@@ -119,49 +148,11 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
|
||||
resp.ERROR(c, "Invalid operations")
|
||||
}
|
||||
|
||||
// OrderQuery 查询订单状态
|
||||
func (h *PaymentHandler) OrderQuery(c *gin.Context) {
|
||||
var data struct {
|
||||
OrderNo string `json:"order_no"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var order model.Order
|
||||
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Order not found")
|
||||
return
|
||||
}
|
||||
|
||||
if order.Status == types.OrderPaidSuccess {
|
||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||
return
|
||||
}
|
||||
|
||||
counter := 0
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
var item model.Order
|
||||
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
|
||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
||||
order.Status = item.Status
|
||||
break
|
||||
}
|
||||
counter++
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||
}
|
||||
|
||||
// PayQrcode 生成支付 URL 二维码
|
||||
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way"` // 支付方式
|
||||
ProductId uint `json:"product_id"`
|
||||
UserId int `json:"user_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -180,10 +171,9 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
res = h.DB.First(&user, data.UserId)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Invalid user ID")
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -191,14 +181,21 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
var notifyURL string
|
||||
switch data.PayWay {
|
||||
case "hupi":
|
||||
payWay = PayWayXunHu
|
||||
payWay = PayWayXunHu.Value
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
break
|
||||
case "payjs":
|
||||
payWay = PayWayJs
|
||||
payWay = PayWayJs.Value
|
||||
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||
default:
|
||||
payWay = PayWayAlipay
|
||||
break
|
||||
case "alipay":
|
||||
payWay = PayWayAlipay.Value
|
||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||
break
|
||||
default:
|
||||
payWay = PayWayWechat.Value
|
||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
||||
|
||||
}
|
||||
// 创建订单
|
||||
remark := types.OrderRemark{
|
||||
@@ -234,7 +231,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
OutTradeNo: order.OrderNo,
|
||||
Subject: product.Name,
|
||||
}
|
||||
r := h.js.Pay(params)
|
||||
r := h.jsPayService.Pay(params)
|
||||
if r.IsOK() {
|
||||
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
|
||||
return
|
||||
@@ -253,6 +250,8 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
} else {
|
||||
logo = "res/img/alipay.jpg"
|
||||
}
|
||||
} else if data.PayWay == "wechat" {
|
||||
logo = "res/img/wechat-pay.jpg"
|
||||
}
|
||||
|
||||
file, err := h.fs.Open(logo)
|
||||
@@ -266,8 +265,21 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
|
||||
timestamp := time.Now().Unix()
|
||||
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
|
||||
sign := utils.Sha256(signStr)
|
||||
var imageURL string
|
||||
if data.PayWay == "wechat" {
|
||||
payUrl, err := h.wechatPayService.PayUrlNative(order.OrderNo, int(math.Floor(order.Amount*100)), product.Name)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generating wechat payment qrcode: "+err.Error())
|
||||
return
|
||||
} else {
|
||||
imageURL = payUrl
|
||||
}
|
||||
} else {
|
||||
imageURL = fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
|
||||
}
|
||||
imgData, err := utils.GenQrcode(imageURL, 400, file)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
@@ -282,7 +294,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way"` // 支付方式
|
||||
ProductId uint `json:"product_id"`
|
||||
UserId int `json:"user_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -301,10 +312,9 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
res = h.DB.First(&user, data.UserId)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Invalid user ID")
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -314,9 +324,11 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
var payURL string
|
||||
switch data.PayWay {
|
||||
case "hupi":
|
||||
payWay = PayWayXunHu
|
||||
payWay = PayWayXunHu.Name
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||
parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
|
||||
baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
|
||||
params := payment.HuPiPayReq{
|
||||
Version: "1.1",
|
||||
TradeOrderId: orderNo,
|
||||
@@ -326,16 +338,19 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
ReturnURL: returnURL,
|
||||
CallbackURL: returnURL,
|
||||
WapName: "极客学长",
|
||||
WapUrl: baseURL,
|
||||
Type: "WAP",
|
||||
}
|
||||
r, err := h.huPiPayService.Pay(params)
|
||||
if err != nil {
|
||||
logger.Error("error with generating Pay URL: ", err.Error())
|
||||
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
|
||||
errMsg := "error with generating Pay Hupi URL: " + err.Error()
|
||||
logger.Error(errMsg)
|
||||
resp.ERROR(c, errMsg)
|
||||
return
|
||||
}
|
||||
payURL = r.URL
|
||||
case "payjs":
|
||||
payWay = PayWayJs
|
||||
payWay = PayWayJs.Name
|
||||
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||
returnURL = h.App.Config.JPayConfig.ReturnURL
|
||||
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
|
||||
@@ -345,14 +360,22 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
params.Add("body", product.Name)
|
||||
params.Add("notify_url", notifyURL)
|
||||
params.Add("auto", "0")
|
||||
payURL = h.js.PayH5(params)
|
||||
payURL = h.jsPayService.PayH5(params)
|
||||
case "alipay":
|
||||
payWay = PayWayAlipay
|
||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
|
||||
payWay = PayWayAlipay.Name
|
||||
payURL, err = h.alipayService.PayUrlMobile(orderNo, fmt.Sprintf("%.2f", amount), product.Name)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
|
||||
errMsg := "error with generating Alipay URL: " + err.Error()
|
||||
resp.ERROR(c, errMsg)
|
||||
return
|
||||
}
|
||||
case "wechat":
|
||||
payWay = PayWayWechat.Name
|
||||
payURL, err = h.wechatPayService.PayUrlH5(orderNo, int(amount*100), product.Name, c.ClientIP())
|
||||
if err != nil {
|
||||
errMsg := "error with generating Wechat URL: " + err.Error()
|
||||
logger.Error(errMsg)
|
||||
resp.ERROR(c, errMsg)
|
||||
return
|
||||
}
|
||||
default:
|
||||
@@ -385,7 +408,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, payURL)
|
||||
resp.SUCCESS(c, gin.H{"url": payURL, "order_no": orderNo})
|
||||
}
|
||||
|
||||
// 异步通知回调公共逻辑
|
||||
@@ -424,27 +447,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
|
||||
var opt string
|
||||
var power int
|
||||
if user.Vip { // 已经是 VIP 用户
|
||||
if remark.Days > 0 { // 只延期 VIP,不增加调用次数
|
||||
if remark.Days > 0 { // VIP 充值
|
||||
if user.ExpiredTime >= time.Now().Unix() {
|
||||
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
||||
} else { // 充值点卡,直接增加次数即可
|
||||
user.Power += remark.Power
|
||||
opt = "点卡充值"
|
||||
power = remark.Power
|
||||
}
|
||||
|
||||
} else { // 非 VIP 用户
|
||||
if remark.Days > 0 { // vip 套餐:days > 0, power == 0
|
||||
opt = "VIP充值,VIP 没到期,只延期不增加算力"
|
||||
} else {
|
||||
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
||||
user.Power += h.App.SysConfig.VipMonthPower
|
||||
user.Vip = true
|
||||
opt = "VIP充值"
|
||||
power = h.App.SysConfig.VipMonthPower
|
||||
} else { //点卡:days == 0, calls > 0
|
||||
user.Power += remark.Power
|
||||
opt = "点卡充值"
|
||||
power = remark.Power
|
||||
opt = "VIP充值"
|
||||
}
|
||||
user.Vip = true
|
||||
} else { // 充值点卡,直接增加次数即可
|
||||
user.Power += remark.Power
|
||||
opt = "点卡充值"
|
||||
power = remark.Power
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
@@ -470,7 +487,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
||||
|
||||
// 记录算力充值日志
|
||||
if opt != "" {
|
||||
if power > 0 {
|
||||
h.DB.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
@@ -499,6 +516,9 @@ func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
||||
if h.App.Config.JPayConfig.Enabled {
|
||||
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
|
||||
}
|
||||
if h.App.Config.WechatPayConfig.Enabled {
|
||||
data["wechat"] = gin.H{"name": "wechat"}
|
||||
}
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
@@ -537,7 +557,7 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
}
|
||||
|
||||
// TODO:验证交易签名
|
||||
res := h.alipayService.TradeVerify(c.Request.Form)
|
||||
res := h.alipayService.TradeVerify(c.Request)
|
||||
logger.Infof("验证支付结果:%+v", res)
|
||||
if !res.Success() {
|
||||
logger.Error("订单校验失败:", res.Message)
|
||||
@@ -565,7 +585,7 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
|
||||
|
||||
orderNo := c.Request.Form.Get("out_trade_no")
|
||||
returnCode := c.Request.Form.Get("return_code")
|
||||
logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
|
||||
logger.Infof("收到PayJs订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
|
||||
// 支付失败
|
||||
if returnCode != "1" {
|
||||
return
|
||||
@@ -573,7 +593,7 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
|
||||
|
||||
// 校验订单支付状态
|
||||
tradeNo := c.Request.Form.Get("payjs_order_id")
|
||||
err = h.js.Check(tradeNo)
|
||||
err = h.jsPayService.TradeVerify(tradeNo)
|
||||
if err != nil {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
@@ -588,3 +608,30 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
|
||||
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// WechatPayNotify 微信商户支付异步回调
|
||||
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
result := h.wechatPayService.TradeVerify(c.Request)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": "FAIL",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.notify(result.OutTradeNo, result.TradeId)
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"math"
|
||||
@@ -50,12 +57,12 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
||||
var item model.Reward
|
||||
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "无效的众筹交易流水号!")
|
||||
resp.ERROR(c, "无效的交易流水号!")
|
||||
return
|
||||
}
|
||||
|
||||
if item.Status {
|
||||
resp.ERROR(c, "当前众筹交易流水号已经被核销,请不要重复核销!")
|
||||
resp.ERROR(c, "当前交易流水号已经被核销,请不要重复核销!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -66,6 +73,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
||||
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
|
||||
if res.Error != nil {
|
||||
tx.Rollback()
|
||||
logger.Error("添加应用失败:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -77,6 +85,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
||||
res = tx.Updates(&item)
|
||||
if res.Error != nil {
|
||||
tx.Rollback()
|
||||
logger.Error("添加应用失败:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败!")
|
||||
return
|
||||
}
|
||||
@@ -90,7 +99,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
||||
Balance: user.Power + exchange.Power,
|
||||
Mark: types.PowerAdd,
|
||||
Model: "众筹支付",
|
||||
Remark: fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
tx.Commit()
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/service/sd"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -23,15 +31,19 @@ import (
|
||||
|
||||
type SdJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
pool *sd.ServicePool
|
||||
uploader *oss.UploaderManager
|
||||
redis *redis.Client
|
||||
pool *sd.ServicePool
|
||||
uploader *oss.UploaderManager
|
||||
snowflake *service.Snowflake
|
||||
leveldb *store.LevelDB
|
||||
}
|
||||
|
||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
|
||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
||||
return &SdJobHandler{
|
||||
pool: pool,
|
||||
uploader: manager,
|
||||
pool: pool,
|
||||
uploader: manager,
|
||||
snowflake: snowflake,
|
||||
leveldb: levelDB,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -60,7 +72,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
@@ -83,14 +95,11 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
if !h.checkLimits(c) {
|
||||
if !h.preCheck(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
SessionId string `json:"session_id"`
|
||||
types.SdTaskParams
|
||||
}
|
||||
var data types.SdTaskParams
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
@@ -116,22 +125,27 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
taskId, err := h.snowflake.Next(true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate task id: "+err.Error())
|
||||
return
|
||||
}
|
||||
params := types.SdTaskParams{
|
||||
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
|
||||
Prompt: data.Prompt,
|
||||
NegativePrompt: data.NegativePrompt,
|
||||
Steps: data.Steps,
|
||||
Sampler: data.Sampler,
|
||||
FaceFix: data.FaceFix,
|
||||
CfgScale: data.CfgScale,
|
||||
Seed: data.Seed,
|
||||
Height: data.Height,
|
||||
Width: data.Width,
|
||||
HdFix: data.HdFix,
|
||||
HdRedrawRate: data.HdRedrawRate,
|
||||
HdScale: data.HdScale,
|
||||
HdScaleAlg: data.HdScaleAlg,
|
||||
HdSteps: data.HdSteps,
|
||||
TaskId: taskId,
|
||||
Prompt: data.Prompt,
|
||||
NegPrompt: data.NegPrompt,
|
||||
Steps: data.Steps,
|
||||
Sampler: data.Sampler,
|
||||
FaceFix: data.FaceFix,
|
||||
CfgScale: data.CfgScale,
|
||||
Seed: data.Seed,
|
||||
Height: data.Height,
|
||||
Width: data.Width,
|
||||
HdFix: data.HdFix,
|
||||
HdRedrawRate: data.HdRedrawRate,
|
||||
HdScale: data.HdScale,
|
||||
HdScaleAlg: data.HdScaleAlg,
|
||||
HdSteps: data.HdSteps,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
@@ -151,11 +165,10 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.SdTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
Id: int(job.Id),
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
@@ -199,13 +212,13 @@ func (h *SdJobHandler) ImgWall(c *gin.Context) {
|
||||
|
||||
// JobList 获取 SD 任务列表
|
||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||
status := h.GetBool(c, "status")
|
||||
finish := h.GetBool(c, "finish")
|
||||
userId := h.GetLoginUserId(c)
|
||||
page := h.GetInt(c, "page", 0)
|
||||
pageSize := h.GetInt(c, "page_size", 0)
|
||||
publish := h.GetBool(c, "publish")
|
||||
|
||||
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||
err, jobs := h.getData(finish, userId, page, pageSize, publish)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -249,10 +262,11 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
||||
}
|
||||
|
||||
if item.Progress < 100 {
|
||||
// 正在运行中任务使用代理访问图片
|
||||
image, err := utils.DownloadImage(item.ImgURL, "")
|
||||
// 从 leveldb 中获取图片预览数据
|
||||
var imageData string
|
||||
err = h.leveldb.Get(item.TaskId, &imageData)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
job.ImgURL = "data:image/png;base64," + imageData
|
||||
}
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
@@ -263,32 +277,30 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
||||
|
||||
// Remove remove task image
|
||||
func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
ImgURL string `json:"img_url"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
var job model.SdJob
|
||||
if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// remove job recode
|
||||
res := h.DB.Delete(&model.SdJob{Id: data.Id})
|
||||
res := h.DB.Delete(&model.SdJob{Id: job.Id})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// remove image
|
||||
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
|
||||
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.pool.Clients.Get(data.UserId)
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
_ = client.Send([]byte(sd.Finished))
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
@@ -296,17 +308,13 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||
|
||||
// Publish 发布/取消发布图片到画廊显示
|
||||
func (h *SdJobHandler) Publish(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
|
||||
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
|
||||
res := h.DB.Model(&model.SdJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/sms"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/sms"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -42,14 +49,20 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
var data struct {
|
||||
Receiver string `json:"receiver"` // 接收者
|
||||
Key string `json:"key"`
|
||||
Dots string `json:"dots"`
|
||||
Dots string `json:"dots,omitempty"`
|
||||
X int `json:"x,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if !h.captcha.Check(data) {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "验证码错误,请先完人机验证")
|
||||
return
|
||||
}
|
||||
@@ -57,13 +70,13 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
code := utils.RandomNumber(6)
|
||||
var err error
|
||||
if strings.Contains(data.Receiver, "@") { // email
|
||||
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
|
||||
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||
return
|
||||
}
|
||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||
} else {
|
||||
if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
||||
resp.ERROR(c, "系统已禁用手机号注册!")
|
||||
return
|
||||
}
|
||||
@@ -82,5 +95,9 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
if h.App.Debug {
|
||||
resp.SUCCESS(c, code)
|
||||
} else {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
}
|
||||
|
||||
345
api/handler/suno_handler.go
Normal file
345
api/handler/suno_handler.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/suno"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SunoHandler struct {
|
||||
BaseHandler
|
||||
service *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler {
|
||||
return &SunoHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
service: service,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *SunoHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.service.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Instrumental bool `json:"instrumental"`
|
||||
Lyrics string `json:"lyrics"`
|
||||
Model string `json:"model"`
|
||||
Tags string `json:"tags"`
|
||||
Title string `json:"title"`
|
||||
Type int `json:"type"`
|
||||
RefTaskId string `json:"ref_task_id"` // 续写的任务id
|
||||
ExtendSecs int `json:"extend_secs"` // 续写秒数
|
||||
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 插入数据库
|
||||
job := model.SunoJob{
|
||||
UserId: int(h.GetLoginUserId(c)),
|
||||
Prompt: data.Prompt,
|
||||
Instrumental: data.Instrumental,
|
||||
ModelName: data.Model,
|
||||
Tags: data.Tags,
|
||||
Title: data.Title,
|
||||
Type: data.Type,
|
||||
RefSongId: data.RefSongId,
|
||||
RefTaskId: data.RefTaskId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Power: h.App.SysConfig.SunoPower,
|
||||
}
|
||||
if data.Lyrics != "" {
|
||||
job.Prompt = data.Lyrics
|
||||
}
|
||||
tx := h.DB.Create(&job)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, tx.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
h.service.PushTask(types.SunoTask{
|
||||
Id: job.Id,
|
||||
UserId: job.UserId,
|
||||
Type: job.Type,
|
||||
Title: job.Title,
|
||||
RefTaskId: data.RefTaskId,
|
||||
RefSongId: data.RefSongId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Prompt: job.Prompt,
|
||||
Tags: data.Tags,
|
||||
Model: data.Model,
|
||||
Instrumental: data.Instrumental,
|
||||
})
|
||||
|
||||
// update user's power
|
||||
tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||
// 记录算力变化日志
|
||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||
user, _ := h.GetLoginUser(c)
|
||||
h.DB.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: job.Power,
|
||||
Balance: user.Power - job.Power,
|
||||
Mark: types.PowerSub,
|
||||
Model: job.ModelName,
|
||||
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *SunoHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
page := h.GetInt(c, "page", 0)
|
||||
pageSize := h.GetInt(c, "page_size", 0)
|
||||
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
||||
|
||||
// 统计总数
|
||||
var total int64
|
||||
session.Model(&model.SunoJob{}).Count(&total)
|
||||
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
session = session.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
var list []model.SunoJob
|
||||
err := session.Order("id desc").Find(&list).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 初始化续写关系
|
||||
songIds := make([]string, 0)
|
||||
for _, v := range list {
|
||||
if v.RefTaskId != "" {
|
||||
songIds = append(songIds, v.RefSongId)
|
||||
}
|
||||
}
|
||||
var tasks []model.SunoJob
|
||||
h.DB.Where("song_id IN ?", songIds).Find(&tasks)
|
||||
songMap := make(map[string]model.SunoJob)
|
||||
for _, t := range tasks {
|
||||
songMap[t.SongId] = t
|
||||
}
|
||||
// 转换为 VO
|
||||
items := make([]vo.SunoJob, 0)
|
||||
for _, v := range list {
|
||||
var item vo.SunoJob
|
||||
err = utils.CopyObject(v, &item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
item.CreatedAt = v.CreatedAt.Unix()
|
||||
if s, ok := songMap[v.RefSongId]; ok {
|
||||
item.RefSong = map[string]interface{}{
|
||||
"id": s.Id,
|
||||
"title": s.Title,
|
||||
"cover": s.CoverURL,
|
||||
"audio": s.AudioURL,
|
||||
}
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
var job model.SunoJob
|
||||
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 删除任务
|
||||
h.DB.Delete(&job)
|
||||
// 删除文件
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Publish(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
publish := h.GetBool(c, "publish")
|
||||
err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Update(c *gin.Context) {
|
||||
var data struct {
|
||||
Id int `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Cover string `json:"cover"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Id == 0 || data.Title == "" || data.Cover == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetLoginUserId(c)
|
||||
var item model.SunoJob
|
||||
if err := h.DB.Where("id", data.Id).Where("user_id", userId).First(&item).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
item.Title = data.Title
|
||||
item.CoverURL = data.Cover
|
||||
|
||||
if err := h.DB.Updates(&item).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Detail 歌曲详情
|
||||
func (h *SunoHandler) Detail(c *gin.Context) {
|
||||
songId := c.Query("song_id")
|
||||
if songId == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
var item model.SunoJob
|
||||
if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 读取用户信息
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", item.UserId).First(&user).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var itemVo vo.SunoJob
|
||||
if err := utils.CopyObject(item, &itemVo); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
itemVo.CreatedAt = item.CreatedAt.Unix()
|
||||
itemVo.User = map[string]interface{}{
|
||||
"nickname": user.Nickname,
|
||||
"avatar": user.Avatar,
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, itemVo)
|
||||
}
|
||||
|
||||
// Play 增加歌曲播放次数
|
||||
func (h *SunoHandler) Play(c *gin.Context) {
|
||||
songId := c.Query("song_id")
|
||||
if songId == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
|
||||
}
|
||||
|
||||
const genLyricTemplate = `
|
||||
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
|
||||
请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。
|
||||
输出格式如下:
|
||||
歌曲名称
|
||||
第一节:
|
||||
{{歌词内容}}
|
||||
副歌:
|
||||
{{歌词内容}}
|
||||
|
||||
第二节:
|
||||
{{歌词内容}}
|
||||
副歌:
|
||||
{{歌词内容}}
|
||||
|
||||
尾声:
|
||||
{{歌词内容}}
|
||||
`
|
||||
|
||||
// Lyric 生成歌词
|
||||
func (h *SunoHandler) Lyric(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"chatplus/service"
|
||||
"chatplus/service/payment"
|
||||
"geekai/service"
|
||||
"geekai/service/payment"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TestHandler struct {
|
||||
db *gorm.DB
|
||||
snowflake *service.Snowflake
|
||||
js *payment.PayJS
|
||||
js *payment.JPayService
|
||||
}
|
||||
|
||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
|
||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.JPayService) *TestHandler {
|
||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
@@ -28,6 +36,12 @@ func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("upload file: ", file.Name)
|
||||
// cut the file name if it's too long
|
||||
if len(file.Name) > 100 {
|
||||
file.Name = file.Name[:90] + file.Ext
|
||||
}
|
||||
|
||||
userId := h.GetLoginUserId(c)
|
||||
res := h.DB.Create(&model.File{
|
||||
UserId: int(userId),
|
||||
@@ -47,10 +61,23 @@ func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *UploadHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Urls []string `json:"urls,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetLoginUserId(c)
|
||||
var items []model.File
|
||||
var files = make([]vo.File, 0)
|
||||
h.DB.Where("user_id = ?", userId).Find(&items)
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
session = session.Where("user_id = ?", userId)
|
||||
if len(data.Urls) > 0 {
|
||||
session = session.Where("url IN ?", data.Urls)
|
||||
}
|
||||
session.Find(&items)
|
||||
if len(items) > 0 {
|
||||
for _, v := range items {
|
||||
var file vo.File
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/imroc/req/v3"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,16 +30,23 @@ import (
|
||||
|
||||
type UserHandler struct {
|
||||
BaseHandler
|
||||
searcher *xdb.Searcher
|
||||
redis *redis.Client
|
||||
searcher *xdb.Searcher
|
||||
redis *redis.Client
|
||||
licenseService *service.LicenseService
|
||||
}
|
||||
|
||||
func NewUserHandler(
|
||||
app *core.AppServer,
|
||||
db *gorm.DB,
|
||||
searcher *xdb.Searcher,
|
||||
client *redis.Client) *UserHandler {
|
||||
return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
|
||||
client *redis.Client,
|
||||
licenseService *service.LicenseService) *UserHandler {
|
||||
return &UserHandler{
|
||||
BaseHandler: BaseHandler{DB: db, App: app},
|
||||
searcher: searcher,
|
||||
redis: client,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register user register
|
||||
@@ -53,9 +69,17 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检测最大注册人数
|
||||
var totalUser int64
|
||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查验证码
|
||||
var key string
|
||||
if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
|
||||
if data.RegWay == "email" || data.RegWay == "mobile" {
|
||||
key = CodeStorePrefix + data.Username
|
||||
code, err := h.redis.Get(c, key).Result()
|
||||
if err != nil || code != data.Code {
|
||||
@@ -74,7 +98,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// check if the username is exists
|
||||
// check if the username is existing
|
||||
var item model.User
|
||||
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||||
if item.Id > 0 {
|
||||
@@ -86,7 +110,6 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
user := model.User{
|
||||
Username: data.Username,
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
@@ -95,6 +118,16 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
}
|
||||
|
||||
// 被邀请人也获得赠送算力
|
||||
if data.InviteCode != "" {
|
||||
user.Power += h.App.SysConfig.InvitePower
|
||||
}
|
||||
if h.licenseService.GetLicense().Configs.DeCopy {
|
||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
} else {
|
||||
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
||||
}
|
||||
|
||||
res = h.DB.Create(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "保存数据失败")
|
||||
@@ -152,7 +185,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, tokenString)
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
@@ -211,26 +244,142 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, tokenString)
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
// Logout 注 销
|
||||
func (h *UserHandler) Logout(c *gin.Context) {
|
||||
sessionId := c.GetHeader(types.ChatTokenHeader)
|
||||
key := h.GetUserKey(c)
|
||||
if _, err := h.redis.Del(c, key).Result(); err != nil {
|
||||
logger.Error("error with delete session: ", err)
|
||||
}
|
||||
// 删除 websocket 会话列表
|
||||
h.App.ChatSession.Delete(sessionId)
|
||||
// 关闭 socket 连接
|
||||
client := h.App.ChatClients.Get(sessionId)
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// CLogin 第三方登录请求二维码
|
||||
func (h *UserHandler) CLogin(c *gin.Context) {
|
||||
returnURL := h.GetTrim(c, "return_url")
|
||||
var res types.BizVo
|
||||
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
|
||||
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
|
||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
||||
return
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
resp.ERROR(c, "error with http response: "+res.Message)
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, res.Data)
|
||||
}
|
||||
|
||||
// CLoginCallback 第三方登录回调
|
||||
func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
||||
loginType := h.GetTrim(c, "login_type")
|
||||
code := h.GetTrim(c, "code")
|
||||
|
||||
var res types.BizVo
|
||||
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
|
||||
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
|
||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
||||
return
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
resp.ERROR(c, "error with http response: "+res.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// login successfully
|
||||
data := res.Data.(map[string]interface{})
|
||||
session := gin.H{}
|
||||
var user model.User
|
||||
tx := h.DB.Debug().Where("openid", data["openid"]).First(&user)
|
||||
if tx.Error != nil { // user not exist, create new user
|
||||
// 检测最大注册人数
|
||||
var totalUser int64
|
||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||
return
|
||||
}
|
||||
|
||||
salt := utils.RandString(8)
|
||||
password := fmt.Sprintf("%d", utils.RandomNumber(8))
|
||||
user = model.User{
|
||||
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
|
||||
Password: utils.GenPassword(password, salt),
|
||||
Avatar: fmt.Sprintf("%s", data["avatar"]),
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
OpenId: fmt.Sprintf("%s", data["openid"]),
|
||||
Nickname: fmt.Sprintf("%s", data["nickname"]),
|
||||
}
|
||||
|
||||
tx = h.DB.Create(&user)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, "保存数据失败")
|
||||
logger.Error(tx.Error)
|
||||
return
|
||||
}
|
||||
session["username"] = user.Username
|
||||
session["password"] = password
|
||||
} else { // login directly
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = c.ClientIP()
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
h.DB.Model(&user).Updates(user)
|
||||
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: c.ClientIP(),
|
||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
||||
})
|
||||
}
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
}
|
||||
// 保存到 redis
|
||||
key := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
session["token"] = tokenString
|
||||
resp.SUCCESS(c, session)
|
||||
}
|
||||
|
||||
// Session 获取/验证会话
|
||||
func (h *UserHandler) Session(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
@@ -334,7 +483,7 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
|
||||
newPass := utils.GenPassword(data.Password, user.Salt)
|
||||
res := h.DB.Model(&user).UpdateColumn("password", newPass)
|
||||
if res.Error != nil {
|
||||
logger.Error("更新数据库失败: ", res.Error)
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "更新数据库失败")
|
||||
return
|
||||
}
|
||||
@@ -415,6 +564,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
|
||||
|
||||
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
|
||||
if res.Error != nil {
|
||||
logger.Error(res.Error)
|
||||
resp.ERROR(c, "更新数据库失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package logger
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
169
api/main.go
169
api/main.go
@@ -1,22 +1,31 @@
|
||||
package main
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/handler/admin"
|
||||
"chatplus/handler/chatimpl"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/service/payment"
|
||||
"chatplus/service/sd"
|
||||
"chatplus/service/sms"
|
||||
"chatplus/service/wx"
|
||||
"chatplus/store"
|
||||
"context"
|
||||
"embed"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/handler/admin"
|
||||
"geekai/handler/chatimpl"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
"geekai/service/sd"
|
||||
"geekai/service/sms"
|
||||
"geekai/service/suno"
|
||||
"geekai/service/wx"
|
||||
"geekai/store"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
@@ -43,16 +52,20 @@ type AppLifecycle struct {
|
||||
|
||||
// OnStart 应用程序启动时执行
|
||||
func (l *AppLifecycle) OnStart(context.Context) error {
|
||||
log.Println("AppLifecycle OnStart")
|
||||
logger.Info("AppLifecycle OnStart")
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStop 应用程序停止时执行
|
||||
func (l *AppLifecycle) OnStop(context.Context) error {
|
||||
log.Println("AppLifecycle OnStop")
|
||||
logger.Info("AppLifecycle OnStop")
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAppLifeCycle() *AppLifecycle {
|
||||
return &AppLifecycle{}
|
||||
}
|
||||
|
||||
func main() {
|
||||
configFile := os.Getenv("CONFIG_FILE")
|
||||
if configFile == "" {
|
||||
@@ -92,6 +105,7 @@ func main() {
|
||||
fx.Provide(store.NewGormConfig),
|
||||
fx.Provide(store.NewMysql),
|
||||
fx.Provide(store.NewRedisClient),
|
||||
fx.Provide(store.NewLevelDB),
|
||||
|
||||
fx.Provide(func() embed.FS {
|
||||
return xdbFS
|
||||
@@ -148,9 +162,21 @@ func main() {
|
||||
}),
|
||||
fx.Provide(oss.NewUploaderManager),
|
||||
fx.Provide(mj.NewService),
|
||||
fx.Provide(dalle.NewService),
|
||||
fx.Invoke(func(service *dalle.Service) {
|
||||
service.Run()
|
||||
service.CheckTaskNotify()
|
||||
service.DownloadImages()
|
||||
service.CheckTaskStatus()
|
||||
}),
|
||||
|
||||
// 邮件服务
|
||||
fx.Provide(service.NewSmtpService),
|
||||
// License 服务
|
||||
fx.Provide(service.NewLicenseService),
|
||||
fx.Invoke(func(licenseService *service.LicenseService) {
|
||||
licenseService.SyncLicense()
|
||||
}),
|
||||
|
||||
// 微信机器人服务
|
||||
fx.Provide(wx.NewWeChatBot),
|
||||
@@ -165,7 +191,8 @@ func main() {
|
||||
|
||||
// MidJourney service pool
|
||||
fx.Provide(mj.NewServicePool),
|
||||
fx.Invoke(func(pool *mj.ServicePool) {
|
||||
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
|
||||
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
|
||||
if pool.HasAvailableService() {
|
||||
pool.DownloadImages()
|
||||
pool.CheckTaskNotify()
|
||||
@@ -175,16 +202,26 @@ func main() {
|
||||
|
||||
// Stable Diffusion 机器人
|
||||
fx.Provide(sd.NewServicePool),
|
||||
fx.Invoke(func(pool *sd.ServicePool) {
|
||||
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
|
||||
pool.InitServices(config.SdConfigs)
|
||||
if pool.HasAvailableService() {
|
||||
pool.CheckTaskNotify()
|
||||
pool.CheckTaskStatus()
|
||||
}
|
||||
}),
|
||||
|
||||
fx.Provide(suno.NewService),
|
||||
fx.Invoke(func(s *suno.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadImages()
|
||||
}),
|
||||
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewHuPiPay),
|
||||
fx.Provide(payment.NewPayJS),
|
||||
fx.Provide(payment.NewJPayService),
|
||||
fx.Provide(payment.NewWechatService),
|
||||
fx.Provide(service.NewSnowflake),
|
||||
fx.Provide(service.NewXXLJobExecutor),
|
||||
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
|
||||
@@ -212,6 +249,8 @@ func main() {
|
||||
group.POST("password", h.UpdatePass)
|
||||
group.POST("bind/username", h.BindUsername)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
group.GET("clogin", h.CLogin)
|
||||
group.GET("clogin/callback", h.CLoginCallback)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
|
||||
group := s.Engine.Group("/api/chat/")
|
||||
@@ -227,7 +266,7 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
|
||||
s.Engine.POST("/api/upload", h.Upload)
|
||||
s.Engine.GET("/api/upload/list", h.List)
|
||||
s.Engine.POST("/api/upload/list", h.List)
|
||||
s.Engine.GET("/api/upload/remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
||||
@@ -253,8 +292,8 @@ func main() {
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.POST("remove", h.Remove)
|
||||
group.POST("publish", h.Publish)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||
group := s.Engine.Group("/api/sd")
|
||||
@@ -262,19 +301,24 @@ func main() {
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.POST("remove", h.Remove)
|
||||
group.POST("publish", h.Publish)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||
group := s.Engine.Group("/api/config/")
|
||||
group.GET("get", h.Get)
|
||||
group.GET("license", h.License)
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||
group := s.Engine.Group("/api/admin/config/")
|
||||
group.POST("update", h.Update)
|
||||
group.GET("get", h.Get)
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
group.POST("config/update", h.Update)
|
||||
group.GET("config/get", h.Get)
|
||||
group.POST("active", h.Active)
|
||||
group.GET("config/get/license", h.GetLicense)
|
||||
group.GET("config/get/app", h.GetAppConfig)
|
||||
group.POST("config/update/draw", h.SaveDrawingConfig)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
@@ -292,7 +336,7 @@ func main() {
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("remove", h.Remove)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
||||
group := s.Engine.Group("/api/admin/user/")
|
||||
@@ -308,7 +352,7 @@ func main() {
|
||||
group.POST("save", h.Save)
|
||||
group.POST("sort", h.Sort)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("remove", h.Remove)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
|
||||
group := s.Engine.Group("/api/admin/reward/")
|
||||
@@ -335,12 +379,12 @@ func main() {
|
||||
group := s.Engine.Group("/api/payment/")
|
||||
group.GET("doPay", h.DoPay)
|
||||
group.GET("payWays", h.GetPayWays)
|
||||
group.POST("query", h.OrderQuery)
|
||||
group.POST("qrcode", h.PayQrcode)
|
||||
group.POST("mobile", h.Mobile)
|
||||
group.POST("alipay/notify", h.AlipayNotify)
|
||||
group.POST("hupipay/notify", h.HuPiPayNotify)
|
||||
group.POST("payjs/notify", h.PayJsNotify)
|
||||
group.POST("wechat/notify", h.WechatPayNotify)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||
group := s.Engine.Group("/api/admin/product/")
|
||||
@@ -357,7 +401,8 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
||||
group := s.Engine.Group("/api/order/")
|
||||
group.POST("list", h.List)
|
||||
group.GET("list", h.List)
|
||||
group.GET("query", h.Query)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
|
||||
group := s.Engine.Group("/api/product/")
|
||||
@@ -382,13 +427,6 @@ func main() {
|
||||
group.GET("token", h.GenToken)
|
||||
}),
|
||||
|
||||
// 验证码
|
||||
fx.Provide(admin.NewCaptchaHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
|
||||
group := s.Engine.Group("/api/admin/login/")
|
||||
group.GET("captcha", h.GetCaptcha)
|
||||
}),
|
||||
|
||||
fx.Provide(admin.NewUploadHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
||||
s.Engine.POST("/api/admin/upload", h.Upload)
|
||||
@@ -417,12 +455,57 @@ func main() {
|
||||
group := s.Engine.Group("/api/admin/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
err := s.Run(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fx.Provide(admin.NewMenuHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
|
||||
group := s.Engine.Group("/api/admin/menu/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Provide(handler.NewMenuHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
|
||||
group := s.Engine.Group("/api/menu/")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Provide(handler.NewMarkMapHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||
group := s.Engine.Group("/api/markMap/")
|
||||
group.Any("client", h.Client)
|
||||
}),
|
||||
fx.Provide(handler.NewDallJobHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||
group := s.Engine.Group("/api/dall")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Provide(handler.NewSunoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
||||
group := s.Engine.Group("/api/suno")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("create", h.Create)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("detail", h.Detail)
|
||||
group.GET("play", h.Play)
|
||||
group.POST("lyric", h.Lyric)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
go func() {
|
||||
err := s.Run(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
}),
|
||||
fx.Provide(NewAppLifeCycle),
|
||||
// 注册生命周期回调函数
|
||||
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
|
||||
lifecycle.Append(fx.Hook{
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CaptchaService struct {
|
||||
config types.ChatPlusApiConfig
|
||||
config types.ApiConfig
|
||||
client *req.Client
|
||||
}
|
||||
|
||||
func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService {
|
||||
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
|
||||
return &CaptchaService{
|
||||
config: config,
|
||||
client: req.C().SetTimeout(10 * time.Second),
|
||||
|
||||
314
api/service/dalle/service.go
Normal file
314
api/service/dalle/service.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package dalle
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// DALL-E 绘画服务
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.DallTask) {
|
||||
logger.Infof("add a new DALL-E task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting DALL-E job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
var task types.DallTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type imgReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
Quality string `json:"quality"`
|
||||
Style string `json:"style"`
|
||||
}
|
||||
|
||||
type imgRes struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
RevisedPrompt string `json:"revised_prompt"`
|
||||
Url string `json:"url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type ErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
logger.Debugf("绘画参数:%+v", task)
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var user model.User
|
||||
s.db.Where("id", task.UserId).First(&user)
|
||||
if user.Power < task.Power {
|
||||
return "", errors.New("insufficient of power")
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
|
||||
// 记录算力变化日志
|
||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||
var u model.User
|
||||
s.db.Where("id", user.Id).First(&u)
|
||||
s.db.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: task.Power,
|
||||
Balance: u.Power,
|
||||
Mark: types.PowerSub,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// get image generation API KEY
|
||||
var apiKey model.ApiKey
|
||||
tx = s.db.Where("type", "dalle").
|
||||
Where("enabled", true).
|
||||
Order("last_used_at ASC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
|
||||
}
|
||||
|
||||
var res imgRes
|
||||
var errRes ErrRes
|
||||
if len(apiKey.ProxyURL) > 5 {
|
||||
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||
reqBody := imgReq{
|
||||
Model: "dall-e-3",
|
||||
Prompt: prompt,
|
||||
N: 1,
|
||||
Size: task.Size,
|
||||
Style: task.Style,
|
||||
Quality: task.Quality,
|
||||
}
|
||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with send request: %v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||
}
|
||||
// update the api key last use time
|
||||
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// update task progress
|
||||
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
||||
"progress": 100,
|
||||
"org_url": res.Data[0].Url,
|
||||
"prompt": prompt,
|
||||
})
|
||||
if tx.Error != nil {
|
||||
return "", fmt.Errorf("err with update database: %v", tx.Error)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task notify checking ...")
|
||||
for {
|
||||
var message sd.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
go func() {
|
||||
var items []model.DallJob
|
||||
for {
|
||||
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// download images
|
||||
for _, v := range items {
|
||||
if v.OrgURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: %s, error: %v", imgURL, err)
|
||||
continue
|
||||
} else {
|
||||
logger.Infof("download image %s successfully.", v.OrgURL)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// update img_url
|
||||
res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
func (s *Service) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||
for {
|
||||
var jobs []model.DallJob
|
||||
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
||||
if res.Error != nil {
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 5 分钟还没完成的任务直接删除
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
||||
s.db.Delete(&job)
|
||||
var user model.User
|
||||
s.db.Where("id = ?", job.UserId).First(&user)
|
||||
// 退回绘图次数
|
||||
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||
if res.Error == nil && res.RowsAffected > 0 {
|
||||
s.db.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: job.Power,
|
||||
Balance: user.Power + job.Power,
|
||||
Mark: types.PowerAdd,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
197
api/service/license_service.go
Normal file
197
api/service/license_service.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type LicenseService struct {
|
||||
config types.ApiConfig
|
||||
levelDB *store.LevelDB
|
||||
license *types.License
|
||||
urlWhiteList []string
|
||||
machineId string
|
||||
}
|
||||
|
||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||
var license types.License
|
||||
return &LicenseService{
|
||||
config: server.Config.ApiConfig,
|
||||
levelDB: levelDB,
|
||||
license: &license,
|
||||
machineId: "",
|
||||
}
|
||||
}
|
||||
|
||||
type License struct {
|
||||
Name string `json:"name"`
|
||||
License string `json:"license"`
|
||||
MachineId string `json:"mid"`
|
||||
ActiveAt int64 `json:"active_at"`
|
||||
ExpiredAt int64 `json:"expired_at"`
|
||||
UserNum int `json:"user_num"`
|
||||
Configs types.LicenseConfig `json:"configs"`
|
||||
}
|
||||
|
||||
// ActiveLicense 激活 License
|
||||
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": license, "machine_id": machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送激活请求失败: %v", err)
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
return fmt.Errorf("发送激活请求失败:%v", response.Status)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return fmt.Errorf("激活失败:%v", res.Message)
|
||||
}
|
||||
|
||||
s.license = &types.License{
|
||||
Key: license,
|
||||
MachineId: machineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
IsActive: true,
|
||||
}
|
||||
err = s.levelDB.Put(types.LicenseKey, s.license)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存许可证书失败:%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncLicense 定期同步 License
|
||||
func (s *LicenseService) SyncLicense() {
|
||||
go func() {
|
||||
retryCounter := 0
|
||||
for {
|
||||
license, err := s.fetchLicense()
|
||||
if err != nil {
|
||||
retryCounter++
|
||||
if retryCounter < 5 {
|
||||
logger.Warn(err)
|
||||
}
|
||||
s.license.IsActive = false
|
||||
} else {
|
||||
s.license = license
|
||||
}
|
||||
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
if err == nil {
|
||||
s.urlWhiteList = urls
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||
//var res struct {
|
||||
// Code types.BizCode `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// Data License `json:"data"`
|
||||
//}
|
||||
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
||||
//response, err := req.C().R().
|
||||
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
// SetSuccessResult(&res).Post(apiURL)
|
||||
//if err != nil {
|
||||
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
||||
//}
|
||||
//if response.IsErrorState() {
|
||||
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
||||
//}
|
||||
//if res.Code != types.Success {
|
||||
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
||||
//}
|
||||
|
||||
return &types.License{
|
||||
Key: "abc",
|
||||
MachineId: "abc",
|
||||
Configs: types.LicenseConfig{
|
||||
UserNum: 10000,
|
||||
DeCopy: false,
|
||||
},
|
||||
ExpiredAt: 0,
|
||||
IsActive: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data []string `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
|
||||
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %v", err)
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return nil, fmt.Errorf("发送请求失败:%v", response.Status)
|
||||
}
|
||||
if res.Code != types.Success {
|
||||
return nil, fmt.Errorf("获取白名单失败:%v", res.Message)
|
||||
}
|
||||
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
// GetLicense 获取许可信息
|
||||
func (s *LicenseService) GetLicense() *types.License {
|
||||
return s.license
|
||||
}
|
||||
|
||||
// IsValidApiURL 判断是否合法的中转 URL
|
||||
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||
// 获得许可授权的直接放行
|
||||
return nil
|
||||
//if s.license.IsActive {
|
||||
// if s.license.MachineId != s.machineId {
|
||||
// return errors.New("系统使用了盗版的许可证书")
|
||||
// }
|
||||
//
|
||||
// if time.Now().Unix() > s.license.ExpiredAt {
|
||||
// return errors.New("系统许可证书已经过期")
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//if len(s.urlWhiteList) == 0 {
|
||||
// urls, err := s.fetchUrlWhiteList()
|
||||
// if err == nil {
|
||||
// s.urlWhiteList = urls
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//for _, v := range s.urlWhiteList {
|
||||
// if strings.HasPrefix(uri, v) {
|
||||
// return nil
|
||||
// }
|
||||
//}
|
||||
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
}
|
||||
@@ -1,6 +1,13 @@
|
||||
package mj
|
||||
|
||||
import "chatplus/core/types"
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import "geekai/core/types"
|
||||
|
||||
type Client interface {
|
||||
Imagine(task types.MjTask) (ImageRes, error)
|
||||
|
||||
@@ -1,32 +1,59 @@
|
||||
package mj
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/utils"
|
||||
"github.com/imroc/req/v3"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PlusClient MidJourney Plus ProxyClient
|
||||
type PlusClient struct {
|
||||
Config types.MjPlusConfig
|
||||
apiURL string
|
||||
Config types.MjPlusConfig
|
||||
apiURL string
|
||||
client *req.Client
|
||||
licenseService *service.LicenseService
|
||||
}
|
||||
|
||||
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
|
||||
return &PlusClient{Config: config, apiURL: config.ApiURL}
|
||||
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
|
||||
return &PlusClient{
|
||||
Config: config,
|
||||
apiURL: config.ApiURL,
|
||||
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PlusClient) preCheck() error {
|
||||
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
|
||||
}
|
||||
|
||||
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
if err := c.preCheck(); err != nil {
|
||||
return ImageRes{}, err
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
||||
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||
if task.NegPrompt != "" {
|
||||
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
||||
}
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Prompt: task.Prompt,
|
||||
Prompt: prompt,
|
||||
Base64Array: make([]string, 0),
|
||||
}
|
||||
// 生成图片 Base64 编码
|
||||
@@ -39,30 +66,17 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
}
|
||||
|
||||
}
|
||||
logger.Info("API URL: ", apiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
return c.doRequest(body, apiURL)
|
||||
}
|
||||
|
||||
// Blend 融图
|
||||
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||
if err := c.preCheck(); err != nil {
|
||||
return ImageRes{}, err
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
||||
logger.Info("API URL: ", apiURL)
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Dimensions: "SQUARE",
|
||||
@@ -79,27 +93,15 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
return c.doRequest(body, apiURL)
|
||||
}
|
||||
|
||||
// SwapFace 换脸
|
||||
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||
if err := c.preCheck(); err != nil {
|
||||
return ImageRes{}, err
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
||||
// 生成图片 Base64 编码
|
||||
if len(task.ImgArr) != 2 {
|
||||
@@ -128,60 +130,42 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||
},
|
||||
"state": "",
|
||||
}
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
return c.doRequest(body, apiURL)
|
||||
}
|
||||
|
||||
// Upscale 放大指定的图片
|
||||
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||
if err := c.preCheck(); err != nil {
|
||||
return ImageRes{}, err
|
||||
}
|
||||
|
||||
body := map[string]string{
|
||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||
"taskId": task.MessageId,
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
|
||||
return c.doRequest(body, apiURL)
|
||||
}
|
||||
|
||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||
if err := c.preCheck(); err != nil {
|
||||
return ImageRes{}, err
|
||||
}
|
||||
|
||||
body := map[string]string{
|
||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||
"taskId": task.MessageId,
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
|
||||
|
||||
return c.doRequest(body, apiURL)
|
||||
}
|
||||
|
||||
func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) {
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
logger.Info("API URL: ", apiURL)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
@@ -202,7 +186,7 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||
var res QueryRes
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
package mj
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -21,41 +31,15 @@ type ServicePool struct {
|
||||
db *gorm.DB
|
||||
uploaderManager *oss.UploaderManager
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
licenseService *service.LicenseService
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
||||
|
||||
for k, config := range appConfig.MjPlusConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
cli := NewPlusClient(config)
|
||||
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
for k, config := range appConfig.MjProxyConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
cli := NewProxyClient(config)
|
||||
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
@@ -63,22 +47,59 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
uploaderManager: manager,
|
||||
db: db,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
|
||||
// stop old service
|
||||
for _, s := range p.services {
|
||||
s.Stop()
|
||||
}
|
||||
p.services = make([]*Service, 0)
|
||||
|
||||
for _, config := range plusConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
|
||||
cli := NewPlusClient(config, p.licenseService)
|
||||
name := utils.Md5(config.ApiURL)
|
||||
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
||||
go func() {
|
||||
plusService.Run()
|
||||
}()
|
||||
p.services = append(p.services, plusService)
|
||||
}
|
||||
|
||||
// for mid-journey proxy
|
||||
for _, config := range proxyConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
cli := NewProxyClient(config)
|
||||
name := utils.Md5(config.ApiURL)
|
||||
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
||||
go func() {
|
||||
proxyService.Run()
|
||||
}()
|
||||
p.services = append(p.services, proxyService)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ServicePool) CheckTaskNotify() {
|
||||
go func() {
|
||||
for {
|
||||
var userId uint
|
||||
err := p.notifyQueue.LPop(&userId)
|
||||
var message sd.NotifyMessage
|
||||
err := p.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cli := p.Clients.Get(userId)
|
||||
cli := p.Clients.Get(uint(message.UserId))
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = cli.Send([]byte("Task Updated"))
|
||||
err = cli.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -102,17 +123,23 @@ func (p *ServicePool) DownloadImages() {
|
||||
}
|
||||
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
var imgURL string
|
||||
var err error
|
||||
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
|
||||
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
||||
} else {
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
|
||||
mjService := p.getService(v.ChannelId)
|
||||
if mjService == nil {
|
||||
logger.Errorf("Invalid task: %+v", v)
|
||||
continue
|
||||
}
|
||||
|
||||
task, _ := mjService.Client.QueryTask(v.TaskId)
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
// 如果是返回的是 discord 图片地址,则使用代理下载
|
||||
proxy := false
|
||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||
proxy = true
|
||||
}
|
||||
imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||
continue
|
||||
@@ -127,7 +154,7 @@ func (p *ServicePool) DownloadImages() {
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = cli.Send([]byte("Task Updated"))
|
||||
err = cli.Send([]byte(sd.Finished))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -152,43 +179,20 @@ func (p *ServicePool) HasAvailableService() bool {
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (p *ServicePool) SyncTaskProgress() {
|
||||
go func() {
|
||||
var items []model.MidJourneyJob
|
||||
var jobs []model.MidJourneyJob
|
||||
for {
|
||||
res := p.db.Where("progress < ?", 100).Find(&items)
|
||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range items {
|
||||
// 失败或者 30 分钟还没完成的任务删除并退回算力
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
|
||||
// 删除任务
|
||||
p.db.Delete(&job)
|
||||
// 退回算力
|
||||
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||
var user model.User
|
||||
p.db.Where("id = ?", job.UserId).First(&user)
|
||||
p.db.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: job.Power,
|
||||
Balance: user.Power + job.Power,
|
||||
Mark: types.PowerAdd,
|
||||
Model: "mid-journey",
|
||||
Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
|
||||
_ = servicePlus.Notify(job)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package mj
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"github.com/imroc/req/v3"
|
||||
"io"
|
||||
)
|
||||
@@ -22,8 +29,12 @@ func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
|
||||
|
||||
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
|
||||
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||
if task.NegPrompt != "" {
|
||||
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
||||
}
|
||||
body := ImageReq{
|
||||
Prompt: task.Prompt,
|
||||
Prompt: prompt,
|
||||
Base64Array: make([]string, 0),
|
||||
}
|
||||
// 生成图片 Base64 编码
|
||||
@@ -46,8 +57,6 @@ func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
all, err := io.ReadAll(r.Body)
|
||||
logger.Info(string(all))
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
package mj
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -16,41 +23,32 @@ import (
|
||||
|
||||
// Service MJ 绘画服务
|
||||
type Service struct {
|
||||
Name string // service Name
|
||||
Client Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
HandledTaskNum int32 // already handled task number
|
||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||
taskTimeout int64
|
||||
Name string // service Name
|
||||
Client Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
running bool
|
||||
retryCount map[uint]int
|
||||
}
|
||||
|
||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
|
||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
|
||||
return &Service{
|
||||
Name: name,
|
||||
db: db,
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
Client: cli,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
taskStartTimes: make(map[int]time.Time, 0),
|
||||
Name: name,
|
||||
db: db,
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
Client: cli,
|
||||
running: true,
|
||||
retryCount: make(map[uint]int),
|
||||
}
|
||||
}
|
||||
|
||||
const failedProgress = 101
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||||
for {
|
||||
s.checkTasks()
|
||||
if !s.canHandleTask() {
|
||||
// current service is full, can not handle more task
|
||||
// waiting for running task finish
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
for s.running {
|
||||
var task types.MjTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
@@ -61,21 +59,42 @@ func (s *Service) Run() {
|
||||
// 如果配置了多个中转平台的 API KEY
|
||||
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
|
||||
if task.ChannelId != "" && task.ChannelId != s.Name {
|
||||
if s.retryCount[task.Id] > 5 {
|
||||
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
|
||||
continue
|
||||
}
|
||||
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
||||
s.taskQueue.RPush(task)
|
||||
s.retryCount[task.Id]++
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果是 mj-proxy 则自动翻译提示词
|
||||
if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var job model.MidJourneyJob
|
||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||
if tx.Error != nil {
|
||||
logger.Error("任务不存在,任务ID:", task.TaskId)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
||||
var res ImageRes
|
||||
@@ -97,46 +116,34 @@ func (s *Service) Run() {
|
||||
break
|
||||
}
|
||||
|
||||
var job model.MidJourneyJob
|
||||
s.db.Where("id = ?", task.Id).First(&job)
|
||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
||||
var errMsg string
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else {
|
||||
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
|
||||
}
|
||||
|
||||
logger.Error("绘画任务执行失败:", errMsg)
|
||||
job.Progress = -1
|
||||
job.Progress = failedProgress
|
||||
job.ErrMsg = errMsg
|
||||
// update the task progress
|
||||
s.db.Updates(&job)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
||||
// 更新任务 ID/频道
|
||||
job.TaskId = res.Result
|
||||
job.MessageId = res.Result
|
||||
job.ChannelId = s.Name
|
||||
s.db.Updates(&job)
|
||||
}
|
||||
}
|
||||
|
||||
// check if current service instance can handle more task
|
||||
func (s *Service) canHandleTask() bool {
|
||||
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
}
|
||||
|
||||
// remove the expired tasks
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||
// delete task from database
|
||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||
}
|
||||
}
|
||||
func (s *Service) Stop() {
|
||||
s.running = false
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
@@ -166,9 +173,10 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
// 任务执行失败了
|
||||
if task.FailReason != "" {
|
||||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"progress": failedProgress,
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||
}
|
||||
|
||||
@@ -181,18 +189,17 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
if task.ImageUrl != "" {
|
||||
job.OrgURL = task.ImageUrl
|
||||
}
|
||||
job.MessageId = task.Id
|
||||
tx := s.db.Updates(&job)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||
}
|
||||
if task.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||
}
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
message := sd.Running
|
||||
if job.Progress == 100 {
|
||||
message = sd.Finished
|
||||
}
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package oss
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -77,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
fileData, err = utils.DownloadImage(fileURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
parse, err := url.Parse(imageURL)
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
package oss
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -50,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(imageURL)
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
@@ -62,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
}
|
||||
|
||||
if useProxy {
|
||||
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
|
||||
err = utils.DownloadFile(fileURL, filePath, s.proxyURL)
|
||||
} else {
|
||||
err = utils.DownloadFile(imageURL, filePath, "")
|
||||
err = utils.DownloadFile(fileURL, filePath, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package oss
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -37,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
fileData, err = utils.DownloadImage(fileURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
parse, err := url.Parse(imageURL)
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
@@ -58,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
filename,
|
||||
strings.NewReader(string(imageData)),
|
||||
int64(len(imageData)),
|
||||
strings.NewReader(string(fileData)),
|
||||
int64(len(fileData)),
|
||||
minio.PutObjectOptions{ContentType: "image/png"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package oss
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -86,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
fileData, err = utils.DownloadImage(fileURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
parse, err := url.Parse(imageURL)
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
@@ -106,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra)
|
||||
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package oss
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
const Local = "LOCAL"
|
||||
@@ -16,7 +23,7 @@ type File struct {
|
||||
}
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (File, error)
|
||||
PutImg(imageURL string, useProxy bool) (string, error)
|
||||
PutUrlFile(url string, useProxy bool) (string, error)
|
||||
PutBase64(imageData string) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package oss
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"geekai/core/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/smartwalle/alipay/v3"
|
||||
"log"
|
||||
"net/url"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/alipay"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
@@ -28,93 +36,90 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
||||
}
|
||||
|
||||
xClient, err := alipay.New(config.AppId, priKey, !config.SandBox)
|
||||
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
|
||||
}
|
||||
|
||||
if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error with loading App PublicKey: %v", err)
|
||||
}
|
||||
if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil {
|
||||
return nil, fmt.Errorf("error with loading alipay RootCert: %v", err)
|
||||
}
|
||||
if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err)
|
||||
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
||||
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
||||
SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2
|
||||
SetReturnUrl(config.ReturnURL). // 设置返回URL
|
||||
SetNotifyUrl(config.NotifyURL)
|
||||
|
||||
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
||||
}
|
||||
|
||||
return &AlipayService{config: &config, client: xClient}, nil
|
||||
return &AlipayService{config: &config, client: client}, nil
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) {
|
||||
var p = alipay.TradeWapPay{}
|
||||
p.NotifyURL = notifyURL
|
||||
p.ReturnURL = returnURL
|
||||
p.Subject = subject
|
||||
p.OutTradeNo = outTradeNo
|
||||
p.TotalAmount = Amount
|
||||
p.ProductCode = "QUICK_WAP_WAY"
|
||||
res, err := s.client.TradeWapPay(p)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return res.String(), err
|
||||
func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", subject)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
bm.Set("quit_url", s.config.ReturnURL)
|
||||
bm.Set("total_amount", amount)
|
||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
||||
return s.client.TradeWapPay(context.Background(), bm)
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) {
|
||||
var p = alipay.TradePagePay{}
|
||||
p.NotifyURL = notifyURL
|
||||
p.ReturnURL = returnURL
|
||||
p.Subject = subject
|
||||
p.OutTradeNo = outTradeNo
|
||||
p.TotalAmount = amount
|
||||
p.ProductCode = "FAST_INSTANT_TRADE_PAY"
|
||||
res, err := s.client.TradePagePay(p)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return res.String(), err
|
||||
func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", subject)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
bm.Set("total_amount", amount)
|
||||
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
||||
return s.client.TradePagePay(context.Background(), bm)
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo {
|
||||
err := s.client.VerifySign(reqForm)
|
||||
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
|
||||
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
|
||||
if err != nil {
|
||||
log.Println("异步通知验证签名发生错误", err)
|
||||
return NotifyVo{
|
||||
Status: 0,
|
||||
Message: "异步通知验证签名发生错误",
|
||||
Status: Failure,
|
||||
Message: "error with parse notify request: " + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return s.TradeQuery(reqForm.Get("out_trade_no"))
|
||||
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "error with verify sign: " + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return s.TradeQuery(request.Form.Get("out_trade_no"))
|
||||
}
|
||||
|
||||
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
|
||||
var p = alipay.TradeQuery{}
|
||||
p.OutTradeNo = outTradeNo
|
||||
rsp, err := s.client.TradeQuery(p)
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
|
||||
//查询订单
|
||||
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: 0,
|
||||
Status: Failure,
|
||||
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" {
|
||||
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
|
||||
return NotifyVo{
|
||||
Status: 1,
|
||||
OutTradeNo: rsp.OutTradeNo,
|
||||
TradeNo: rsp.TradeNo,
|
||||
Amount: rsp.TotalAmount,
|
||||
Subject: rsp.Subject,
|
||||
Status: Success,
|
||||
OutTradeNo: rsp.Response.OutTradeNo,
|
||||
TradeId: rsp.Response.TradeNo,
|
||||
Amount: rsp.Response.TotalAmount,
|
||||
Subject: rsp.Response.Subject,
|
||||
Message: "OK",
|
||||
}
|
||||
} else {
|
||||
return NotifyVo{
|
||||
Status: 0,
|
||||
Status: Failure,
|
||||
Message: "异步查询验证订单信息发生错误" + outTradeNo,
|
||||
}
|
||||
}
|
||||
@@ -127,16 +132,3 @@ func readKey(filename string) (string, error) {
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
type NotifyVo struct {
|
||||
Status int
|
||||
OutTradeNo string
|
||||
TradeNo string
|
||||
Amount string
|
||||
Message string
|
||||
Subject string
|
||||
}
|
||||
|
||||
func (v NotifyVo) Success() bool {
|
||||
return v.Status == 1
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -42,6 +49,8 @@ type HuPiPayReq struct {
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Time string `json:"time"`
|
||||
NonceStr string `json:"nonce_str"`
|
||||
Type string `json:"type"`
|
||||
WapUrl string `json:"wap_url"`
|
||||
}
|
||||
|
||||
type HuPiResp struct {
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,12 +21,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PayJS struct {
|
||||
type JPayService struct {
|
||||
config *types.JPayConfig
|
||||
}
|
||||
|
||||
func NewPayJS(appConfig *types.AppConfig) *PayJS {
|
||||
return &PayJS{
|
||||
func NewJPayService(appConfig *types.AppConfig) *JPayService {
|
||||
return &JPayService{
|
||||
config: &appConfig.JPayConfig,
|
||||
}
|
||||
}
|
||||
@@ -46,7 +53,7 @@ func (r JPayReps) IsOK() bool {
|
||||
return r.ReturnMsg == "SUCCESS"
|
||||
}
|
||||
|
||||
func (js *PayJS) Pay(param JPayReq) JPayReps {
|
||||
func (js *JPayService) Pay(param JPayReq) JPayReps {
|
||||
param.NotifyURL = js.config.NotifyURL
|
||||
var p = url.Values{}
|
||||
encode := utils.JsonEncode(param)
|
||||
@@ -79,13 +86,13 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
|
||||
return data
|
||||
}
|
||||
|
||||
func (js *PayJS) PayH5(p url.Values) string {
|
||||
func (js *JPayService) PayH5(p url.Values) string {
|
||||
p.Add("mchid", js.config.AppId)
|
||||
p.Add("sign", js.sign(p))
|
||||
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
|
||||
}
|
||||
|
||||
func (js *PayJS) sign(params url.Values) string {
|
||||
func (js *JPayService) sign(params url.Values) string {
|
||||
params.Del(`sign`)
|
||||
var keys = make([]string, 0, 0)
|
||||
for key := range params {
|
||||
@@ -110,20 +117,18 @@ func (js *PayJS) sign(params url.Values) string {
|
||||
return strings.ToUpper(md5res)
|
||||
}
|
||||
|
||||
// Check 查询订单支付状态
|
||||
// TradeVerify 查询订单支付状态
|
||||
// @param tradeNo 支付平台交易 ID
|
||||
func (js *PayJS) Check(tradeNo string) error {
|
||||
func (js *JPayService) TradeVerify(tradeNo string) error {
|
||||
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
|
||||
params := url.Values{}
|
||||
params.Add("payjs_order_id", tradeNo)
|
||||
params.Add("sign", js.sign(params))
|
||||
data := strings.NewReader(params.Encode())
|
||||
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
|
||||
defer resp.Body.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with http reqeust: %v", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
||||
19
api/service/payment/types.go
Normal file
19
api/service/payment/types.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package payment
|
||||
|
||||
type NotifyVo struct {
|
||||
Status int
|
||||
OutTradeNo string // 商户订单号
|
||||
TradeId string // 交易ID
|
||||
Amount string // 交易金额
|
||||
Message string
|
||||
Subject string
|
||||
}
|
||||
|
||||
func (v NotifyVo) Success() bool {
|
||||
return v.Status == Success
|
||||
}
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
Failure = 1
|
||||
)
|
||||
135
api/service/payment/wepay_service.go
Normal file
135
api/service/payment/wepay_service.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/wechat/v3"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WechatPayService struct {
|
||||
config *types.WechatPayConfig
|
||||
client *wechat.ClientV3
|
||||
}
|
||||
|
||||
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
|
||||
config := appConfig.WechatPayConfig
|
||||
if !config.Enabled {
|
||||
logger.Info("Disabled WechatPay service")
|
||||
return nil, nil
|
||||
}
|
||||
priKey, err := readKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
||||
}
|
||||
|
||||
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
|
||||
}
|
||||
err = client.AutoVerifySign()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
|
||||
}
|
||||
//client.DebugSwitch = gopay.DebugOn
|
||||
|
||||
return &WechatPayService{config: &config, client: client}, nil
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", subject).
|
||||
Set("out_trade_no", outTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", s.config.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", amount).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
|
||||
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.CodeUrl, nil
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", subject).
|
||||
Set("out_trade_no", outTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", s.config.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", amount).
|
||||
Set("currency", "CNY")
|
||||
}).
|
||||
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("payer_client_ip", ip).
|
||||
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("type", "Wap")
|
||||
})
|
||||
})
|
||||
|
||||
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.H5Url, nil
|
||||
}
|
||||
|
||||
type NotifyResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `xml:"message"`
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
|
||||
notifyReq, err := wechat.V3ParseNotify(request)
|
||||
if err != nil {
|
||||
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
|
||||
}
|
||||
|
||||
// TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签
|
||||
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("error with client v3 verify sign: %v", err)
|
||||
//}
|
||||
|
||||
// 解密支付密文,验证订单信息
|
||||
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
||||
if err != nil {
|
||||
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
|
||||
}
|
||||
|
||||
return NotifyVo{
|
||||
Status: Success,
|
||||
OutTradeNo: result.OutTradeNo,
|
||||
TradeId: result.TransactionId,
|
||||
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,18 @@
|
||||
package sd
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
@@ -18,28 +25,14 @@ type ServicePool struct {
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
uploader *oss.UploaderManager
|
||||
levelDB *store.LevelDB
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||
// create mj client and service
|
||||
for _, config := range appConfig.SdConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
|
||||
// create sd service
|
||||
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
|
||||
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
|
||||
// run sd service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
taskQueue: taskQueue,
|
||||
@@ -47,6 +40,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
services: services,
|
||||
db: db,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploader: manager,
|
||||
levelDB: levelDB,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
|
||||
// stop old service
|
||||
for _, s := range p.services {
|
||||
s.Stop()
|
||||
}
|
||||
p.services = make([]*Service, 0)
|
||||
|
||||
for k, config := range configs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
|
||||
// create sd service
|
||||
name := fmt.Sprintf(" sd-service-%d", k)
|
||||
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
|
||||
// run sd service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
p.services = append(p.services, service)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,16 +79,16 @@ func (p *ServicePool) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var userId uint
|
||||
err := p.notifyQueue.LPop(&userId)
|
||||
var message NotifyMessage
|
||||
err := p.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := p.Clients.Get(userId)
|
||||
client := p.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte("Task Updated"))
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -113,7 +132,7 @@ func (p *ServicePool) CheckTaskStatus() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
package sd
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
@@ -24,9 +32,11 @@ type Service struct {
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
name string // service name
|
||||
leveldb *store.LevelDB
|
||||
running bool // 运行状态
|
||||
}
|
||||
|
||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
|
||||
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||
return &Service{
|
||||
name: name,
|
||||
@@ -35,23 +45,39 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
uploadManager: manager,
|
||||
running: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
for {
|
||||
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
|
||||
for s.running {
|
||||
var task types.SdTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
// 翻译提示词
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// translate negative prompt
|
||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.Params.NegPrompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,12 +91,16 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Stop() {
|
||||
s.running = false
|
||||
}
|
||||
|
||||
// Txt2ImgReq 文生图请求实体
|
||||
type Txt2ImgReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -81,6 +111,7 @@ type Txt2ImgReq struct {
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
SamplerName string `json:"sampler_name"`
|
||||
Scheduler string `json:"scheduler"`
|
||||
EnableHr bool `json:"enable_hr,omitempty"`
|
||||
HrScale int `json:"hr_scale,omitempty"`
|
||||
HrUpscaler string `json:"hr_upscaler,omitempty"`
|
||||
@@ -108,12 +139,14 @@ type TaskProgressResp struct {
|
||||
func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
body := Txt2ImgReq{
|
||||
Prompt: task.Params.Prompt,
|
||||
NegativePrompt: task.Params.NegativePrompt,
|
||||
NegativePrompt: task.Params.NegPrompt,
|
||||
Steps: task.Params.Steps,
|
||||
CfgScale: task.Params.CfgScale,
|
||||
Width: task.Params.Width,
|
||||
Height: task.Params.Height,
|
||||
SamplerName: task.Params.Sampler,
|
||||
Scheduler: task.Params.Scheduler,
|
||||
ForceTaskId: task.Params.TaskId,
|
||||
}
|
||||
if task.Params.Seed > 0 {
|
||||
body.Seed = task.Params.Seed
|
||||
@@ -129,8 +162,13 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
var errChan = make(chan error)
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||
logger.Debugf("send image request to %s", apiURL)
|
||||
// send a request to sd api endpoint
|
||||
go func() {
|
||||
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
@@ -154,27 +192,35 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
return
|
||||
}
|
||||
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt})
|
||||
errChan <- nil
|
||||
}()
|
||||
|
||||
// waiting for task finish
|
||||
for {
|
||||
select {
|
||||
case err := <-errChan: // 任务完成
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress()
|
||||
// 更新任务进度
|
||||
if err == nil && resp.Progress > 0 {
|
||||
logger.Debugf("Check task progress: %+v", resp.Progress)
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@@ -186,7 +232,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
@@ -1,47 +1,24 @@
|
||||
package sd
|
||||
|
||||
import logger2 "chatplus/logger"
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import logger2 "geekai/logger"
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type TaskInfo struct {
|
||||
UserId uint `json:"user_id"`
|
||||
SessionId string `json:"session_id"`
|
||||
JobId int `json:"job_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
Data []interface{} `json:"data"`
|
||||
EventData interface{} `json:"event_data"`
|
||||
FnIndex int `json:"fn_index"`
|
||||
SessionHash string `json:"session_hash"`
|
||||
type NotifyMessage struct {
|
||||
UserId int `json:"user_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
UserId uint
|
||||
SessionId string
|
||||
JobId int
|
||||
TaskId string
|
||||
ImageName string
|
||||
ImageData string
|
||||
Progress int
|
||||
Seed int64
|
||||
Success bool
|
||||
Message string
|
||||
}
|
||||
|
||||
var ParamKeys = map[string]int{
|
||||
"task_id": 0,
|
||||
"prompt": 1,
|
||||
"negative_prompt": 2,
|
||||
"steps": 4,
|
||||
"sampler": 5,
|
||||
"face_fix": 7, // 面部修复
|
||||
"cfg_scale": 8,
|
||||
"seed": 27,
|
||||
"height": 10,
|
||||
"width": 9,
|
||||
"hd_fix": 11,
|
||||
"hd_redraw_rate": 12, //高清修复重绘幅度
|
||||
"hd_scale": 13, // 高清修复放大倍数
|
||||
"hd_scale_alg": 14, // 高清修复放大算法
|
||||
"hd_sample_num": 15, // 高清修复采样次数
|
||||
}
|
||||
const (
|
||||
Running = "RUNNING"
|
||||
Finished = "FINISH"
|
||||
Failed = "FAIL"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
const Ali = "ALI"
|
||||
const Bao = "BAO"
|
||||
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"mime"
|
||||
"net/smtp"
|
||||
"net/textproto"
|
||||
)
|
||||
|
||||
type SmtpService struct {
|
||||
@@ -19,12 +28,18 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
|
||||
}
|
||||
|
||||
func (s *SmtpService) SendVerifyCode(to string, code int) error {
|
||||
subject := "ChatPlus注册验证码"
|
||||
body := fmt.Sprintf("您正在注册 ChatPlus AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code)
|
||||
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
|
||||
body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
|
||||
|
||||
// 设置SMTP客户端配置
|
||||
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
|
||||
if s.config.UseTls {
|
||||
return s.sendTLS(auth, to, subject, body)
|
||||
} else {
|
||||
return s.send(auth, to, subject, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SmtpService) send(auth smtp.Auth, to string, subject string, body string) error {
|
||||
// 对主题进行MIME编码
|
||||
encodedSubject := mime.QEncoding.Encode("UTF-8", subject)
|
||||
// 组装邮件
|
||||
@@ -34,11 +49,83 @@ func (s *SmtpService) SendVerifyCode(to string, code int) error {
|
||||
message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject))
|
||||
message.WriteString("\r\n" + body)
|
||||
|
||||
// 发送邮件
|
||||
// 发送邮件
|
||||
err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending email: %v", err)
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *SmtpService) sendTLS(auth smtp.Auth, to string, subject string, body string) error {
|
||||
// TLS配置
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: s.config.Host,
|
||||
}
|
||||
|
||||
// 建立TLS连接
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to SMTP server: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, s.config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating SMTP client: %v", err)
|
||||
}
|
||||
defer client.Quit()
|
||||
|
||||
// 身份验证
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("error authenticating: %v", err)
|
||||
}
|
||||
|
||||
// 设置寄件人
|
||||
if err = client.Mail(s.config.From); err != nil {
|
||||
return fmt.Errorf("error setting sender: %v", err)
|
||||
}
|
||||
|
||||
// 设置收件人
|
||||
if err = client.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("error setting recipient: %v", err)
|
||||
}
|
||||
|
||||
// 发送邮件内容
|
||||
wc, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting data writer: %v", err)
|
||||
}
|
||||
defer wc.Close()
|
||||
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("From", s.config.From)
|
||||
header.Set("To", to)
|
||||
header.Set("Subject", subject)
|
||||
|
||||
// 将邮件头写入
|
||||
for key, values := range header {
|
||||
for _, value := range values {
|
||||
_, err = fmt.Fprintf(wc, "%s: %s\r\n", key, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending email header: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
_, _ = fmt.Fprintln(wc)
|
||||
// 将邮件内容写入
|
||||
_, err = fmt.Fprintf(wc, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending email: %v", err)
|
||||
}
|
||||
|
||||
// 发送完毕
|
||||
err = wc.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error closing data writer: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
355
api/service/suno/service.go
Normal file
355
api/service/suno/service.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package suno
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.SunoTask) {
|
||||
logger.Infof("add a new Suno task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
var jobs []model.SunoJob
|
||||
s.db.Where("task_id", "").Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
s.PushTask(types.SunoTask{
|
||||
Id: v.Id,
|
||||
Channel: v.Channel,
|
||||
UserId: v.UserId,
|
||||
Type: v.Type,
|
||||
Title: v.Title,
|
||||
RefTaskId: v.RefTaskId,
|
||||
RefSongId: v.RefSongId,
|
||||
Prompt: v.Prompt,
|
||||
Tags: v.Tags,
|
||||
Model: v.ModelName,
|
||||
Instrumental: v.Instrumental,
|
||||
ExtendSecs: v.ExtendSecs,
|
||||
})
|
||||
}
|
||||
logger.Info("Starting Suno job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
var task types.SunoTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
r, err := s.Create(task)
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": 101,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Data,
|
||||
"channel": r.Channel,
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type RespVo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"task_id": task.RefTaskId,
|
||||
"continue_clip_id": task.RefSongId,
|
||||
"continue_at": task.ExtendSecs,
|
||||
"make_instrumental": task.Instrumental,
|
||||
}
|
||||
// 灵感模式
|
||||
if task.Type == 1 {
|
||||
reqBody["gpt_description_prompt"] = task.Prompt
|
||||
} else { // 自定义模式
|
||||
reqBody["prompt"] = task.Prompt
|
||||
reqBody["tags"] = task.Tags
|
||||
reqBody["mv"] = task.Model
|
||||
reqBody["title"] = task.Title
|
||||
}
|
||||
|
||||
var res RespVo
|
||||
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
if res.Code != "success" {
|
||||
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||
}
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Suno task notify checking ...")
|
||||
for {
|
||||
var message sd.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
go func() {
|
||||
var items []model.SunoJob
|
||||
for {
|
||||
res := s.db.Where("progress", 102).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
// 下载图片和音频
|
||||
logger.Infof("try download cover image: %s", v.CoverURL)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download audio: %s", v.AudioURL)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
}
|
||||
v.CoverURL = coverURL
|
||||
v.AudioURL = audioURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.SunoJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
task, err := s.QueryTask(job.TaskId, job.Channel)
|
||||
if err != nil {
|
||||
logger.Errorf("query task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if task.Code != "success" {
|
||||
logger.Errorf("query task with error: %v", task.Message)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("task: %+v", task.Data.Status)
|
||||
// 任务完成,删除旧任务插入两条新任务
|
||||
if task.Data.Status == "SUCCESS" {
|
||||
var jobId = job.Id
|
||||
var flag = false
|
||||
tx := s.db.Begin()
|
||||
for _, v := range task.Data.Data {
|
||||
job.Id = 0
|
||||
job.Progress = 102 // 102 表示资源未下载完成
|
||||
job.Title = v.Title
|
||||
job.SongId = v.Id
|
||||
job.Duration = int(v.Metadata.Duration)
|
||||
job.Prompt = v.Metadata.Prompt
|
||||
job.Tags = v.Metadata.Tags
|
||||
job.ModelName = v.ModelName
|
||||
job.RawData = utils.JsonEncode(v)
|
||||
job.CoverURL = v.ImageLargeUrl
|
||||
job.AudioURL = v.AudioUrl
|
||||
|
||||
if err = tx.Create(&job).Error; err != nil {
|
||||
logger.Error("create job with error: %v", err)
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
flag = true
|
||||
}
|
||||
|
||||
// 删除旧任务
|
||||
if flag {
|
||||
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
|
||||
logger.Error("create job with error: %v", err)
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = 101
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type QueryRespVo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Action string `json:"action"`
|
||||
Status string `json:"status"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
SubmitTime int `json:"submit_time"`
|
||||
StartTime int `json:"start_time"`
|
||||
FinishTime int `json:"finish_time"`
|
||||
Progress string `json:"progress"`
|
||||
Data []struct {
|
||||
Id string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Status string `json:"status"`
|
||||
Metadata struct {
|
||||
Tags string `json:"tags"`
|
||||
Type string `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
Duration float64 `json:"duration"`
|
||||
ErrorMessage interface{} `json:"error_message"`
|
||||
} `json:"metadata"`
|
||||
AudioUrl string `json:"audio_url"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
ModelName string `json:"model_name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
ImageLargeUrl string `json:"image_large_url"`
|
||||
MajorModelVersion string `json:"major_model_version"`
|
||||
} `json:"data"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
|
||||
Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return QueryRespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId)
|
||||
var res QueryRespVo
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||||
}
|
||||
|
||||
defer r.Body.Close()
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user