mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-10-31 14:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			242 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 54b45ec2ff | ||
|  | c434f85045 | ||
|  | 4d10279870 | ||
|  | 9de9489673 | ||
|  | 9814fec930 | ||
|  | 53ba731159 | ||
|  | b2f57aa483 | ||
|  | 4c2dba1004 | ||
|  | 79adc871ef | ||
|  | 8144fada25 | ||
|  | 754ba02263 | ||
|  | 7ddf57ae06 | ||
|  | cc5180a6f7 | ||
|  | 96f1126d02 | ||
|  | 7f9b8d8246 | ||
|  | 5132d52a44 | ||
|  | abdf5298fe | ||
|  | 2129f7a8b7 | ||
|  | f6f8748521 | ||
|  | 59301df073 | ||
|  | e17dcf4d5f | ||
|  | 09f44e6d9b | ||
|  | 59824bffc5 | ||
|  | cb0dacd5e0 | ||
|  | 7463cfc66c | ||
|  | b248560ba2 | ||
|  | 37368fe13f | ||
|  | 246b023624 | ||
|  | 9f44c34d34 | ||
|  | a6b9f57a50 | ||
|  | 42bc23cacf | ||
|  | 282f55c7a3 | ||
|  | 44798f89ba | ||
|  | 596cb2b206 | ||
|  | d1965deff1 | ||
|  | b793b81768 | ||
|  | a5ef4299ec | ||
|  | cdb1a8bde1 | ||
|  | 233f6e00f0 | ||
|  | b7dba68549 | ||
|  | 64e5fc48ba | ||
|  | a692cf1338 | ||
|  | 6998dd7af4 | ||
|  | 9343c73e0f | ||
|  | 739cd46539 | ||
|  | f8fed83507 | ||
|  | d63536d5ef | ||
|  | 4905fb28d4 | ||
|  | a3a2a8abcb | ||
|  | 839dd8dbf4 | ||
|  | 0375164f40 | ||
|  | 691294b444 | ||
|  | bdea12c51a | ||
|  | a27d9ea259 | ||
|  | 7cd824c284 | ||
|  | e27d95e2b5 | ||
|  | c24b4d7074 | ||
|  | 6839827db0 | ||
|  | ab24398748 | ||
|  | 6110522b54 | ||
|  | bcdf5e3776 | ||
|  | 2207830db9 | ||
|  | d52dfbfef4 | ||
|  | d6a04f96fe | ||
|  | 66ccb387e8 | ||
|  | 5f820b9dc1 | ||
|  | 3cc2263dc7 | ||
|  | f0a3c5d8ae | ||
|  | 2a4ef27774 | ||
|  | 2b057f32aa | ||
|  | bc6451026f | ||
|  | 99fd596862 | ||
|  | f0959b5df6 | ||
|  | 6788edbe9d | ||
|  | 3895305882 | ||
|  | 1b0938b33f | ||
|  | c2acbaaa94 | ||
|  | 02faff461a | ||
|  | e18e5a38c6 | ||
|  | 2f9b1b7835 | ||
|  | 717b137a6d | ||
|  | f755bdccae | ||
|  | 4bba77ab47 | ||
|  | 6944a32ff3 | ||
|  | 5742b40aee | ||
|  | 7f1ec90748 | ||
|  | 4a99be2f15 | ||
|  | bee19392c1 | ||
|  | 27c816cf3b | ||
|  | 0d81776212 | ||
|  | 00d31a2379 | ||
|  | cccab31c0f | ||
|  | 5d65505ab7 | ||
|  | 3dc7d0516a | ||
|  | 50335ebc2d | ||
|  | bcadee7290 | ||
|  | cac3194d5b | ||
|  | 4ddf3bf2bf | ||
|  | d45f9fbad6 | ||
|  | d98b08d7cd | ||
|  | 5a8fe5a6cf | ||
|  | 36c27d6092 | ||
|  | 3ab29da8f0 | ||
|  | 3699f024f1 | ||
|  | 3d37a3d367 | ||
|  | 73d8236697 | ||
|  | 114d0088dc | ||
|  | 43b6665370 | ||
|  | 5fb9f84182 | ||
|  | e35c34ad9a | ||
|  | 1a4d798f8b | ||
|  | afb91a7023 | ||
|  | dc4c1f7877 | ||
|  | bbc8fe2b40 | ||
|  | 3c34e8e0e7 | ||
|  | 57c932f07c | ||
|  | 922202734a | ||
|  | 8b3b0139b0 | ||
|  | 31828a3336 | ||
|  | b270960a04 | ||
|  | 5c4899df6e | ||
|  | 9a797bb4a5 | ||
|  | b0c9ffc5a6 | ||
|  | f527cc5b98 | ||
|  | debe8dc209 | ||
|  | 2f0215ac87 | ||
|  | dd5cc206e5 | ||
|  | 142cd553a3 | ||
|  | 657ecccee3 | ||
|  | 1232c3cd9c | ||
|  | 3ac04a3938 | ||
|  | b7abc42209 | ||
|  | a48179ce0e | ||
|  | e589f25a05 | ||
|  | cc1a3ce343 | ||
|  | 7bb76d581c | ||
|  | 0d733c0be0 | ||
|  | 8b40ac5b5c | ||
|  | 24479814e9 | ||
|  | 99df028237 | ||
|  | b354b88876 | ||
|  | 5e0be4d10e | ||
|  | 468b48151f | ||
|  | fa5c036041 | ||
|  | 0fdc588167 | ||
|  | 2e023cb8dc | ||
|  | e933f32d9c | ||
|  | bd4b0c4d65 | ||
|  | 0b2501c1d8 | ||
|  | 9d28e62142 | ||
|  | c1d892069e | ||
|  | 61b2dbc9f1 | ||
|  | be3245666e | ||
|  | dacdd6fe74 | ||
|  | 6807f7e88a | ||
|  | 087f5ab2d1 | ||
|  | 47c5a0387b | ||
|  | f9da18ad52 | ||
|  | 5c9025ca22 | ||
|  | d02cb573fd | ||
|  | caa538a1d0 | ||
|  | b584b4bfb6 | ||
|  | bda335212d | ||
|  | 06f4cdc649 | ||
|  | 336a7d5b56 | ||
|  | a0f464830f | ||
|  | 9bf7fa4081 | ||
|  | 96ead65774 | ||
|  | 7ad41927aa | ||
|  | 4ca9dfd9c0 | ||
|  | 8a9f386d8f | ||
|  | adfee8bf58 | ||
|  | fbfa2a71a9 | ||
|  | 9a1368ef17 | ||
|  | 31b02b97d3 | ||
|  | 42da38c5c3 | ||
|  | 0a01b55713 | ||
|  | 3b292c2a12 | ||
|  | db0ba0d9a0 | ||
|  | 3a23ff6b42 | ||
|  | 1e9c5adb0a | ||
|  | abab76ccc6 | ||
|  | 6efd92806f | ||
|  | cfe333e89f | ||
|  | a7237fe62f | ||
|  | c3c454b7d7 | ||
|  | d4d708d44b | ||
|  | 7f0b6a3a46 | ||
|  | c2a7c089d2 | ||
|  | df5bd4df60 | ||
|  | 79b6010104 | ||
|  | 97b0a98793 | ||
|  | 5230f90540 | ||
|  | 803db4e895 | ||
|  | 7cee9f2ebb | ||
|  | 8be9a21efd | ||
|  | 6a3e26b566 | ||
|  | 0355c37bef | ||
|  | 9b7ee538c4 | ||
|  | d900a3d08e | ||
|  | cdf5b66729 | ||
|  | 1cff4b63cd | ||
|  | da14309ef9 | ||
|  | fbb216fe3b | ||
|  | 95efbd5659 | ||
|  | 4596c1049c | ||
|  | b35d95f0c7 | ||
|  | 01419df998 | ||
|  | a6c00c42fa | ||
|  | 4cc9db7115 | ||
|  | 4f1ed54059 | ||
|  | 8227a73e35 | ||
|  | adfd8c1939 | ||
|  | be8a0ec184 | ||
|  | b02e3aad95 | ||
|  | 08eca511ad | ||
|  | c34e911596 | ||
|  | 8a452c3072 | ||
|  | 13bfb14107 | ||
|  | 4188b0969e | ||
|  | 0c27795a10 | ||
|  | d05693c5c1 | ||
|  | c0b2063b38 | ||
|  | 4d183747b1 | ||
|  | 08fe1b2f75 | ||
|  | db3e8a267e | ||
|  | 8fc62682c4 | ||
|  | 75031914a3 | ||
|  | a4c9fdd95a | ||
|  | 6a9bfeb5aa | ||
|  | e654766f60 | ||
|  | 0ef6955f96 | ||
|  | b4501557c9 | ||
|  | a2ed99e6cb | ||
|  | 6bd6bb3885 | ||
|  | 399cf65fc9 | ||
|  | 24906a6df1 | ||
|  | d772bbebe6 | ||
|  | 14988853a3 | ||
|  | 7b3f16ac9f | ||
|  | 82b2755c18 | ||
|  | 4e4dc4cb73 | 
							
								
								
									
										2
									
								
								.github/ISSUE_TEMPLATE/1.bug.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ISSUE_TEMPLATE/1.bug.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,5 +1,5 @@ | ||||
| name: Bug 报告 🐛 | ||||
| description: 为 chatgpt-plus 提交错误报告 | ||||
| description: 为 geekai 提交错误报告 | ||||
| labels: ['Bug'] | ||||
| body: | ||||
|   - type: checkboxes | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/ISSUE_TEMPLATE/2.feature.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ISSUE_TEMPLATE/2.feature.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,5 +1,5 @@ | ||||
| name: 功能优化 🚀 | ||||
| description: 为 chatgpt-plus 提交优化建议 | ||||
| description: 为 geekai 提交优化建议 | ||||
| labels: ['feature'] | ||||
| body: | ||||
|   - type: checkboxes | ||||
|   | ||||
							
								
								
									
										129
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										129
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -1,20 +1,133 @@ | ||||
| # 更新日志 | ||||
| ## v4.1.1 | ||||
| * Bug修复:修复 GPT 模型 function call 调用后没有输出的问题 | ||||
| * 功能新增:允许获取 License 授权用户可以自定义版权信息 | ||||
| * 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件 | ||||
| * 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求 | ||||
| * 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用 | ||||
| * 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息 | ||||
| * 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式 | ||||
| * 功能新增:允许在管理后台设置首页显示的导航菜单 | ||||
| * Bug修复:修复注册页面先显示关闭注册组件,然后再显示注册组件 | ||||
| * 功能新增:增加 Suno 文生歌曲功能 | ||||
| * 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加 | ||||
| * 功能优化:在所有列表页面增加返回顶部按钮 | ||||
|  | ||||
| ## v4.1.0 | ||||
| * bug修复:修复移动端修改聊天标题不生效的问题 | ||||
| * Bug修复:修复用户注册不显示用户名的问题 | ||||
| * Bug修复:修复管理后台拖动排序不生效的问题 | ||||
| * 功能优化:允许用户设置自定义首页背景图片 | ||||
| * 功能新增:**支持AI解读 PDF, Word, Excel等文件** | ||||
| * 功能优化:优化聊天界面的用户上传文件的列表样式 | ||||
| * 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换 | ||||
| * 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。 | ||||
|  | ||||
|  | ||||
| ## v4.0.9 | ||||
| * 环境升级:升级 Golang 到 go1.22.4 | ||||
| * 功能增加:接入微信商户号支付渠道 | ||||
| * Bug修复:修复前端页面菜单把页面撑开,底部留白问题 | ||||
| * 功能优化:聊天页面自动根据内容调整输入框的高度 | ||||
| * Bug修复:修复Dalle绘图失败退回算力的问题 | ||||
| * 功能优化:邀请码注册时被邀请人也可以获得赠送的算力 | ||||
| * 功能优化:允许设置邮件验证码的抬头 | ||||
| * Bug修复:修复免费模型不会记录聊天记录的bug | ||||
| * Bug修复:修复聊天输入公式显示异常的Bug | ||||
|  | ||||
| ## v4.0.8 | ||||
| * 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题 | ||||
| * 功能优化:当数据库更新失败的时候记录错误日志 | ||||
| * 功能优化:聊天输入框会随着输入内容的增多自动调整高度 | ||||
| * Bug修复:修复移动端聊天页面模型切换不生效的Bug | ||||
| * 功能优化:给PC端扫码支付增加签名验证和有效期验证 | ||||
| * Bug修复:修复支付码生成API权限控制的问题 | ||||
| * Bug修复:模型算力设置为0时,不扣减用户算力,并且不记录算力消费日志 | ||||
| * 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片 | ||||
| * 功能新增:H5端支持 Dalle 绘图 | ||||
|  | ||||
| ## v4.0.7 | ||||
|  | ||||
| * 功能优化:升级quic-go,支持 Go1.21 | ||||
| * 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片 | ||||
| * Bug修复:修复弹窗等于图形验证码一直验证失败的问题 | ||||
| * 功能重构:重构前端 UI 页面,增加顶部导航 | ||||
| * 功能优化:优化 Vue 非父子组件之间的通信方式 | ||||
| * 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量 | ||||
|  | ||||
| ## v4.0.6 | ||||
|  | ||||
| * Bug修复:修复PC端画廊页面的瀑布流组件样式错乱问题 | ||||
| * 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位 | ||||
| * Bug修复:修复思维导图不扣费的Bug | ||||
| * Bug修复:修复管理后台角色删除失败的Bug | ||||
| * Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数 | ||||
| * 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY | ||||
| * Bug修复:修复注册用户提示注册人数达到上限的 Bug | ||||
| * 功能优化:将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件 | ||||
|  | ||||
| ## v4.0.5 | ||||
|  | ||||
| * 功能优化:已授权系统在后台显示授权信息 | ||||
| * 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误 | ||||
| * 功能优化:优化首页登录注册页面的 UI | ||||
| * BUG修复:修复License验证的逻辑漏洞 | ||||
| * Bug修复:后台添加用户的时候密码规则限制跟前台注册保持一致 | ||||
| * 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题 | ||||
| * 功能新增:移动端新增 DALL-E 绘画功能 | ||||
| * 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题 | ||||
| * 功能新增:移动支持免登录预览功能 | ||||
| * Bug修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题 | ||||
| * Bug修复:修复部分中转 API 模型会出现第一输出的字符被淹没的Bug | ||||
|  | ||||
| ## v4.0.4 | ||||
|  | ||||
| * Bug修复:修复统一千问第二句不回复的问题 | ||||
| * 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度 | ||||
| * 功能新增:Dalle AI 绘画功能实现 | ||||
| * Bug修复:修复思维导图格式乱码问题 | ||||
| * 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题 | ||||
| * 功能新增:支持从应用列表直接和某个应用对话 | ||||
| * 功能优化:优化算力日志的页面和首页的UI | ||||
| * 功能新增:支持思维导图导出 PNG 图片下载 | ||||
|  | ||||
| ## v4.0.3 | ||||
|  | ||||
| * 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型 | ||||
| * Bug修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug | ||||
| * Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题 | ||||
| * 功能新增:支持上传图片和视觉模型 | ||||
| * 功能优化:优化聊天页面的复制代码按钮样式乱码 | ||||
| * 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图 | ||||
| * 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4 | ||||
| * 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。 | ||||
| * 功能新增:支持管理后台 Logo 修改 | ||||
|  | ||||
| ## 4.0.2 | ||||
|  | ||||
| * 功能新增:支持前端菜单可以配置 | ||||
| * 功能优化:手机端支持免登录预览功能 | ||||
| * 功能优化:在登录和注册界面标题显示软件版本号 | ||||
| * 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数 | ||||
| * 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题 | ||||
| * Bug修复:解决因为图片上传使用相对路径而导致融图失败的问题。 | ||||
| * 功能新增:手机端支持 Stable-Diffusion 绘画 | ||||
| * 功能新增:管理后台登录页面增加行为验证码,防止爆破 | ||||
|  | ||||
| ## v4.0.1 | ||||
| * 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些 | ||||
| * 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 MJ-Plus 中转 | ||||
|  | ||||
| * 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion | ||||
|   发行版,稳定性更强一些 | ||||
| * 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 | ||||
|   MJ-Plus 中转 | ||||
| * 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力 | ||||
| * Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换 | ||||
| * Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug | ||||
|  | ||||
| ## v4.0.0 | ||||
|  | ||||
| 非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。 | ||||
| 只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力... | ||||
| 只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ | ||||
| 对话消耗15个算力... | ||||
|  | ||||
| * 功能重构:重构整体系统,全部采用算力来进行结算 | ||||
| * 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽 | ||||
| @@ -29,6 +142,7 @@ | ||||
| * 功能新增:管理后台新增7日内新增用户和新增订单统计 | ||||
|  | ||||
| ## v3.2.7 | ||||
|  | ||||
| * 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能 | ||||
| * 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能 | ||||
| * Bug修复:修复 issue [ | ||||
| @@ -43,6 +157,7 @@ | ||||
| * 功能新增:后台管理新怎对话查看和检索功能 | ||||
|  | ||||
| ## v3.2.6 | ||||
|  | ||||
| * 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号 | ||||
| * 功能优化:兼用旧版本微信收款消息解析 | ||||
| * 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源 | ||||
| @@ -56,16 +171,18 @@ | ||||
| * 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug | ||||
| * 功能新增:新增短信宝短信平台发送平台集成 | ||||
|  | ||||
|  | ||||
| ## v3.2.5 | ||||
|  | ||||
| * 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。 | ||||
| * 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 Plus 账号了!!! | ||||
| * 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 | ||||
|   Plus 账号了!!! | ||||
| * 功能优化:增强 markdown 图片和引用块解析。 | ||||
| * 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。 | ||||
| * 功能优化:function call 兼用中转 API。 | ||||
| * Bug修复:修复部分已知的 Bug。 | ||||
|  | ||||
| ## v3.2.4.1 | ||||
|  | ||||
| * 功能新增:新增 PayJs 支付通道 | ||||
| * Bug修复:紧急修复后台添加用户失败问题 | ||||
| * Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题 | ||||
|   | ||||
							
								
								
									
										214
									
								
								LICENSE
									
									
									
									
									
								
							
							
						
						
									
										214
									
								
								LICENSE
									
									
									
									
									
								
							| @@ -1,21 +1,201 @@ | ||||
| MIT License | ||||
|                                  Apache License | ||||
|                            Version 2.0, January 2004 | ||||
|                         http://www.apache.org/licenses/ | ||||
|  | ||||
| Copyright (c) 2023 RockYang | ||||
|    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|    1. Definitions. | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|       "License" shall mean the terms and conditions for use, reproduction, | ||||
|       and distribution as defined by Sections 1 through 9 of this document. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
|       "Licensor" shall mean the copyright owner or entity authorized by | ||||
|       the copyright owner that is granting the License. | ||||
|  | ||||
|       "Legal Entity" shall mean the union of the acting entity and all | ||||
|       other entities that control, are controlled by, or are under common | ||||
|       control with that entity. For the purposes of this definition, | ||||
|       "control" means (i) the power, direct or indirect, to cause the | ||||
|       direction or management of such entity, whether by contract or | ||||
|       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||
|       outstanding shares, or (iii) beneficial ownership of such entity. | ||||
|  | ||||
|       "You" (or "Your") shall mean an individual or Legal Entity | ||||
|       exercising permissions granted by this License. | ||||
|  | ||||
|       "Source" form shall mean the preferred form for making modifications, | ||||
|       including but not limited to software source code, documentation | ||||
|       source, and configuration files. | ||||
|  | ||||
|       "Object" form shall mean any form resulting from mechanical | ||||
|       transformation or translation of a Source form, including but | ||||
|       not limited to compiled object code, generated documentation, | ||||
|       and conversions to other media types. | ||||
|  | ||||
|       "Work" shall mean the work of authorship, whether in Source or | ||||
|       Object form, made available under the License, as indicated by a | ||||
|       copyright notice that is included in or attached to the work | ||||
|       (an example is provided in the Appendix below). | ||||
|  | ||||
|       "Derivative Works" shall mean any work, whether in Source or Object | ||||
|       form, that is based on (or derived from) the Work and for which the | ||||
|       editorial revisions, annotations, elaborations, or other modifications | ||||
|       represent, as a whole, an original work of authorship. For the purposes | ||||
|       of this License, Derivative Works shall not include works that remain | ||||
|       separable from, or merely link (or bind by name) to the interfaces of, | ||||
|       the Work and Derivative Works thereof. | ||||
|  | ||||
|       "Contribution" shall mean any work of authorship, including | ||||
|       the original version of the Work and any modifications or additions | ||||
|       to that Work or Derivative Works thereof, that is intentionally | ||||
|       submitted to Licensor for inclusion in the Work by the copyright owner | ||||
|       or by an individual or Legal Entity authorized to submit on behalf of | ||||
|       the copyright owner. For the purposes of this definition, "submitted" | ||||
|       means any form of electronic, verbal, or written communication sent | ||||
|       to the Licensor or its representatives, including but not limited to | ||||
|       communication on electronic mailing lists, source code control systems, | ||||
|       and issue tracking systems that are managed by, or on behalf of, the | ||||
|       Licensor for the purpose of discussing and improving the Work, but | ||||
|       excluding communication that is conspicuously marked or otherwise | ||||
|       designated in writing by the copyright owner as "Not a Contribution." | ||||
|  | ||||
|       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||
|       on behalf of whom a Contribution has been received by Licensor and | ||||
|       subsequently incorporated within the Work. | ||||
|  | ||||
|    2. Grant of Copyright License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       copyright license to reproduce, prepare Derivative Works of, | ||||
|       publicly display, publicly perform, sublicense, and distribute the | ||||
|       Work and such Derivative Works in Source or Object form. | ||||
|  | ||||
|    3. Grant of Patent License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       (except as stated in this section) patent license to make, have made, | ||||
|       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||
|       where such license applies only to those patent claims licensable | ||||
|       by such Contributor that are necessarily infringed by their | ||||
|       Contribution(s) alone or by combination of their Contribution(s) | ||||
|       with the Work to which such Contribution(s) was submitted. If You | ||||
|       institute patent litigation against any entity (including a | ||||
|       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||
|       or a Contribution incorporated within the Work constitutes direct | ||||
|       or contributory patent infringement, then any patent licenses | ||||
|       granted to You under this License for that Work shall terminate | ||||
|       as of the date such litigation is filed. | ||||
|  | ||||
|    4. Redistribution. You may reproduce and distribute copies of the | ||||
|       Work or Derivative Works thereof in any medium, with or without | ||||
|       modifications, and in Source or Object form, provided that You | ||||
|       meet the following conditions: | ||||
|  | ||||
|       (a) You must give any other recipients of the Work or | ||||
|           Derivative Works a copy of this License; and | ||||
|  | ||||
|       (b) You must cause any modified files to carry prominent notices | ||||
|           stating that You changed the files; and | ||||
|  | ||||
|       (c) You must retain, in the Source form of any Derivative Works | ||||
|           that You distribute, all copyright, patent, trademark, and | ||||
|           attribution notices from the Source form of the Work, | ||||
|           excluding those notices that do not pertain to any part of | ||||
|           the Derivative Works; and | ||||
|  | ||||
|       (d) If the Work includes a "NOTICE" text file as part of its | ||||
|           distribution, then any Derivative Works that You distribute must | ||||
|           include a readable copy of the attribution notices contained | ||||
|           within such NOTICE file, excluding those notices that do not | ||||
|           pertain to any part of the Derivative Works, in at least one | ||||
|           of the following places: within a NOTICE text file distributed | ||||
|           as part of the Derivative Works; within the Source form or | ||||
|           documentation, if provided along with the Derivative Works; or, | ||||
|           within a display generated by the Derivative Works, if and | ||||
|           wherever such third-party notices normally appear. The contents | ||||
|           of the NOTICE file are for informational purposes only and | ||||
|           do not modify the License. You may add Your own attribution | ||||
|           notices within Derivative Works that You distribute, alongside | ||||
|           or as an addendum to the NOTICE text from the Work, provided | ||||
|           that such additional attribution notices cannot be construed | ||||
|           as modifying the License. | ||||
|  | ||||
|       You may add Your own copyright statement to Your modifications and | ||||
|       may provide additional or different license terms and conditions | ||||
|       for use, reproduction, or distribution of Your modifications, or | ||||
|       for any such Derivative Works as a whole, provided Your use, | ||||
|       reproduction, and distribution of the Work otherwise complies with | ||||
|       the conditions stated in this License. | ||||
|  | ||||
|    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||
|       any Contribution intentionally submitted for inclusion in the Work | ||||
|       by You to the Licensor shall be under the terms and conditions of | ||||
|       this License, without any additional terms or conditions. | ||||
|       Notwithstanding the above, nothing herein shall supersede or modify | ||||
|       the terms of any separate license agreement you may have executed | ||||
|       with Licensor regarding such Contributions. | ||||
|  | ||||
|    6. Trademarks. This License does not grant permission to use the trade | ||||
|       names, trademarks, service marks, or product names of the Licensor, | ||||
|       except as required for reasonable and customary use in describing the | ||||
|       origin of the Work and reproducing the content of the NOTICE file. | ||||
|  | ||||
|    7. Disclaimer of Warranty. Unless required by applicable law or | ||||
|       agreed to in writing, Licensor provides the Work (and each | ||||
|       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||
|       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
|       implied, including, without limitation, any warranties or conditions | ||||
|       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||
|       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||
|       appropriateness of using or redistributing the Work and assume any | ||||
|       risks associated with Your exercise of permissions under this License. | ||||
|  | ||||
|    8. Limitation of Liability. In no event and under no legal theory, | ||||
|       whether in tort (including negligence), contract, or otherwise, | ||||
|       unless required by applicable law (such as deliberate and grossly | ||||
|       negligent acts) or agreed to in writing, shall any Contributor be | ||||
|       liable to You for damages, including any direct, indirect, special, | ||||
|       incidental, or consequential damages of any character arising as a | ||||
|       result of this License or out of the use or inability to use the | ||||
|       Work (including but not limited to damages for loss of goodwill, | ||||
|       work stoppage, computer failure or malfunction, or any and all | ||||
|       other commercial damages or losses), even if such Contributor | ||||
|       has been advised of the possibility of such damages. | ||||
|  | ||||
|    9. Accepting Warranty or Additional Liability. While redistributing | ||||
|       the Work or Derivative Works thereof, You may choose to offer, | ||||
|       and charge a fee for, acceptance of support, warranty, indemnity, | ||||
|       or other liability obligations and/or rights consistent with this | ||||
|       License. However, in accepting such obligations, You may act only | ||||
|       on Your own behalf and on Your sole responsibility, not on behalf | ||||
|       of any other Contributor, and only if You agree to indemnify, | ||||
|       defend, and hold each Contributor harmless for any liability | ||||
|       incurred by, or claims asserted against, such Contributor by reason | ||||
|       of your accepting any such warranty or additional liability. | ||||
|  | ||||
|    END OF TERMS AND CONDITIONS | ||||
|  | ||||
|    APPENDIX: How to apply the Apache License to your work. | ||||
|  | ||||
|       To apply the Apache License to your work, attach the following | ||||
|       boilerplate notice, with the fields enclosed by brackets "[]" | ||||
|       replaced with your own identifying information. (Don't include | ||||
|       the brackets!)  The text should be enclosed in the appropriate | ||||
|       comment syntax for the file format. We also recommend that a | ||||
|       file or class name and description of purpose be included on the | ||||
|       same "printed page" as the copyright notice for easier | ||||
|       identification within third-party archives. | ||||
|  | ||||
|    Copyright [yyyy] [name of copyright owner] | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
|   | ||||
							
								
								
									
										123
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										123
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,124 +1,67 @@ | ||||
| # ChatGPT-Plus | ||||
| # GeekAI | ||||
| > 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||
|  | ||||
| **ChatGPT-PLUS** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure, | ||||
| ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。主要有如下特性: | ||||
| **GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure, | ||||
| ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。 | ||||
|  | ||||
| * 完整的开源系统,前端应用和后台管理系统皆可开箱即用。 | ||||
| * 基于 Websocket 实现,完美的打字机体验。 | ||||
| * 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。 | ||||
| * 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。 | ||||
| * 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。 | ||||
| * 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。 | ||||
| * 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。 | ||||
| * 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI | ||||
| 主要特性: | ||||
|  | ||||
| - 完整的开源系统,前端应用和后台管理系统皆可开箱即用。 | ||||
| - 基于 Websocket 实现,完美的打字机体验。 | ||||
| - 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。 | ||||
| - 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。 | ||||
| - 支持 Suno 文生音乐 | ||||
| - 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。 | ||||
| - 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。 | ||||
| - 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。 | ||||
| - 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI | ||||
|   绘画函数插件。 | ||||
|  | ||||
| ### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus) | ||||
|  | ||||
| - [x] 更友好的 UI 界面 | ||||
| - [x] 支持 Dall-E 文生图功能 | ||||
| - [x] 支持文生思维导图 | ||||
| - [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能 | ||||
| - [x] 支持网站 Logo 版权等信息的修改 | ||||
|  | ||||
| ## 功能截图 | ||||
|  | ||||
| ### PC 端聊天界面 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### AI 对话界面 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### MidJourney 专业绘画界面 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### Stable-Diffusion 专业绘画页面 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 绘图作品展 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### AI应用列表 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 会员充值 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 自动调用函数插件 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 管理后台 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 移动端 Web 页面 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。 | ||||
|  | ||||
| ### 体验地址 | ||||
|  | ||||
| > 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/> | ||||
| > 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/> | ||||
| > **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!** | ||||
|  | ||||
| ## 快速部署 | ||||
|  | ||||
| **演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。 | ||||
|  | ||||
| ```shell | ||||
| bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)" | ||||
| ``` | ||||
|  | ||||
| 最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。 | ||||
|  | ||||
| 目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址 | ||||
|  | ||||
| * 前端访问地址:http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。 | ||||
| * 后台管理地址:http://localhost:8080/admin | ||||
| * 移动端地址:http://localhost:8080/mobile | ||||
| * 初始后台管理账号:admin/admin123 | ||||
| * 初始前端体验账号:18575670125/12345678 | ||||
|  | ||||
| 服务启动成功之后不能立刻使用,需要先登录管理后台 -> API-KEY 去添加一个 OpenAI 或者文心一言,科大讯飞等至少一个平台的 API | ||||
| KEY。 | ||||
|  | ||||
|  | ||||
|  | ||||
| 另外,如果您目前还没有 OpenAI 的 API KEY的,推荐您去 https://gpt.bemore.lol 购买,**无需魔法,高速稳定,且价格还远低于 OpenAI | ||||
| 官方**。 | ||||
| 请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。 | ||||
|  | ||||
| ## 使用须知 | ||||
|  | ||||
| 1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。 | ||||
| 1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。 | ||||
| 2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。 | ||||
|  | ||||
| ## 项目地址 | ||||
|  | ||||
| * Github 地址:https://github.com/yangjian102621/chatgpt-plus | ||||
| * 码云地址:https://gitee.com/blackfox/chatgpt-plus | ||||
| * Github 地址:https://github.com/yangjian102621/geekai | ||||
| * 码云地址:https://gitee.com/blackfox/geekai | ||||
|  | ||||
| ## 客户端下载 | ||||
|  | ||||
| 目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2 | ||||
| 目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2 | ||||
|  | ||||
| ## TODOLIST | ||||
|  | ||||
| * [ ] 支持基于知识库的 AI 问答 | ||||
| * [ ] 会员邀请注册推广功能 | ||||
| * [ ] 文生视频,文生歌曲功能 | ||||
| * [ ] 微信支付功能 | ||||
|  | ||||
| ## 项目文档 | ||||
|  | ||||
| 最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/) | ||||
|  | ||||
| 详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)。 | ||||
| 详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。 | ||||
|  | ||||
| 加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。** | ||||
|  | ||||
| @@ -146,4 +89,4 @@ KEY。 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										1
									
								
								api/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								api/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -18,3 +18,4 @@ data | ||||
| config.toml | ||||
| static/upload  | ||||
| storage.json  | ||||
| res/certs/wechat/apiclient_key.pem | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| SHELL=/usr/bin/env bash | ||||
| NAME := chatgpt-plus | ||||
| NAME := geekai | ||||
| all: amd64 arm64 | ||||
|  | ||||
| amd64: | ||||
|   | ||||
| @@ -5,6 +5,7 @@ StaticDir = "./static" # 静态资源的目录 | ||||
| StaticUrl = "/static" # 静态资源访问 URL | ||||
| AesEncryptKey = "" | ||||
| WeChatBot = false | ||||
| TikaHost = "http://tika:9998" | ||||
|  | ||||
| [Session] | ||||
|   SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换 | ||||
| @@ -17,7 +18,7 @@ WeChatBot = false | ||||
|   DB = 0 | ||||
|  | ||||
| [ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通 | ||||
|   ApiURL = "" | ||||
|   ApiURL = "https://sapi.geekai.me" | ||||
|   AppId = "" | ||||
|   Token = "" | ||||
|  | ||||
| @@ -108,7 +109,8 @@ WeChatBot = false | ||||
|   ApiURL = "https://api.xunhupay.com" | ||||
|   NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify" | ||||
|  | ||||
| [SmtpConfig] # 注意,阿里云服务器禁用了25号端口,所以如果需要使用邮件功能,请别用阿里云服务器 | ||||
| [SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接 | ||||
|   UseTls = false | ||||
|   Host = "smtp.163.com" | ||||
|   Port = 25 | ||||
|   AppName = "极客学长" | ||||
| @@ -122,3 +124,15 @@ WeChatBot = false | ||||
|   PrivateKey = "" # 秘钥 | ||||
|   ApiURL = "https://payjs.cn" | ||||
|   NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的 | ||||
|  | ||||
| # 微信商户支付 | ||||
| [WechatPayConfig] | ||||
|   Enabled = false | ||||
|   AppId = "" # 商户应用ID | ||||
|   MchId = "" # 商户号 | ||||
|   SerialNo = "" # API 证书序列号 | ||||
|   PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内 | ||||
|   ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的 | ||||
|   NotifyURL = "https://ai.r9it.com/api/payment/wechat/notify" # 支付成功异步回调地址,域名改成自己的 | ||||
|   ReturnURL = "" # 支付成功同步回调地址 | ||||
|  | ||||
|   | ||||
| @@ -1,22 +1,29 @@ | ||||
| package core | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/nfnt/resize" | ||||
| 	"golang.org/x/image/webp" | ||||
| 	"gorm.io/gorm" | ||||
| 	"image" | ||||
| 	"image/jpeg" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"runtime/debug" | ||||
| @@ -28,15 +35,7 @@ type AppServer struct { | ||||
| 	Debug     bool | ||||
| 	Config    *types.AppConfig | ||||
| 	Engine    *gin.Engine | ||||
| 	ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message | ||||
|  | ||||
| 	SysConfig *types.SystemConfig // system config cache | ||||
|  | ||||
| 	// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 | ||||
| 	// 防止第三方直接连接 socket 调用 OpenAI API | ||||
| 	ChatSession   *types.LMap[string, *types.ChatSession] //map[sessionId]UserId | ||||
| 	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合 | ||||
| 	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function | ||||
| } | ||||
|  | ||||
| func NewServer(appConfig *types.AppConfig) *AppServer { | ||||
| @@ -46,10 +45,6 @@ func NewServer(appConfig *types.AppConfig) *AppServer { | ||||
| 		Debug:  false, | ||||
| 		Config: appConfig, | ||||
| 		Engine: gin.Default(), | ||||
| 		ChatContexts:  types.NewLMap[string, []types.Message](), | ||||
| 		ChatSession:   types.NewLMap[string, *types.ChatSession](), | ||||
| 		ChatClients:   types.NewLMap[string, *types.WsClient](), | ||||
| 		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -88,7 +83,7 @@ func errorHandler(c *gin.Context) { | ||||
| 		if r := recover(); r != nil { | ||||
| 			logger.Errorf("Handler Panic: %v", r) | ||||
| 			debug.PrintStack() | ||||
| 			c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) | ||||
| 			c.JSON(http.StatusBadRequest, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) | ||||
| 			c.Abort() | ||||
| 		} | ||||
| 	}() | ||||
| @@ -144,7 +139,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { | ||||
|  | ||||
| 		if tokenString == "" { | ||||
| 			if needLogin(c) { | ||||
| 				resp.ERROR(c, "You should put Authorization in request headers") | ||||
| 				resp.NotAuth(c, "You should put Authorization in request headers") | ||||
| 				c.Abort() | ||||
| 				return | ||||
| 			} else { // 直接放行 | ||||
| @@ -200,10 +195,13 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { | ||||
|  | ||||
| func needLogin(c *gin.Context) bool { | ||||
| 	if c.Request.URL.Path == "/api/user/login" || | ||||
| 		c.Request.URL.Path == "/api/user/logout" || | ||||
| 		c.Request.URL.Path == "/api/user/resetPass" || | ||||
| 		c.Request.URL.Path == "/api/admin/login" || | ||||
| 		c.Request.URL.Path == "/api/admin/logout" || | ||||
| 		c.Request.URL.Path == "/api/admin/login/captcha" || | ||||
| 		c.Request.URL.Path == "/api/user/register" || | ||||
| 		c.Request.URL.Path == "/api/user/session" || | ||||
| 		c.Request.URL.Path == "/api/chat/history" || | ||||
| 		c.Request.URL.Path == "/api/chat/detail" || | ||||
| 		c.Request.URL.Path == "/api/chat/list" || | ||||
| @@ -215,13 +213,26 @@ func needLogin(c *gin.Context) bool { | ||||
| 		c.Request.URL.Path == "/api/invite/hits" || | ||||
| 		c.Request.URL.Path == "/api/sd/imgWall" || | ||||
| 		c.Request.URL.Path == "/api/sd/client" || | ||||
| 		c.Request.URL.Path == "/api/config/get" || | ||||
| 		c.Request.URL.Path == "/api/dall/imgWall" || | ||||
| 		c.Request.URL.Path == "/api/dall/client" || | ||||
| 		c.Request.URL.Path == "/api/product/list" || | ||||
| 		c.Request.URL.Path == "/api/menu/list" || | ||||
| 		c.Request.URL.Path == "/api/markMap/client" || | ||||
| 		c.Request.URL.Path == "/api/payment/alipay/notify" || | ||||
| 		c.Request.URL.Path == "/api/payment/hupipay/notify" || | ||||
| 		c.Request.URL.Path == "/api/payment/payjs/notify" || | ||||
| 		c.Request.URL.Path == "/api/payment/wechat/notify" || | ||||
| 		c.Request.URL.Path == "/api/payment/doPay" || | ||||
| 		c.Request.URL.Path == "/api/payment/payWays" || | ||||
| 		c.Request.URL.Path == "/api/suno/client" || | ||||
| 		c.Request.URL.Path == "/api/suno/Detail" || | ||||
| 		c.Request.URL.Path == "/api/suno/play" || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/test") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/config/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/function/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/payment/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/static/") { | ||||
| 		return false | ||||
| 	} | ||||
| @@ -326,6 +337,10 @@ func staticResourceMiddleware() gin.HandlerFunc { | ||||
|  | ||||
| 			// 解码图片 | ||||
| 			img, _, err := image.Decode(file) | ||||
| 			// for .webp image | ||||
| 			if err != nil { | ||||
| 				img, err = webp.Decode(file) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				c.String(http.StatusInternalServerError, "Error decoding image") | ||||
| 				return | ||||
| @@ -342,7 +357,9 @@ func staticResourceMiddleware() gin.HandlerFunc { | ||||
| 			var buffer bytes.Buffer | ||||
| 			err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality}) | ||||
| 			if err != nil { | ||||
| 				log.Fatal(err) | ||||
| 				logger.Error(err) | ||||
| 				c.String(http.StatusInternalServerError, err.Error()) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// 设置图片缓存有效期为一年 (365天) | ||||
|   | ||||
| @@ -1,10 +1,17 @@ | ||||
| package core | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/utils" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/utils" | ||||
| 	"os" | ||||
|  | ||||
| 	"github.com/BurntSushi/toml" | ||||
| @@ -23,7 +30,7 @@ func NewDefaultConfig() *types.AppConfig { | ||||
| 			SecretKey: utils.RandString(64), | ||||
| 			MaxAge:    86400, | ||||
| 		}, | ||||
| 		ApiConfig: types.ChatPlusApiConfig{}, | ||||
| 		ApiConfig: types.ApiConfig{}, | ||||
| 		OSS: types.OSSConfig{ | ||||
| 			Active: "local", | ||||
| 			Local: types.LocalStorageConfig{ | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| // ApiRequest API 请求实体 | ||||
| type ApiRequest struct { | ||||
| 	Model       string        `json:"model,omitempty"` // 兼容百度文心一言 | ||||
| @@ -8,7 +15,7 @@ type ApiRequest struct { | ||||
| 	Stream      bool          `json:"stream"` | ||||
| 	Messages    []interface{} `json:"messages,omitempty"` | ||||
| 	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM | ||||
| 	Tools       []interface{} `json:"tools,omitempty"` | ||||
| 	Tools       []Tool        `json:"tools,omitempty"` | ||||
| 	Functions   []interface{} `json:"functions,omitempty"` // 兼容中转平台 | ||||
|  | ||||
| 	ToolChoice string `json:"tool_choice,omitempty"` | ||||
| @@ -46,22 +53,22 @@ type Delta struct { | ||||
| // ChatSession 聊天会话对象 | ||||
| type ChatSession struct { | ||||
| 	SessionId string    `json:"session_id"` | ||||
| 	UserId    uint      `json:"user_id"` | ||||
| 	ClientIP  string    `json:"client_ip"` // 客户端 IP | ||||
| 	Username  string    `json:"username"`  // 当前登录的 username | ||||
| 	UserId    uint      `json:"user_id"`   // 当前登录的 user ID | ||||
| 	ChatId    string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段 | ||||
| 	Model     ChatModel `json:"model"`     // GPT 模型 | ||||
| } | ||||
|  | ||||
| type ChatModel struct { | ||||
| 	Id          uint    `json:"id"` | ||||
| 	Platform    Platform `json:"platform"` | ||||
| 	Platform    string  `json:"platform"` | ||||
| 	Name        string  `json:"name"` | ||||
| 	Value       string  `json:"value"` | ||||
| 	Power       int     `json:"power"` | ||||
| 	MaxTokens   int     `json:"max_tokens"`  // 最大响应长度 | ||||
| 	MaxContext  int     `json:"max_context"` // 最大上下文长度 | ||||
| 	Temperature float32 `json:"temperature"` // 模型温度 | ||||
| 	KeyId       int     `json:"key_id"`      // 绑定 API KEY | ||||
| } | ||||
|  | ||||
| type ApiError struct { | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/gorilla/websocket" | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| ) | ||||
| @@ -14,7 +21,7 @@ type AppConfig struct { | ||||
| 	StaticDir      string                  // 静态资源目录 | ||||
| 	StaticUrl      string                  // 静态资源 URL | ||||
| 	Redis          RedisConfig             // redis 连接信息 | ||||
| 	ApiConfig      ChatPlusApiConfig       // ChatPlus API authorization configs | ||||
| 	ApiConfig      ApiConfig               // ChatPlus API authorization configs | ||||
| 	SMS            SMSConfig               // send mobile message config | ||||
| 	OSS            OSSConfig               // OSS config | ||||
| 	MjProxyConfigs []MjProxyConfig         // MJ proxy config | ||||
| @@ -23,13 +30,16 @@ type AppConfig struct { | ||||
| 	SdConfigs      []StableDiffusionConfig // sd AI draw service pool | ||||
|  | ||||
| 	XXLConfig       XXLConfig | ||||
| 	AlipayConfig  AlipayConfig | ||||
| 	HuPiPayConfig HuPiPayConfig | ||||
| 	AlipayConfig    AlipayConfig    // 支付宝支付渠道配置 | ||||
| 	HuPiPayConfig   HuPiPayConfig   // 虎皮椒支付配置 | ||||
| 	SmtpConfig      SmtpConfig      // 邮件发送配置 | ||||
| 	JPayConfig      JPayConfig      // payjs 支付配置 | ||||
| 	WechatPayConfig WechatPayConfig // 微信支付渠道配置 | ||||
| 	TikaHost        string          // TiKa 服务器地址 | ||||
| } | ||||
|  | ||||
| type SmtpConfig struct { | ||||
| 	UseTls   bool // 是否使用 TLS 发送 | ||||
| 	Host     string | ||||
| 	Port     int | ||||
| 	AppName  string // 应用名称 | ||||
| @@ -37,7 +47,7 @@ type SmtpConfig struct { | ||||
| 	Password string // 发件人邮箱密码 | ||||
| } | ||||
|  | ||||
| type ChatPlusApiConfig struct { | ||||
| type ApiConfig struct { | ||||
| 	ApiURL string | ||||
| 	AppId  string | ||||
| 	Token  string | ||||
| @@ -77,6 +87,17 @@ type AlipayConfig struct { | ||||
| 	ReturnURL       string // 支付成功返回地址 | ||||
| } | ||||
|  | ||||
| type WechatPayConfig struct { | ||||
| 	Enabled    bool   // 是否启用该支付通道 | ||||
| 	AppId      string // 公众号的APPID,如:wxd678efh567hg6787 | ||||
| 	MchId      string // 直连商户的商户号,由微信支付生成并下发 | ||||
| 	SerialNo   string // 商户证书的证书序列号 | ||||
| 	PrivateKey string // 用户私钥文件路径 | ||||
| 	ApiV3Key   string // API V3 秘钥 | ||||
| 	NotifyURL  string // 异步通知回调 | ||||
| 	ReturnURL  string // 支付成功返回地址 | ||||
| } | ||||
|  | ||||
| type HuPiPayConfig struct { //虎皮椒第四方支付配置 | ||||
| 	Enabled   bool   // 是否启用该支付通道 | ||||
| 	Name      string // 支付名称,如:wechat/alipay | ||||
| @@ -114,29 +135,37 @@ type RedisConfig struct { | ||||
| 	DB       int | ||||
| } | ||||
|  | ||||
| // LicenseKey 存储许可证书的 KEY | ||||
| const LicenseKey = "Geek-AI-License" | ||||
|  | ||||
| type License struct { | ||||
| 	Key       string        `json:"key"`        // 许可证书密钥 | ||||
| 	MachineId string        `json:"machine_id"` // 机器码 | ||||
| 	ExpiredAt int64         `json:"expired_at"` // 过期时间 | ||||
| 	IsActive  bool          `json:"is_active"`  // 是否激活 | ||||
| 	Configs   LicenseConfig `json:"configs"` | ||||
| } | ||||
|  | ||||
| type LicenseConfig struct { | ||||
| 	UserNum int  `json:"user_num"` // 用户数量 | ||||
| 	DeCopy  bool `json:"de_copy"`  // 去版权 | ||||
| } | ||||
|  | ||||
| func (c RedisConfig) Url() string { | ||||
| 	return fmt.Sprintf("%s:%d", c.Host, c.Port) | ||||
| } | ||||
|  | ||||
| type Platform string | ||||
|  | ||||
| const OpenAI = Platform("OpenAI") | ||||
| const Azure = Platform("Azure") | ||||
| const ChatGLM = Platform("ChatGLM") | ||||
| const Baidu = Platform("Baidu") | ||||
| const XunFei = Platform("XunFei") | ||||
| const QWen = Platform("QWen") | ||||
|  | ||||
| type SystemConfig struct { | ||||
| 	Title         string `json:"title,omitempty"` | ||||
| 	AdminTitle    string `json:"admin_title,omitempty"` | ||||
| 	Title         string `json:"title,omitempty"`       // 网站标题 | ||||
| 	Slogan        string `json:"slogan,omitempty"`      // 网站 slogan | ||||
| 	AdminTitle    string `json:"admin_title,omitempty"` // 管理后台标题 | ||||
| 	Logo          string `json:"logo,omitempty"` | ||||
| 	InitPower     int    `json:"init_power,omitempty"`      // 新用户注册赠送算力值 | ||||
| 	DailyPower    int    `json:"daily_power,omitempty"`     // 每日赠送算力 | ||||
| 	InvitePower   int    `json:"invite_power,omitempty"`    // 邀请新用户赠送算力值 | ||||
| 	VipMonthPower int    `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值 | ||||
|  | ||||
| 	RegisterWays    []string `json:"register_ways,omitempty"`    // 注册方式:支持手机,邮箱注册,账号密码注册 | ||||
| 	RegisterWays    []string `json:"register_ways,omitempty"`    // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册 | ||||
| 	EnabledRegister bool     `json:"enabled_register,omitempty"` // 是否开放注册 | ||||
|  | ||||
| 	RewardImg     string  `json:"reward_img,omitempty"`     // 众筹收款二维码地址 | ||||
| @@ -144,16 +173,23 @@ type SystemConfig struct { | ||||
| 	PowerPrice    float64 `json:"power_price,omitempty"`    // 算力单价 | ||||
|  | ||||
| 	OrderPayTimeout int    `json:"order_pay_timeout,omitempty"` //订单支付超时时间 | ||||
| 	VipInfoText     string `json:"vip_info_text"`               // 会员页面充值说明 | ||||
| 	VipInfoText     string `json:"vip_info_text,omitempty"`     // 会员页面充值说明 | ||||
| 	DefaultModels   []int  `json:"default_models,omitempty"`    // 默认开通的 AI 模型 | ||||
|  | ||||
| 	MjPower       int `json:"mj_power,omitempty"`        // MJ 绘画消耗算力 | ||||
| 	MjActionPower int `json:"mj_action_power"`      // MJ 操作(放大,变换)消耗算力 | ||||
| 	MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 | ||||
| 	SdPower       int `json:"sd_power,omitempty"`        // SD 绘画消耗算力 | ||||
| 	DallPower     int `json:"dall_power,omitempty"`      // DALLE3 绘图消耗算力 | ||||
| 	SunoPower     int `json:"suno_power,omitempty"`      // Suno 生成歌曲消耗算力 | ||||
|  | ||||
| 	WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 | ||||
|  | ||||
| 	EnableContext bool `json:"enable_context,omitempty"` | ||||
| 	ContextDeep   int  `json:"context_deep,omitempty"` | ||||
|  | ||||
| 	SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词 | ||||
|  | ||||
| 	IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片 | ||||
| 	IndexNavs  []int  `json:"index_navs"`   // 首页显示的导航菜单 | ||||
| 	Copyright  string `json:"copyright"`    // 版权信息 | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| type ToolCall struct { | ||||
| 	Type     string `json:"type"` | ||||
| 	Function struct { | ||||
| @@ -8,19 +15,13 @@ type ToolCall struct { | ||||
| 	} `json:"function"` | ||||
| } | ||||
|  | ||||
| type Tool struct { | ||||
| 	Type     string   `json:"type"` | ||||
| 	Function Function `json:"function"` | ||||
| } | ||||
|  | ||||
| type Function struct { | ||||
| 	Name        string                 `json:"name"` | ||||
| 	Description string                 `json:"description"` | ||||
| 	Parameters  Parameters `json:"parameters"` | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	Type       string              `json:"type"` | ||||
| 	Required   []string            `json:"required"` | ||||
| 	Properties map[string]Property `json:"properties"` | ||||
| } | ||||
|  | ||||
| type Property struct { | ||||
| 	Type        string `json:"type"` | ||||
| 	Description string `json:"description"` | ||||
| 	Parameters  map[string]interface{} `json:"parameters"` | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"sync" | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| type OrderStatus int | ||||
|  | ||||
| const ( | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| type OSSConfig struct { | ||||
| 	Active string | ||||
| 	Local  LocalStorageConfig | ||||
|   | ||||
| @@ -1,11 +1,17 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| const LoginUserID = "LOGIN_USER_ID" | ||||
| const LoginUserCache = "LOGIN_USER_CACHE" | ||||
|  | ||||
| const UserAuthHeader = "Authorization" | ||||
| const AdminAuthHeader = "Admin-Authorization" | ||||
| const ChatTokenHeader = "Chat-Token" | ||||
|  | ||||
| // Session configs struct | ||||
| type Session struct { | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| type SMSConfig struct { | ||||
| 	Active string | ||||
| 	Ali    SmsConfigAli | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| // TaskType 任务类别 | ||||
| type TaskType string | ||||
|  | ||||
| @@ -21,10 +28,11 @@ type MjTask struct { | ||||
| 	TaskId      string   `json:"task_id"` | ||||
| 	ImgArr      []string `json:"img_arr"` | ||||
| 	ChannelId   string   `json:"channel_id"` | ||||
| 	SessionId   string   `json:"session_id"` | ||||
| 	Type        TaskType `json:"type"` | ||||
| 	UserId      int      `json:"user_id"` | ||||
| 	Prompt      string   `json:"prompt,omitempty"` | ||||
| 	NegPrompt   string   `json:"neg_prompt,omitempty"` | ||||
| 	Params      string   `json:"full_prompt"` | ||||
| 	Index       int      `json:"index,omitempty"` | ||||
| 	MessageId   string   `json:"message_id,omitempty"` | ||||
| 	MessageHash string   `json:"message_hash,omitempty"` | ||||
| @@ -33,7 +41,6 @@ type MjTask struct { | ||||
|  | ||||
| type SdTask struct { | ||||
| 	Id         int          `json:"id"` // job 数据库ID | ||||
| 	SessionId  string       `json:"session_id"` | ||||
| 	Type       TaskType     `json:"type"` | ||||
| 	UserId     int          `json:"user_id"` | ||||
| 	Params     SdTaskParams `json:"params"` | ||||
| @@ -43,9 +50,10 @@ type SdTask struct { | ||||
| type SdTaskParams struct { | ||||
| 	TaskId       string  `json:"task_id"` | ||||
| 	Prompt       string  `json:"prompt"`     // 提示词 | ||||
| 	NegativePrompt string  `json:"negative_prompt"` // 反向提示词 | ||||
| 	NegPrompt    string  `json:"neg_prompt"` // 反向提示词 | ||||
| 	Steps        int     `json:"steps"`      // 迭代步数,默认20 | ||||
| 	Sampler      string  `json:"sampler"`    // 采样器 | ||||
| 	Scheduler    string  `json:"scheduler"`  // 采样调度 | ||||
| 	FaceFix      bool    `json:"face_fix"`   // 面部修复 | ||||
| 	CfgScale     float32 `json:"cfg_scale"`  //引导系数,默认 7 | ||||
| 	Seed         int64   `json:"seed"`       // 随机数种子 | ||||
| @@ -57,3 +65,32 @@ type SdTaskParams struct { | ||||
| 	HdScaleAlg   string  `json:"hd_scale_alg"`   // 放大算法 | ||||
| 	HdSteps      int     `json:"hd_steps"`       // 高清修复迭代步数 | ||||
| } | ||||
|  | ||||
| // DallTask DALL-E task | ||||
| type DallTask struct { | ||||
| 	JobId   uint   `json:"job_id"` | ||||
| 	UserId  uint   `json:"user_id"` | ||||
| 	Prompt  string `json:"prompt"` | ||||
| 	N       int    `json:"n"` | ||||
| 	Quality string `json:"quality"` | ||||
| 	Size    string `json:"size"` | ||||
| 	Style   string `json:"style"` | ||||
|  | ||||
| 	Power int `json:"power"` | ||||
| } | ||||
|  | ||||
| type SunoTask struct { | ||||
| 	Id           uint   `json:"id"` | ||||
| 	Channel      string `json:"channel"` | ||||
| 	UserId       int    `json:"user_id"` | ||||
| 	Type         int    `json:"type"` | ||||
| 	TaskId       string `json:"task_id"` | ||||
| 	Title        string `json:"title"` | ||||
| 	RefTaskId    string `json:"ref_task_id"` | ||||
| 	RefSongId    string `json:"ref_song_id"` | ||||
| 	Prompt       string `json:"prompt"` // 提示词/歌词 | ||||
| 	Tags         string `json:"tags"` | ||||
| 	Model        string `json:"model"` | ||||
| 	Instrumental bool   `json:"instrumental"` // 是否纯音乐 | ||||
| 	ExtendSecs   int    `json:"extend_secs"`  // 延长秒杀 | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package types | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| // BizVo 业务返回 VO | ||||
| type BizVo struct { | ||||
| 	Code     BizCode     `json:"code"` | ||||
| @@ -15,13 +22,14 @@ type WsMessage struct { | ||||
| 	Type    WsMsgType   `json:"type"` // 消息类别,start, end, img | ||||
| 	Content interface{} `json:"content"` | ||||
| } | ||||
|  | ||||
| type WsMsgType string | ||||
|  | ||||
| const ( | ||||
| 	WsStart  = WsMsgType("start") | ||||
| 	WsMiddle = WsMsgType("middle") | ||||
| 	WsEnd    = WsMsgType("end") | ||||
| 	WsMjImg  = WsMsgType("mj") | ||||
| 	WsErr    = WsMsgType("error") | ||||
| ) | ||||
|  | ||||
| type BizCode int | ||||
| @@ -29,11 +37,9 @@ type BizCode int | ||||
| const ( | ||||
| 	Success       = BizCode(0) | ||||
| 	Failed        = BizCode(1) | ||||
| 	NotAuthorized = BizCode(400) // 未授权 | ||||
| 	NotPermission = BizCode(403) // 没有权限 | ||||
| 	NotAuthorized = BizCode(401) // 未授权 | ||||
|  | ||||
| 	OkMsg       = "Success" | ||||
| 	ErrorMsg    = "系统开小差了" | ||||
| 	InvalidArgs = "非法参数或参数解析失败" | ||||
| 	NoData      = "No Data" | ||||
| ) | ||||
|   | ||||
							
								
								
									
										59
									
								
								api/go.mod
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								api/go.mod
									
									
									
									
									
								
							| @@ -1,6 +1,8 @@ | ||||
| module chatplus | ||||
| module geekai | ||||
|  | ||||
| go 1.19 | ||||
| go 1.21 | ||||
|  | ||||
| toolchain go1.22.4 | ||||
|  | ||||
| require ( | ||||
| 	github.com/BurntSushi/toml v1.1.0 | ||||
| @@ -17,7 +19,6 @@ require ( | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 | ||||
| 	github.com/qiniu/go-sdk/v7 v7.17.1 | ||||
| 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e | ||||
| 	github.com/smartwalle/alipay/v3 v3.2.15 | ||||
| 	go.uber.org/zap v1.23.0 | ||||
| 	gopkg.in/natefinch/lumberjack.v2 v2.2.1 | ||||
| 	gorm.io/driver/mysql v1.4.7 | ||||
| @@ -26,19 +27,37 @@ require ( | ||||
| require github.com/xxl-job/xxl-job-executor-go v1.2.0 | ||||
|  | ||||
| require ( | ||||
| 	github.com/mojocn/base64Captcha v1.3.1 | ||||
| 	github.com/go-pay/gopay v1.5.101 | ||||
| 	github.com/google/go-tika v0.3.1 | ||||
| 	github.com/microcosm-cc/bluemonday v1.0.26 | ||||
| 	github.com/mojocn/base64Captcha v1.3.6 | ||||
| 	github.com/shirou/gopsutil v3.21.11+incompatible | ||||
| 	github.com/shopspring/decimal v1.3.1 | ||||
| 	github.com/syndtr/goleveldb v1.0.0 | ||||
| 	golang.org/x/image v0.15.0 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/aymerick/douceur v0.2.0 // indirect | ||||
| 	github.com/go-ole/go-ole v1.2.6 // indirect | ||||
| 	github.com/go-pay/crypto v0.0.1 // indirect | ||||
| 	github.com/go-pay/errgroup v0.0.2 // indirect | ||||
| 	github.com/go-pay/util v0.0.2 // indirect | ||||
| 	github.com/go-pay/xlog v0.0.2 // indirect | ||||
| 	github.com/go-pay/xtime v0.0.2 // indirect | ||||
| 	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect | ||||
| 	golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect | ||||
| 	github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect | ||||
| 	github.com/gorilla/css v1.0.0 // indirect | ||||
| 	github.com/tklauser/go-sysconf v0.3.13 // indirect | ||||
| 	github.com/tklauser/numcpus v0.7.0 // indirect | ||||
| 	github.com/yusufpapurcu/wmi v1.2.4 // indirect | ||||
| 	go.uber.org/mock v0.4.0 // indirect | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/andybalholm/brotli v1.0.4 // indirect | ||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.2.0 // indirect | ||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | ||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||
| 	github.com/dlclark/regexp2 v1.8.1 // indirect | ||||
| @@ -49,7 +68,6 @@ require ( | ||||
| 	github.com/go-sql-driver/mysql v1.7.0 // indirect | ||||
| 	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| 	github.com/golang/mock v1.6.0 // indirect | ||||
| 	github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect | ||||
| 	github.com/google/uuid v1.3.0 // indirect | ||||
| 	github.com/hashicorp/errwrap v1.1.0 // indirect | ||||
| @@ -66,26 +84,21 @@ require ( | ||||
| 	github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| 	github.com/quic-go/qpack v0.4.0 // indirect | ||||
| 	github.com/quic-go/qtls-go1-19 v0.3.2 // indirect | ||||
| 	github.com/quic-go/qtls-go1-20 v0.2.2 // indirect | ||||
| 	github.com/quic-go/quic-go v0.35.1 // indirect | ||||
| 	github.com/quic-go/quic-go v0.45.0 // indirect | ||||
| 	github.com/refraction-networking/utls v1.3.2 // indirect | ||||
| 	github.com/rs/xid v1.5.0 // indirect | ||||
| 	github.com/sirupsen/logrus v1.9.3 // indirect | ||||
| 	github.com/smartwalle/ncrypto v1.0.2 // indirect | ||||
| 	github.com/smartwalle/ngx v1.0.6 // indirect | ||||
| 	github.com/smartwalle/nsign v1.0.8 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	go.uber.org/dig v1.16.1 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect | ||||
| 	golang.org/x/mod v0.11.0 // indirect | ||||
| 	golang.org/x/net v0.14.0 // indirect | ||||
| 	golang.org/x/sync v0.3.0 // indirect | ||||
| 	golang.org/x/text v0.12.0 // indirect | ||||
| 	golang.org/x/time v0.3.0 // indirect | ||||
| 	golang.org/x/tools v0.10.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect | ||||
| 	golang.org/x/mod v0.17.0 // indirect | ||||
| 	golang.org/x/net v0.25.0 // indirect | ||||
| 	golang.org/x/sync v0.7.0 // indirect | ||||
| 	golang.org/x/text v0.15.0 // indirect | ||||
| 	golang.org/x/time v0.5.0 // indirect | ||||
| 	golang.org/x/tools v0.21.0 // indirect | ||||
| 	google.golang.org/protobuf v1.33.0 // indirect | ||||
| 	gopkg.in/ini.v1 v1.67.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
| @@ -104,7 +117,7 @@ require ( | ||||
| 	go.uber.org/atomic v1.9.0 // indirect | ||||
| 	go.uber.org/fx v1.19.3 | ||||
| 	go.uber.org/multierr v1.6.0 // indirect | ||||
| 	golang.org/x/crypto v0.12.0 | ||||
| 	golang.org/x/sys v0.11.0 // indirect | ||||
| 	golang.org/x/crypto v0.23.0 | ||||
| 	golang.org/x/sys v0.20.0 // indirect | ||||
| 	gorm.io/gorm v1.25.1 | ||||
| ) | ||||
|   | ||||
							
								
								
									
										170
									
								
								api/go.sum
									
									
									
									
									
								
							
							
						
						
									
										170
									
								
								api/go.sum
									
									
									
									
									
								
							| @@ -6,12 +6,15 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw | ||||
| github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8= | ||||
| github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= | ||||
| github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= | ||||
| github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= | ||||
| github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= | ||||
| github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= | ||||
| github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= | ||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | ||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | ||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | ||||
| github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= | ||||
| github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= | ||||
| @@ -27,7 +30,9 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp | ||||
| github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= | ||||
| github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU= | ||||
| github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= | ||||
| github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= | ||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | ||||
| github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | ||||
| github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= | ||||
| @@ -39,8 +44,24 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU | ||||
| github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= | ||||
| github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= | ||||
| github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= | ||||
| github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= | ||||
| github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= | ||||
| github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= | ||||
| github.com/go-pay/crypto v0.0.1 h1:B6InT8CLfSLc6nGRVx9VMJRBBazFMjr293+jl0lLXUY= | ||||
| github.com/go-pay/crypto v0.0.1/go.mod h1:41oEIvHMKbNcYlWUlRWtsnC6+ASgh7u29z0gJXe5bes= | ||||
| github.com/go-pay/errgroup v0.0.2 h1:5mZMdm0TDClDm2S3G0/sm0f8AuQRtz0dOrTHDR9R8Cc= | ||||
| github.com/go-pay/errgroup v0.0.2/go.mod h1:0+4b8mvFMS71MIzsaC+gVvB4x37I93lRb2dqrwuU8x8= | ||||
| github.com/go-pay/gopay v1.5.101 h1:rVb+sfv6hiQtknAlZnTTLvU27NvFJ4p0yglN/vPpGXI= | ||||
| github.com/go-pay/gopay v1.5.101/go.mod h1:AW4Yj8jDZX9BM1/GTLTY1Gy5SHjiq8kQvG5sBTN2sxI= | ||||
| github.com/go-pay/util v0.0.2 h1:goJ4f6kNY5zzdtg1Cj8oWC+Cw7bfg/qq2rJangMAb9U= | ||||
| github.com/go-pay/util v0.0.2/go.mod h1:qM8VbyF1n7YAPZBSJONSPMPsPedhUTktewUAdf1AjPg= | ||||
| github.com/go-pay/xlog v0.0.2 h1:kUg5X8/5VZAPDg1J5eGjA3MG0/H5kK6Ew0dW/Bycsws= | ||||
| github.com/go-pay/xlog v0.0.2/go.mod h1:DbjMADPK4+Sjxj28ekK9goqn4zmyY4hql/zRiab+S9E= | ||||
| github.com/go-pay/xtime v0.0.2 h1:7YR4/iuELsEHpJ6LUO0SVK80hQxDO9MLCfuVYIiTCRM= | ||||
| github.com/go-pay/xtime v0.0.2/go.mod h1:W1yRbJaSt4CSBcdAtLBQ8xajiN/Pl5hquGczUcUE9xE= | ||||
| github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= | ||||
| github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= | ||||
| github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= | ||||
| github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= | ||||
| github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= | ||||
| github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= | ||||
| @@ -65,17 +86,22 @@ github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJ | ||||
| github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= | ||||
| github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= | ||||
| github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= | ||||
| github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= | ||||
| github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | ||||
| github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= | ||||
| github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= | ||||
| github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= | ||||
| github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= | ||||
| github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= | ||||
| github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA= | ||||
| github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= | ||||
| github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= | ||||
| github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= | ||||
| github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||
| github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= | ||||
| github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= | ||||
| github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||
| github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||
| github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= | ||||
| @@ -83,6 +109,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY | ||||
| github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= | ||||
| github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= | ||||
| github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= | ||||
| github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | ||||
| github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I= | ||||
| github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| @@ -116,6 +143,8 @@ github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259 | ||||
| github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs= | ||||
| github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= | ||||
| github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= | ||||
| github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= | ||||
| github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs= | ||||
| github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= | ||||
| github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= | ||||
| github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc= | ||||
| @@ -128,15 +157,21 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ | ||||
| github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= | ||||
| github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= | ||||
| github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= | ||||
| github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0= | ||||
| github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY= | ||||
| github.com/mojocn/base64Captcha v1.3.6 h1:gZEKu1nsKpttuIAQgWHO+4Mhhls8cAKyiV2Ew03H+Tw= | ||||
| github.com/mojocn/base64Captcha v1.3.6/go.mod h1:i5CtHvm+oMbj1UzEPXaA8IH/xHFZ3DGY3Wh3dBpZ28E= | ||||
| github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= | ||||
| github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= | ||||
| github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | ||||
| github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= | ||||
| github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= | ||||
| github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= | ||||
| github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | ||||
| github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= | ||||
| github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= | ||||
| github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= | ||||
| github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= | ||||
| github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= | ||||
| github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= | ||||
| github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A= | ||||
| github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU= | ||||
| github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | ||||
| @@ -154,12 +189,8 @@ github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYX | ||||
| github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= | ||||
| github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= | ||||
| github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= | ||||
| github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= | ||||
| github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= | ||||
| github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= | ||||
| github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= | ||||
| github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= | ||||
| github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= | ||||
| github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE= | ||||
| github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI= | ||||
| github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= | ||||
| github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= | ||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | ||||
| @@ -167,20 +198,14 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA | ||||
| github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= | ||||
| github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= | ||||
| github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= | ||||
| github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= | ||||
| github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= | ||||
| github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= | ||||
| github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= | ||||
| github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= | ||||
| github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= | ||||
| github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= | ||||
| github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= | ||||
| github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8= | ||||
| github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE= | ||||
| github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI= | ||||
| github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk= | ||||
| github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps= | ||||
| github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0= | ||||
| github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo= | ||||
| github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc= | ||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||
| github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | ||||
| github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= | ||||
| @@ -193,6 +218,14 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o | ||||
| github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | ||||
| github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= | ||||
| github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||||
| github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= | ||||
| github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= | ||||
| github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= | ||||
| github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= | ||||
| github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= | ||||
| github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= | ||||
| github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY= | ||||
| github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | ||||
| github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= | ||||
| @@ -203,8 +236,9 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d | ||||
| github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= | ||||
| github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ= | ||||
| github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo= | ||||
| github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= | ||||
| github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= | ||||
| github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= | ||||
| github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= | ||||
| go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= | ||||
| go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= | ||||
| go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= | ||||
| @@ -213,6 +247,9 @@ go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk= | ||||
| go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA= | ||||
| go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM= | ||||
| go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= | ||||
| go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= | ||||
| go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= | ||||
| go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= | ||||
| go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= | ||||
| go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= | ||||
| go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= | ||||
| @@ -221,38 +258,43 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu | ||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= | ||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= | ||||
| golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= | ||||
| golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= | ||||
| golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= | ||||
| golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= | ||||
| golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ= | ||||
| golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= | ||||
| golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= | ||||
| golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= | ||||
| golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= | ||||
| golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= | ||||
| golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= | ||||
| golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= | ||||
| golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= | ||||
| golang.org/x/image v0.13.0/go.mod h1:6mmbMOeV28HuMTgA6OSRkdXKYw/t5W9Uwn2Yv1r3Yxk= | ||||
| golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= | ||||
| golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= | ||||
| golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= | ||||
| golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= | ||||
| golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= | ||||
| golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||||
| golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= | ||||
| golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= | ||||
| golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= | ||||
| golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | ||||
| golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= | ||||
| golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= | ||||
| golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= | ||||
| golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= | ||||
| golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= | ||||
| golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= | ||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | ||||
| golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= | ||||
| golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= | ||||
| golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= | ||||
| golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= | ||||
| golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= | ||||
| golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= | ||||
| golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | ||||
| golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| @@ -261,45 +303,55 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc | ||||
| golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= | ||||
| golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= | ||||
| golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= | ||||
| golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= | ||||
| golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= | ||||
| golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= | ||||
| golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= | ||||
| golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= | ||||
| golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= | ||||
| golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= | ||||
| golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||
| golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= | ||||
| golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= | ||||
| golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= | ||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | ||||
| golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= | ||||
| golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= | ||||
| golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= | ||||
| golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= | ||||
| golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= | ||||
| golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= | ||||
| golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= | ||||
| golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= | ||||
| golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= | ||||
| golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= | ||||
| golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | ||||
| google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= | ||||
| google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= | ||||
| google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= | ||||
| gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= | ||||
| gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= | ||||
| gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= | ||||
| gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= | ||||
| gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= | ||||
| gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= | ||||
| gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= | ||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | ||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= | ||||
| gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | ||||
| gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
|   | ||||
| @@ -1,19 +1,25 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/mojocn/base64Captcha" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -49,12 +55,6 @@ func (h *ManagerHandler) Login(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// add captcha | ||||
| 	if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) { | ||||
| 		resp.ERROR(c, "验证码错误!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var manager model.AdminUser | ||||
| 	res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager) | ||||
| 	if res.Error != nil { | ||||
|   | ||||
| @@ -1,13 +1,21 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| @@ -23,7 +31,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler { | ||||
| func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id       uint   `json:"id"` | ||||
| 		Platform string `json:"platform"` | ||||
| 		Name     string `json:"name"` | ||||
| 		Type     string `json:"type"` | ||||
| 		Value    string `json:"value"` | ||||
| @@ -40,7 +47,6 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
| 	if data.Id > 0 { | ||||
| 		h.DB.Find(&apiKey, data.Id) | ||||
| 	} | ||||
| 	apiKey.Platform = data.Platform | ||||
| 	apiKey.Value = data.Value | ||||
| 	apiKey.Type = data.Type | ||||
| 	apiKey.ApiURL = data.ApiURL | ||||
| @@ -49,6 +55,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
| 	apiKey.Name = data.Name | ||||
| 	res := h.DB.Save(&apiKey) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -65,14 +72,24 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ApiKeyHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	status := h.GetBool(c, "status") | ||||
| 	t := h.GetTrim(c, "type") | ||||
| 	platform := h.GetTrim(c, "platform") | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if status { | ||||
| 		session = session.Where("enabled", true) | ||||
| 	} | ||||
| 	if t != "" { | ||||
| 		session = session.Where("type", t) | ||||
| 	} | ||||
| 	if platform != "" { | ||||
| 		session = session.Where("platform", platform) | ||||
| 	} | ||||
|  | ||||
| 	var items []model.ApiKey | ||||
| 	var keys = make([]vo.ApiKey, 0) | ||||
| 	res := h.DB.Find(&items) | ||||
| 	res := session.Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var key vo.ApiKey | ||||
| @@ -104,6 +121,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -111,19 +129,17 @@ func (h *ApiKeyHandler) Set(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ApiKeyHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id uint | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	if id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	if data.Id > 0 { | ||||
| 		res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{}) | ||||
|  | ||||
| 	res := h.DB.Where("id", id).Delete(&model.ApiKey{}) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
| @@ -1,39 +0,0 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/mojocn/base64Captcha" | ||||
| ) | ||||
|  | ||||
| type CaptchaHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler { | ||||
| 	return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}} | ||||
| } | ||||
|  | ||||
| type CaptchaVo struct { | ||||
| 	CaptchaId string `json:"captcha_id"` | ||||
| 	PicPath   string `json:"pic_path"` | ||||
| } | ||||
|  | ||||
| // GetCaptcha 获取验证码 | ||||
| func (h *CaptchaHandler) GetCaptcha(c *gin.Context) { | ||||
| 	var captchaVo CaptchaVo | ||||
| 	driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10) | ||||
| 	cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore) | ||||
| 	// b64s是图片的base64编码 | ||||
| 	id, b64s, err := cp.Generate() | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "生成验证码错误!") | ||||
| 		return | ||||
| 	} | ||||
| 	captchaVo.CaptchaId = id | ||||
| 	captchaVo.PicPath = b64s | ||||
|  | ||||
| 	resp.SUCCESS(c, captchaVo) | ||||
| } | ||||
| @@ -1,13 +1,20 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| @@ -33,11 +40,6 @@ type chatItemVo struct { | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data struct { | ||||
| 		Title    string   `json:"title"` | ||||
| 		UserId   uint     `json:"user_id"` | ||||
| @@ -259,6 +261,7 @@ func (h *ChatHandler) RemoveMessage(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}) | ||||
| 	if tx.Error != nil { | ||||
| 		logger.Error("error with update database:", tx.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,16 +1,23 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type ChatModelHandler struct { | ||||
| @@ -34,6 +41,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) { | ||||
| 		MaxTokens   int     `json:"max_tokens"`  // 最大响应长度 | ||||
| 		MaxContext  int     `json:"max_context"` // 最大上下文长度 | ||||
| 		Temperature float32 `json:"temperature"` // 模型温度 | ||||
| 		KeyId       int     `json:"key_id,omitempty"` | ||||
| 		CreatedAt   int64   `json:"created_at"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| @@ -41,24 +49,32 @@ func (h *ChatModelHandler) Save(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	item := model.ChatModel{ | ||||
| 		Platform:    data.Platform, | ||||
| 		Name:        data.Name, | ||||
| 		Value:       data.Value, | ||||
| 		Enabled:     data.Enabled, | ||||
| 		SortNum:     data.SortNum, | ||||
| 		Open:        data.Open, | ||||
| 		MaxTokens:   data.MaxTokens, | ||||
| 		MaxContext:  data.MaxContext, | ||||
| 		Temperature: data.Temperature, | ||||
| 		Power:       data.Power} | ||||
| 	item.Id = data.Id | ||||
| 	if item.Id > 0 { | ||||
| 		item.CreatedAt = time.Unix(data.CreatedAt, 0) | ||||
| 	item := model.ChatModel{} | ||||
| 	// 更新 | ||||
| 	if data.Id > 0 { | ||||
| 		h.DB.Where("id", data.Id).First(&item) | ||||
| 	} | ||||
|  | ||||
| 	item.Name = data.Name | ||||
| 	item.Value = data.Value | ||||
| 	item.Enabled = data.Enabled | ||||
| 	item.SortNum = data.SortNum | ||||
| 	item.Open = data.Open | ||||
| 	item.Power = data.Power | ||||
| 	item.MaxTokens = data.MaxTokens | ||||
| 	item.MaxContext = data.MaxContext | ||||
| 	item.Temperature = data.Temperature | ||||
| 	item.KeyId = data.KeyId | ||||
|  | ||||
| 	var res *gorm.DB | ||||
| 	if data.Id > 0 { | ||||
| 		res = h.DB.Save(&item) | ||||
| 	} else { | ||||
| 		res = h.DB.Create(&item) | ||||
| 	} | ||||
| 	res := h.DB.Save(&item) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -75,20 +91,34 @@ func (h *ChatModelHandler) Save(c *gin.Context) { | ||||
|  | ||||
| // List 模型列表 | ||||
| func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	enable := h.GetBool(c, "enable") | ||||
| 	name := h.GetTrim(c, "name") | ||||
| 	if enable { | ||||
| 		session = session.Where("enabled", enable) | ||||
| 	} | ||||
| 	if name != "" { | ||||
| 		session = session.Where("name LIKE ?", name+"%") | ||||
| 	} | ||||
| 	var items []model.ChatModel | ||||
| 	var cms = make([]vo.ChatModel, 0) | ||||
| 	res := session.Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 	if res.Error != nil { | ||||
| 		resp.SUCCESS(c, cms) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// initialize key name | ||||
| 	keyIds := make([]int, 0) | ||||
| 	for _, v := range items { | ||||
| 		keyIds = append(keyIds, v.KeyId) | ||||
| 	} | ||||
| 	var keys []model.ApiKey | ||||
| 	keyMap := make(map[uint]string) | ||||
| 	h.DB.Where("id IN ?", keyIds).Find(&keys) | ||||
| 	for _, v := range keys { | ||||
| 		keyMap[v.Id] = v.Name | ||||
| 	} | ||||
| 	for _, item := range items { | ||||
| 		var cm vo.ChatModel | ||||
| 		err := utils.CopyObject(item, &cm) | ||||
| @@ -96,12 +126,12 @@ func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 			cm.Id = item.Id | ||||
| 			cm.CreatedAt = item.CreatedAt.Unix() | ||||
| 			cm.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 			cm.KeyName = keyMap[uint(item.KeyId)] | ||||
| 			cms = append(cms, cm) | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
| 	} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, cms) | ||||
| } | ||||
|  | ||||
| @@ -119,6 +149,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -139,6 +170,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) { | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| @@ -156,6 +188,7 @@ func (h *ChatModelHandler) Remove(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,16 +1,24 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type ChatRoleHandler struct { | ||||
| @@ -40,6 +48,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) { | ||||
| 	} | ||||
| 	res := h.DB.Save(&role) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -50,11 +59,6 @@ func (h *ChatRoleHandler) Save(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.ChatRole | ||||
| 	var roles = make([]vo.ChatRole, 0) | ||||
| 	res := h.DB.Order("sort_num ASC").Find(&items) | ||||
| @@ -63,6 +67,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// initialize model mane for role | ||||
| 	modelIds := make([]int, 0) | ||||
| 	for _, v := range items { | ||||
| 		if v.ModelId > 0 { | ||||
| 			modelIds = append(modelIds, v.ModelId) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	modelNameMap := make(map[int]string) | ||||
| 	if len(modelIds) > 0 { | ||||
| 		var models []model.ChatModel | ||||
| 		tx := h.DB.Where("id IN ?", modelIds).Find(&models) | ||||
| 		if tx.Error == nil { | ||||
| 			for _, m := range models { | ||||
| 				modelNameMap[int(m.Id)] = m.Name | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for _, v := range items { | ||||
| 		var role vo.ChatRole | ||||
| 		err := utils.CopyObject(v, &role) | ||||
| @@ -70,6 +93,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 			role.Id = v.Id | ||||
| 			role.CreatedAt = v.CreatedAt.Unix() | ||||
| 			role.UpdatedAt = v.UpdatedAt.Unix() | ||||
| 			role.ModelName = modelNameMap[role.ModelId] | ||||
| 			roles = append(roles, role) | ||||
| 		} | ||||
| 	} | ||||
| @@ -92,6 +116,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) { | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| @@ -114,6 +139,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -121,19 +147,15 @@ func (h *ChatRoleHandler) Set(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ChatRoleHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id uint | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
|  | ||||
| 	if id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	if data.Id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{}) | ||||
| 	res := h.DB.Where("id", id).Delete(&model.ChatRole{}) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "删除失败!") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,23 +1,45 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/mj" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/shirou/gopsutil/host" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type ConfigHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	levelDB        *store.LevelDB | ||||
| 	licenseService *service.LicenseService | ||||
| 	mjServicePool  *mj.ServicePool | ||||
| 	sdServicePool  *sd.ServicePool | ||||
| } | ||||
|  | ||||
| func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { | ||||
| 	return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { | ||||
| 	return &ConfigHandler{ | ||||
| 		BaseHandler:    handler.BaseHandler{App: app, DB: db}, | ||||
| 		levelDB:        levelDB, | ||||
| 		mjServicePool:  mjPool, | ||||
| 		sdServicePool:  sdPool, | ||||
| 		licenseService: licenseService, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *ConfigHandler) Update(c *gin.Context) { | ||||
| @@ -28,6 +50,7 @@ func (h *ConfigHandler) Update(c *gin.Context) { | ||||
| 			Content string `json:"content,omitempty"` | ||||
| 			Updated bool   `json:"updated,omitempty"` | ||||
| 		} `json:"config"` | ||||
| 		ConfigBak types.SystemConfig `json:"config_bak,omitempty"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| @@ -35,6 +58,12 @@ func (h *ConfigHandler) Update(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// ONLY authorized user can change the copyright | ||||
| 	if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy { | ||||
| 		resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	value := utils.JsonEncode(&data.Config) | ||||
| 	config := model.Config{Key: data.Key, Config: value} | ||||
| 	res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key}) | ||||
| @@ -70,11 +99,6 @@ func (h *ConfigHandler) Update(c *gin.Context) { | ||||
|  | ||||
| // Get 获取指定的系统配置 | ||||
| func (h *ConfigHandler) Get(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	key := c.Query("key") | ||||
| 	var config model.Config | ||||
| 	res := h.DB.Where("marker", key).First(&config) | ||||
| @@ -92,3 +116,88 @@ func (h *ConfigHandler) Get(c *gin.Context) { | ||||
|  | ||||
| 	resp.SUCCESS(c, value) | ||||
| } | ||||
|  | ||||
| // Active 激活系统 | ||||
| func (h *ConfigHandler) Active(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		License string `json:"license"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	info, err := host.Info() | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.licenseService.ActiveLicense(data.License, info.HostID) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, info.HostID) | ||||
| } | ||||
|  | ||||
| // GetLicense 获取 License 信息 | ||||
| func (h *ConfigHandler) GetLicense(c *gin.Context) { | ||||
| 	license := h.licenseService.GetLicense() | ||||
| 	resp.SUCCESS(c, license) | ||||
| } | ||||
|  | ||||
| // GetAppConfig 获取内置配置 | ||||
| func (h *ConfigHandler) GetAppConfig(c *gin.Context) { | ||||
| 	resp.SUCCESS(c, gin.H{ | ||||
| 		"mj_plus":  h.App.Config.MjPlusConfigs, | ||||
| 		"mj_proxy": h.App.Config.MjProxyConfigs, | ||||
| 		"sd":       h.App.Config.SdConfigs, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // SaveDrawingConfig 保存AI绘画配置 | ||||
| func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Sd      []types.StableDiffusionConfig `json:"sd"` | ||||
| 		MjPlus  []types.MjPlusConfig          `json:"mj_plus"` | ||||
| 		MjProxy []types.MjProxyConfig         `json:"mj_proxy"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	changed := false | ||||
| 	if configChanged(data.Sd, h.App.Config.SdConfigs) { | ||||
| 		logger.Debugf("SD 配置变动了") | ||||
| 		h.App.Config.SdConfigs = data.Sd | ||||
| 		h.sdServicePool.InitServices(data.Sd) | ||||
| 		changed = true | ||||
| 	} | ||||
|  | ||||
| 	if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) { | ||||
| 		logger.Debugf("MidJourney 配置变动了") | ||||
| 		h.App.Config.MjPlusConfigs = data.MjPlus | ||||
| 		h.App.Config.MjProxyConfigs = data.MjProxy | ||||
| 		h.mjServicePool.InitServices(data.MjPlus, data.MjProxy) | ||||
| 		changed = true | ||||
| 	} | ||||
|  | ||||
| 	if changed { | ||||
| 		err := core.SaveConfig(h.App.Config) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "更新配置文档失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
|  | ||||
| } | ||||
|  | ||||
| func configChanged(c1 interface{}, c2 interface{}) bool { | ||||
| 	encode1 := utils.JsonEncode(c1) | ||||
| 	encode2 := utils.JsonEncode(c2) | ||||
| 	return utils.Md5(encode1) != utils.Md5(encode2) | ||||
| } | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/shopspring/decimal" | ||||
| 	"gorm.io/gorm" | ||||
|   | ||||
| @@ -1,13 +1,20 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
|  | ||||
| @@ -64,6 +71,7 @@ func (h *FunctionHandler) Set(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -71,11 +79,6 @@ func (h *FunctionHandler) Set(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *FunctionHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.Function | ||||
| 	res := h.DB.Find(&items) | ||||
| 	if res.Error != nil { | ||||
| @@ -101,6 +104,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) { | ||||
| 	if id > 0 { | ||||
| 		res := h.DB.Delete(&model.Function{Id: uint(id)}) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
							
								
								
									
										132
									
								
								api/handler/admin/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								api/handler/admin/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type MenuHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler { | ||||
| 	return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *MenuHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id      uint   `json:"id"` | ||||
| 		Name    string `json:"name"` | ||||
| 		Icon    string `json:"icon"` | ||||
| 		URL     string `json:"url"` | ||||
| 		SortNum int    `json:"sort_num"` | ||||
| 		Enabled bool   `json:"enabled"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Save(&model.Menu{ | ||||
| 		Id:      data.Id, | ||||
| 		Name:    data.Name, | ||||
| 		Icon:    data.Icon, | ||||
| 		URL:     data.URL, | ||||
| 		SortNum: data.SortNum, | ||||
| 		Enabled: data.Enabled, | ||||
| 	}) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // List 数据列表 | ||||
| func (h *MenuHandler) List(c *gin.Context) { | ||||
| 	var items []model.Menu | ||||
| 	var list = make([]vo.Menu, 0) | ||||
| 	res := h.DB.Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var product vo.Menu | ||||
| 			err := utils.CopyObject(item, &product) | ||||
| 			if err == nil { | ||||
| 				list = append(list, product) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, list) | ||||
| } | ||||
|  | ||||
| func (h *MenuHandler) Enable(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id      uint `json:"id"` | ||||
| 		Enabled bool `json:"enabled"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *MenuHandler) Sort(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Ids   []uint `json:"ids"` | ||||
| 		Sorts []int  `json:"sorts"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *MenuHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		res := h.DB.Where("id", id).Delete(&model.Menu{}) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
| @@ -1,13 +1,20 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -22,11 +29,6 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { | ||||
| } | ||||
|  | ||||
| func (h *OrderHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data struct { | ||||
| 		OrderNo  string   `json:"order_no"` | ||||
| 		Status   int      `json:"status"` | ||||
| @@ -92,6 +94,7 @@ func (h *OrderHandler) Remove(c *gin.Context) { | ||||
|  | ||||
| 		res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -1,13 +1,20 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
|   | ||||
| @@ -1,13 +1,20 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| @@ -50,6 +57,7 @@ func (h *ProductHandler) Save(c *gin.Context) { | ||||
| 	} | ||||
| 	res := h.DB.Save(&item) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -65,21 +73,11 @@ func (h *ProductHandler) Save(c *gin.Context) { | ||||
| 	resp.SUCCESS(c, itemVo) | ||||
| } | ||||
|  | ||||
| // List 模型列表 | ||||
| // List 数据列表 | ||||
| func (h *ProductHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	enable := h.GetBool(c, "enable") | ||||
| 	if enable { | ||||
| 		session = session.Where("enabled", enable) | ||||
| 	} | ||||
| 	var items []model.Product | ||||
| 	var list = make([]vo.Product, 0) | ||||
| 	res := session.Order("sort_num ASC").Find(&items) | ||||
| 	res := h.DB.Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var product vo.Product | ||||
| @@ -110,6 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -128,8 +127,9 @@ func (h *ProductHandler) Sort(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| @@ -142,8 +142,9 @@ func (h *ProductHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		res := h.DB.Where("id = ?", id).Delete(&model.Product{}) | ||||
| 		res := h.DB.Where("id", id).Delete(&model.Product{}) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -1,13 +1,20 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| @@ -21,11 +28,6 @@ func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { | ||||
| } | ||||
|  | ||||
| func (h *RewardHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.Reward | ||||
| 	res := h.DB.Order("id DESC").Find(&items) | ||||
| 	var rewards = make([]vo.Reward, 0) | ||||
| @@ -70,6 +72,7 @@ func (h *RewardHandler) Remove(c *gin.Context) { | ||||
| 	if data.Id > 0 { | ||||
| 		res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{}) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
|   | ||||
| @@ -1,14 +1,22 @@ | ||||
| package admin | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/service" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -17,19 +25,15 @@ import ( | ||||
|  | ||||
| type UserHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	licenseService *service.LicenseService | ||||
| } | ||||
|  | ||||
| func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler { | ||||
| 	return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler { | ||||
| 	return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService} | ||||
| } | ||||
|  | ||||
| // List 用户列表 | ||||
| func (h *UserHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	page := h.GetInt(c, "page", 1) | ||||
| 	pageSize := h.GetInt(c, "page_size", 20) | ||||
| 	username := h.GetTrim(c, "username") | ||||
| @@ -80,6 +84,13 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	// 检测最大注册人数 | ||||
| 	var totalUser int64 | ||||
| 	h.DB.Model(&model.User{}).Count(&totalUser) | ||||
| 	if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { | ||||
| 		resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") | ||||
| 		return | ||||
| 	} | ||||
| 	var user = model.User{} | ||||
| 	var res *gorm.DB | ||||
| 	var userVo vo.User | ||||
| @@ -100,7 +111,8 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
|  | ||||
| 		res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			logger.Error("error with update database:", res.Error) | ||||
| 			resp.ERROR(c, res.Error.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		// 记录算力日志 | ||||
| @@ -124,10 +136,16 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
| 			}) | ||||
| 		} | ||||
| 	} else { | ||||
| 		// 检查用户是否已经存在 | ||||
| 		h.DB.Where("username", data.Username).First(&user) | ||||
| 		if user.Id > 0 { | ||||
| 			resp.ERROR(c, "用户名已存在") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		salt := utils.RandString(8) | ||||
| 		u := model.User{ | ||||
| 			Username:    data.Username, | ||||
| 			Nickname:    fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)), | ||||
| 			Password:    utils.GenPassword(data.Password, salt), | ||||
| 			Avatar:      "/images/avatar/user.png", | ||||
| 			Salt:        salt, | ||||
| @@ -137,6 +155,11 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
| 			ChatModels:  utils.JsonEncode(data.ChatModels), | ||||
| 			ExpiredTime: utils.Str2stamp(data.ExpiredTime), | ||||
| 		} | ||||
| 		if h.licenseService.GetLicense().Configs.DeCopy { | ||||
| 			u.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6)) | ||||
| 		} else { | ||||
| 			u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)) | ||||
| 		} | ||||
| 		res = h.DB.Create(&u) | ||||
| 		_ = utils.CopyObject(u, &userVo) | ||||
| 		userVo.Id = u.Id | ||||
| @@ -145,6 +168,7 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
|   | ||||
| @@ -1,9 +1,16 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
|   | ||||
| @@ -1,11 +1,19 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
| @@ -23,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 	var items []model.ChatModel | ||||
| 	var chatModels = make([]vo.ChatModel, 0) | ||||
| 	var res *gorm.DB | ||||
| 	session := h.DB.Session(&gorm.Session{}).Where("enabled", true) | ||||
| 	t := c.Query("type") | ||||
| 	if t != "" { | ||||
| 		session = session.Where("type", t) | ||||
| 	} | ||||
| 	// 如果用户没有登录,则加载所有开放模型 | ||||
| 	if !h.IsLogin(c) { | ||||
| 		res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items) | ||||
| 		res = session.Where("open", true).Order("sort_num ASC").Find(&items) | ||||
| 	} else { | ||||
| 		user, _ := h.GetLoginUser(c) | ||||
| 		var models []int | ||||
| @@ -36,7 +49,7 @@ func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 		} | ||||
| 		// 查询用户有权限访问的模型以及所有开放的模型 | ||||
| 		res = h.DB.Where("enabled = ?", true).Where( | ||||
| 			h.DB.Where("id IN ?", models).Or("open =?", true), | ||||
| 			h.DB.Where("id IN ?", models).Or("open", true), | ||||
| 		).Order("sort_num ASC").Find(&items) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -22,31 +29,11 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { | ||||
|  | ||||
| // List 获取用户聊天应用列表 | ||||
| func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 	all := h.GetBool(c, "all") | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var roles []model.ChatRole | ||||
| 	res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No roles found,"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取所有角色 | ||||
| 	if userId == 0 || all { | ||||
| 		// 转成 vo | ||||
| 		var roleVos = make([]vo.ChatRole, 0) | ||||
| 		for _, r := range roles { | ||||
| 			var v vo.ChatRole | ||||
| 			err := utils.CopyObject(r, &v) | ||||
| 			if err == nil { | ||||
| 				v.Id = r.Id | ||||
| 				roleVos = append(roleVos, v) | ||||
| 			} | ||||
| 		} | ||||
| 		resp.SUCCESS(c, roleVos) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	query := h.DB.Where("enable", true) | ||||
| 	if userId > 0 { | ||||
| 		var user model.User | ||||
| 		h.DB.First(&user, userId) | ||||
| 		var roleKeys []string | ||||
| @@ -55,12 +42,19 @@ func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 			resp.ERROR(c, "角色解析失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	// 转成 vo | ||||
| 		query = query.Where("marker IN ?", roleKeys) | ||||
| 	} | ||||
| 	if id > 0 { | ||||
| 		query = query.Or("id", id) | ||||
| 	} | ||||
| 	res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var roleVos = make([]vo.ChatRole, 0) | ||||
| 	for _, r := range roles { | ||||
| 		if !utils.ContainsStr(roleKeys, r.Key) { | ||||
| 			continue | ||||
| 		} | ||||
| 		var v vo.ChatRole | ||||
| 		err := utils.CopyObject(r, &v) | ||||
| 		if err == nil { | ||||
| @@ -89,7 +83,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) { | ||||
|  | ||||
| 	res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("添加应用失败:", err) | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,208 +0,0 @@ | ||||
| package chatimpl | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
|  | ||||
| // 微软 Azure 模型消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendAzureMessage( | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "context canceled") { | ||||
| 			logger.Info("用户取消了请求:", prompt) | ||||
| 			return nil | ||||
| 		} else if strings.Contains(err.Error(), "no available key") { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| 	} | ||||
|  | ||||
| 	contentType := response.Header.Get("Content-Type") | ||||
| 	if strings.Contains(contentType, "text/event-stream") { | ||||
| 		replyCreatedAt := time.Now() // 记录回复时间 | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if !strings.Contains(line, "data:") || len(line) < 30 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			var responseBody = types.ApiResponse{} | ||||
| 			err = json.Unmarshal([]byte(line[6:]), &responseBody) | ||||
| 			if err != nil { // 数据解析出错 | ||||
| 				logger.Error(err, line) | ||||
| 				utils.ReplyMessage(ws, ErrorMsg) | ||||
| 				utils.ReplyMessage(ws, ErrImg) | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			if len(responseBody.Choices) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// 初始化 role | ||||
| 			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { | ||||
| 				message.Role = responseBody.Choices[0].Delta.Role | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 				continue | ||||
| 			} else if responseBody.Choices[0].FinishReason != "" { | ||||
| 				break // 输出完成或者输出中断了 | ||||
| 			} else { | ||||
| 				content := responseBody.Choices[0].Delta.Content | ||||
| 				contents = append(contents, utils.InterfaceToString(content)) | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), | ||||
| 				}) | ||||
| 			} | ||||
| 		} // end for | ||||
|  | ||||
| 		if err := scanner.Err(); err != nil { | ||||
| 			if strings.Contains(err.Error(), "context canceled") { | ||||
| 				logger.Info("用户取消了请求:", prompt) | ||||
| 			} else { | ||||
| 				logger.Error("信息读取出错:", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
|  | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| 			message.Content = strings.Join(contents, "") | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			replyTokens += getTotalTokens(req) | ||||
|  | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     replyTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| 				chatItem.RoleId = role.Id | ||||
| 				chatItem.ModelId = session.Model.Id | ||||
| 				if utf8.RuneCountInString(prompt) > 30 { | ||||
| 					chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
| 		var res types.ApiError | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		if strings.Contains(res.Error.Message, "maximum context length") { | ||||
| 			logger.Error(res.Error.Message) | ||||
| 			utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") | ||||
| 			h.App.ChatContexts.Delete(session.ChatId) | ||||
| 			return h.sendMessage(ctx, session, role, prompt, ws) | ||||
| 		} else { | ||||
| 			utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -1,273 +0,0 @@ | ||||
| package chatimpl | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
|  | ||||
| type baiduResp struct { | ||||
| 	Id               string `json:"id"` | ||||
| 	Object           string `json:"object"` | ||||
| 	Created          int    `json:"created"` | ||||
| 	SentenceId       int    `json:"sentence_id"` | ||||
| 	IsEnd            bool   `json:"is_end"` | ||||
| 	IsTruncated      bool   `json:"is_truncated"` | ||||
| 	Result           string `json:"result"` | ||||
| 	NeedClearHistory bool   `json:"need_clear_history"` | ||||
| 	Usage            struct { | ||||
| 		PromptTokens     int `json:"prompt_tokens"` | ||||
| 		CompletionTokens int `json:"completion_tokens"` | ||||
| 		TotalTokens      int `json:"total_tokens"` | ||||
| 	} `json:"usage"` | ||||
| } | ||||
|  | ||||
| // 百度文心一言消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendBaiduMessage( | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "context canceled") { | ||||
| 			logger.Info("用户取消了请求:", prompt) | ||||
| 			return nil | ||||
| 		} else if strings.Contains(err.Error(), "no available key") { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| 	} | ||||
|  | ||||
| 	contentType := response.Header.Get("Content-Type") | ||||
| 	if strings.Contains(contentType, "text/event-stream") { | ||||
| 		replyCreatedAt := time.Now() // 记录回复时间 | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		var content string | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if len(line) < 5 || strings.HasPrefix(line, "id:") { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if strings.HasPrefix(line, "data:") { | ||||
| 				content = line[5:] | ||||
| 			} | ||||
|  | ||||
| 			// 处理代码换行 | ||||
| 			if len(content) == 0 { | ||||
| 				content = "\n" | ||||
| 			} | ||||
|  | ||||
| 			var resp baiduResp | ||||
| 			err := utils.JsonDecode(content, &resp) | ||||
| 			if err != nil { | ||||
| 				logger.Error("error with parse data line: ", err) | ||||
| 				utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			if len(contents) == 0 { | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 			} | ||||
| 			utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 				Type:    types.WsMiddle, | ||||
| 				Content: utils.InterfaceToString(resp.Result), | ||||
| 			}) | ||||
| 			contents = append(contents, resp.Result) | ||||
|  | ||||
| 			if resp.IsTruncated { | ||||
| 				utils.ReplyMessage(ws, "AI 输出异常中断") | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			if resp.IsEnd { | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 		} // end for | ||||
|  | ||||
| 		if err := scanner.Err(); err != nil { | ||||
| 			if strings.Contains(err.Error(), "context canceled") { | ||||
| 				logger.Info("用户取消了请求:", prompt) | ||||
| 			} else { | ||||
| 				logger.Error("信息读取出错:", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| 			message.Content = strings.Join(contents, "") | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyTokens + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| 				chatItem.RoleId = role.Id | ||||
| 				chatItem.ModelId = session.Model.Id | ||||
| 				if utf8.RuneCountInString(prompt) > 30 { | ||||
| 					chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		var res struct { | ||||
| 			Code int    `json:"error_code"` | ||||
| 			Msg  string `json:"error_msg"` | ||||
| 		} | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
| 		utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) { | ||||
| 	ctx := context.Background() | ||||
| 	tokenString, err := h.redis.Get(ctx, apiKey).Result() | ||||
| 	if err == nil { | ||||
| 		return tokenString, nil | ||||
| 	} | ||||
|  | ||||
| 	expr := time.Hour * 24 * 20 // access_token 有效期 | ||||
| 	key := strings.Split(apiKey, "|") | ||||
| 	if len(key) != 2 { | ||||
| 		return "", fmt.Errorf("invalid api key: %s", apiKey) | ||||
| 	} | ||||
| 	url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1]) | ||||
| 	client := &http.Client{} | ||||
| 	req, err := http.NewRequest("POST", url, nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 	req.Header.Add("Accept", "application/json") | ||||
|  | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with send request: %w", err) | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with read response: %w", err) | ||||
| 	} | ||||
| 	var r map[string]interface{} | ||||
| 	err = json.Unmarshal(body, &r) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse response: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	if r["error"] != nil { | ||||
| 		return "", fmt.Errorf("error with api response: %s", r["error_description"]) | ||||
| 	} | ||||
|  | ||||
| 	tokenString = fmt.Sprintf("%s", r["access_token"]) | ||||
| 	h.redis.Set(ctx, apiKey, tokenString, expr) | ||||
| 	return tokenString, nil | ||||
| } | ||||
| @@ -1,25 +1,35 @@ | ||||
| package chatimpl | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"html/template" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| @@ -27,30 +37,25 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" | ||||
|  | ||||
| var ErrImg = "" | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type ChatHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	redis          *redis.Client | ||||
| 	uploadManager  *oss.UploaderManager | ||||
| 	licenseService *service.LicenseService | ||||
| 	ReqCancelFunc  *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function | ||||
| 	ChatContexts   *types.LMap[string, []types.Message]    // 聊天上下文 Map [chatId] => []Message | ||||
| } | ||||
|  | ||||
| func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler { | ||||
| func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler { | ||||
| 	return &ChatHandler{ | ||||
| 		BaseHandler:    handler.BaseHandler{App: app, DB: db}, | ||||
| 		redis:          redis, | ||||
| 		uploadManager:  manager, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) Init() { | ||||
| 	// 如果后台有上传微信客服微信二维码,则覆盖 | ||||
| 	if h.App.SysConfig.WechatCardURL != "" { | ||||
| 		ErrImg = fmt.Sprintf("", h.App.SysConfig.WechatCardURL) | ||||
| 		licenseService: licenseService, | ||||
| 		ReqCancelFunc:  types.NewLMap[string, context.CancelFunc](), | ||||
| 		ChatContexts:   types.NewLMap[string, []types.Message](), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -68,30 +73,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 	modelId := h.GetInt(c, "model_id", 0) | ||||
|  | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	var chatRole model.ChatRole | ||||
| 	res := h.DB.First(&chatRole, roleId) | ||||
| 	if res.Error != nil || !chatRole.Enable { | ||||
| 		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
| 	// if the role bind a model_id, use role's bind model_id | ||||
| 	if chatRole.ModelId > 0 { | ||||
| 		modelId = chatRole.ModelId | ||||
| 	} | ||||
| 	// get model info | ||||
| 	var chatModel model.ChatModel | ||||
| 	res := h.DB.First(&chatModel, modelId) | ||||
| 	res = h.DB.First(&chatModel, modelId) | ||||
| 	if res.Error != nil || chatModel.Enabled == false { | ||||
| 		utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.App.ChatSession.Get(sessionId) | ||||
| 	if session == nil { | ||||
| 		user, err := h.GetLoginUser(c) | ||||
| 		if err != nil { | ||||
| 			logger.Info("用户未登录") | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 		} | ||||
| 		session = &types.ChatSession{ | ||||
| 	session := &types.ChatSession{ | ||||
| 		SessionId: sessionId, | ||||
| 		ClientIP:  c.ClientIP(), | ||||
| 			Username:  user.Username, | ||||
| 			UserId:    user.Id, | ||||
| 		} | ||||
| 		h.App.ChatSession.Put(sessionId, session) | ||||
| 		UserId:    h.GetLoginUserId(c), | ||||
| 	} | ||||
|  | ||||
| 	// use old chat data override the chat model and role ID | ||||
| @@ -111,30 +116,19 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 		MaxTokens:   chatModel.MaxTokens, | ||||
| 		MaxContext:  chatModel.MaxContext, | ||||
| 		Temperature: chatModel.Temperature, | ||||
| 		Platform:    types.Platform(chatModel.Platform)} | ||||
| 	logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) | ||||
| 	var chatRole model.ChatRole | ||||
| 	res = h.DB.First(&chatRole, roleId) | ||||
| 	if res.Error != nil || !chatRole.Enable { | ||||
| 		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
| 		KeyId:       chatModel.KeyId} | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.ClientIP()) | ||||
|  | ||||
| 	h.Init() | ||||
|  | ||||
| 	// 保存会话连接 | ||||
| 	h.App.ChatClients.Put(sessionId, client) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := client.Receive() | ||||
| 			if err != nil { | ||||
| 				logger.Debugf("close connection: %s", client.Conn.RemoteAddr()) | ||||
| 				client.Close() | ||||
| 				h.App.ChatClients.Delete(sessionId) | ||||
| 				cancelFunc := h.App.ReqCancelFunc.Get(sessionId) | ||||
| 				cancelFunc := h.ReqCancelFunc.Get(sessionId) | ||||
| 				if cancelFunc != nil { | ||||
| 					cancelFunc() | ||||
| 					h.App.ReqCancelFunc.Delete(sessionId) | ||||
| 					h.ReqCancelFunc.Delete(sessionId) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| @@ -154,12 +148,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 			logger.Info("Receive a message: ", message.Content) | ||||
|  | ||||
| 			ctx, cancel := context.WithCancel(context.Background()) | ||||
| 			h.App.ReqCancelFunc.Put(sessionId, cancel) | ||||
| 			h.ReqCancelFunc.Put(sessionId, cancel) | ||||
| 			// 回复消息 | ||||
| 			err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) | ||||
| 				utils.ReplyMessage(client, err.Error()) | ||||
| 			} else { | ||||
| 				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) | ||||
| 				logger.Infof("回答完毕: %v", message.Content) | ||||
| @@ -181,8 +175,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| 	var user model.User | ||||
| 	res := h.DB.Model(&model.User{}).First(&user, session.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!") | ||||
| 		return res.Error | ||||
| 		return errors.New("未授权用户,您正在进行非法操作!") | ||||
| 	} | ||||
| 	var userVo vo.User | ||||
| 	err := utils.CopyObject(user, &userVo) | ||||
| @@ -192,92 +185,67 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| 	} | ||||
|  | ||||
| 	if userVo.Status == false { | ||||
| 		utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 		return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!") | ||||
| 	} | ||||
|  | ||||
| 	if userVo.Power < session.Model.Power { | ||||
| 		utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型的单次对话需要消耗的算力(%d)!", userVo.Power, session.Model.Power)) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 		return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power) | ||||
| 	} | ||||
|  | ||||
| 	if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() { | ||||
| 		utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 		return errors.New("您的账号已经过期,请联系管理员!") | ||||
| 	} | ||||
|  | ||||
| 	// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度 | ||||
| 	promptTokens, err := utils.CalcTokens(prompt, session.Model.Value) | ||||
| 	if promptTokens > session.Model.MaxContext { | ||||
| 		utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!") | ||||
| 		return nil | ||||
|  | ||||
| 		return errors.New("对话内容超出了当前模型允许的最大上下文长度!") | ||||
| 	} | ||||
|  | ||||
| 	var req = types.ApiRequest{ | ||||
| 		Model:  session.Model.Value, | ||||
| 		Stream: true, | ||||
| 	} | ||||
| 	switch session.Model.Platform { | ||||
| 	case types.Azure, types.ChatGLM, types.Baidu, types.XunFei: | ||||
| 		req.Temperature = session.Model.Temperature | ||||
| 		req.MaxTokens = session.Model.MaxTokens | ||||
| 		break | ||||
| 	case types.OpenAI: | ||||
| 	req.Temperature = session.Model.Temperature | ||||
| 	req.MaxTokens = session.Model.MaxTokens | ||||
| 	// OpenAI 支持函数功能 | ||||
| 	var items []model.Function | ||||
| 		res := h.DB.Where("enabled", true).Find(&items) | ||||
| 		if res.Error != nil { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		var tools = make([]interface{}, 0) | ||||
| 	res = h.DB.Where("enabled", true).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		var tools = make([]types.Tool, 0) | ||||
| 		for _, v := range items { | ||||
| 			var parameters map[string]interface{} | ||||
| 			err = utils.JsonDecode(v.Parameters, ¶meters) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			required := parameters["required"] | ||||
| 			delete(parameters, "required") | ||||
| 			tools = append(tools, gin.H{ | ||||
| 				"type": "function", | ||||
| 				"function": gin.H{ | ||||
| 					"name":        v.Name, | ||||
| 					"description": v.Description, | ||||
| 					"parameters":  parameters, | ||||
| 					"required":    required, | ||||
| 			tool := types.Tool{ | ||||
| 				Type: "function", | ||||
| 				Function: types.Function{ | ||||
| 					Name:        v.Name, | ||||
| 					Description: v.Description, | ||||
| 					Parameters:  parameters, | ||||
| 				}, | ||||
| 			}) | ||||
| 			} | ||||
| 			if v, ok := parameters["required"]; v == nil || !ok { | ||||
| 				tool.Function.Parameters["required"] = []string{} | ||||
| 			} | ||||
| 			tools = append(tools, tool) | ||||
| 		} | ||||
|  | ||||
| 		if len(tools) > 0 { | ||||
| 			req.Tools = tools | ||||
| 			req.ToolChoice = "auto" | ||||
| 		} | ||||
| 	case types.QWen: | ||||
| 		req.Parameters = map[string]interface{}{ | ||||
| 			"max_tokens":  session.Model.MaxTokens, | ||||
| 			"temperature": session.Model.Temperature, | ||||
| 		} | ||||
| 		break | ||||
|  | ||||
| 	default: | ||||
| 		utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// 加载聊天上下文 | ||||
| 	chatCtx := make([]types.Message, 0) | ||||
| 	messages := make([]types.Message, 0) | ||||
| 	if h.App.SysConfig.EnableContext { | ||||
| 		if h.App.ChatContexts.Has(session.ChatId) { | ||||
| 			messages = h.App.ChatContexts.Get(session.ChatId) | ||||
| 		if h.ChatContexts.Has(session.ChatId) { | ||||
| 			messages = h.ChatContexts.Get(session.ChatId) | ||||
| 		} else { | ||||
| 			_ = utils.JsonDecode(role.Context, &messages) | ||||
| 			if h.App.SysConfig.ContextDeep > 0 { | ||||
| @@ -325,37 +293,69 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| 		reqMgs = append(reqMgs, m) | ||||
| 	} | ||||
|  | ||||
| 	if session.Model.Platform == types.QWen { | ||||
| 		req.Input = map[string]interface{}{"prompt": prompt} | ||||
| 		if len(reqMgs) > 0 { | ||||
| 			req.Input["messages"] = reqMgs | ||||
| 		} | ||||
| 	fullPrompt := prompt | ||||
| 	text := prompt | ||||
| 	// extract files in prompt | ||||
| 	files := utils.ExtractFileURLs(prompt) | ||||
| 	logger.Debugf("detected FILES: %+v", files) | ||||
| 	// 如果不是逆向模型,则提取文件内容 | ||||
| 	if len(files) > 0 && !(session.Model.Value == "gpt-4-all" || | ||||
| 		strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") || | ||||
| 		strings.HasSuffix(session.Model.Value, "claude-3")) { | ||||
| 		contents := make([]string, 0) | ||||
| 		var file model.File | ||||
| 		for _, v := range files { | ||||
| 			h.DB.Where("url = ?", v).First(&file) | ||||
| 			content, err := utils.ReadFileContent(v, h.App.Config.TikaHost) | ||||
| 			if err != nil { | ||||
| 				logger.Error("error with read file: ", err) | ||||
| 			} else { | ||||
| 		req.Messages = append(reqMgs, map[string]interface{}{ | ||||
| 			"role":    "user", | ||||
| 			"content": prompt, | ||||
| 		}) | ||||
| 				contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content)) | ||||
| 			} | ||||
| 			text = strings.Replace(text, v, "", 1) | ||||
| 		} | ||||
| 		if len(contents) > 0 { | ||||
| 			fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text) | ||||
| 		} | ||||
|  | ||||
| 	switch session.Model.Platform { | ||||
| 	case types.Azure: | ||||
| 		return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	case types.OpenAI: | ||||
| 		return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	case types.ChatGLM: | ||||
| 		return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	case types.Baidu: | ||||
| 		return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	case types.XunFei: | ||||
| 		return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	case types.QWen: | ||||
| 		return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 		tokens, _ := utils.CalcTokens(fullPrompt, req.Model) | ||||
| 		if tokens > session.Model.MaxContext { | ||||
| 			return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。") | ||||
| 		} | ||||
| 	utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 		Type:    types.WsMiddle, | ||||
| 		Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform), | ||||
| 	} | ||||
| 	logger.Debug("最终Prompt:", fullPrompt) | ||||
|  | ||||
| 	// extract images from prompt | ||||
| 	imgURLs := utils.ExtractImgURLs(prompt) | ||||
| 	logger.Debugf("detected IMG: %+v", imgURLs) | ||||
| 	var content interface{} | ||||
| 	if len(imgURLs) > 0 { | ||||
| 		data := make([]interface{}, 0) | ||||
| 		for _, v := range imgURLs { | ||||
| 			text = strings.Replace(text, v, "", 1) | ||||
| 			data = append(data, gin.H{ | ||||
| 				"type": "image_url", | ||||
| 				"image_url": gin.H{ | ||||
| 					"url": v, | ||||
| 				}, | ||||
| 			}) | ||||
| 	return nil | ||||
| 		} | ||||
| 		data = append(data, gin.H{ | ||||
| 			"type": "text", | ||||
| 			"text": strings.TrimSpace(text), | ||||
| 		}) | ||||
| 		content = data | ||||
| 	} else { | ||||
| 		content = fullPrompt | ||||
| 	} | ||||
| 	req.Messages = append(reqMgs, map[string]interface{}{ | ||||
| 		"role":    "user", | ||||
| 		"content": content, | ||||
| 	}) | ||||
|  | ||||
| 	logger.Debugf("%+v", req.Messages) | ||||
|  | ||||
| 	return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| } | ||||
|  | ||||
| // Tokens 统计 token 数量 | ||||
| @@ -415,55 +415,36 @@ func getTotalTokens(req types.ApiRequest) int { | ||||
| // StopGenerate 停止生成 | ||||
| func (h *ChatHandler) StopGenerate(c *gin.Context) { | ||||
| 	sessionId := c.Query("session_id") | ||||
| 	if h.App.ReqCancelFunc.Has(sessionId) { | ||||
| 		h.App.ReqCancelFunc.Get(sessionId)() | ||||
| 		h.App.ReqCancelFunc.Delete(sessionId) | ||||
| 	if h.ReqCancelFunc.Has(sessionId) { | ||||
| 		h.ReqCancelFunc.Get(sessionId)() | ||||
| 		h.ReqCancelFunc.Delete(sessionId) | ||||
| 	} | ||||
| 	resp.SUCCESS(c, types.OkMsg) | ||||
| } | ||||
|  | ||||
| // 发送请求到 OpenAI 服务器 | ||||
| // useOwnApiKey: 是否使用了用户自己的 API KEY | ||||
| func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) { | ||||
| 	res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) | ||||
| 	if res.Error != nil { | ||||
| func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) { | ||||
| 	// if the chat model bind a KEY, use it directly | ||||
| 	if session.Model.KeyId > 0 { | ||||
| 		h.DB.Where("id", session.Model.KeyId).Find(apiKey) | ||||
| 	} | ||||
| 	// use the last unused key | ||||
| 	if apiKey.Id == 0 { | ||||
| 		h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey) | ||||
| 	} | ||||
| 	if apiKey.Id == 0 { | ||||
| 		return nil, errors.New("no available key, please import key") | ||||
| 	} | ||||
| 	var apiURL string | ||||
| 	switch platform { | ||||
| 	case types.Azure: | ||||
| 		md := strings.Replace(req.Model, ".", "", 1) | ||||
| 		apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) | ||||
| 		break | ||||
| 	case types.ChatGLM: | ||||
| 		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) | ||||
| 		req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 | ||||
| 		req.Messages = nil | ||||
| 		break | ||||
| 	case types.Baidu: | ||||
| 		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) | ||||
| 		break | ||||
| 	case types.QWen: | ||||
| 		apiURL = apiKey.ApiURL | ||||
| 		req.Messages = nil | ||||
| 		break | ||||
| 	default: | ||||
| 		apiURL = apiKey.ApiURL | ||||
| 	} | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 	// 百度文心,需要串接 access_token | ||||
| 	if platform == types.Baidu { | ||||
| 		token, err := h.getBaiduToken(apiKey.Value) | ||||
|  | ||||
| 	// ONLY allow apiURL in blank list | ||||
| 	err := h.licenseService.IsValidApiURL(apiKey.ApiURL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 		logger.Info("百度文心 Access_Token:", token) | ||||
| 		apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token) | ||||
| 	} | ||||
|  | ||||
| 	logger.Debugf(utils.JsonEncode(req)) | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) | ||||
| 	// 创建 HttpClient 请求对象 | ||||
| 	var client *http.Client | ||||
| 	requestBody, err := json.Marshal(req) | ||||
| @@ -477,8 +458,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf | ||||
|  | ||||
| 	request = request.WithContext(ctx) | ||||
| 	request.Header.Set("Content-Type", "application/json") | ||||
| 	var proxyURL string | ||||
| 	if apiKey.ProxyURL != "" { // 使用代理 | ||||
| 	if len(apiKey.ProxyURL) > 5 { // 使用代理 | ||||
| 		proxy, _ := url.Parse(apiKey.ProxyURL) | ||||
| 		client = &http.Client{ | ||||
| 			Transport: &http.Transport{ | ||||
| @@ -488,28 +468,10 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf | ||||
| 	} else { | ||||
| 		client = http.DefaultClient | ||||
| 	} | ||||
| 	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model) | ||||
| 	switch platform { | ||||
| 	case types.Azure: | ||||
| 		request.Header.Set("api-key", apiKey.Value) | ||||
| 		break | ||||
| 	case types.ChatGLM: | ||||
| 		token, err := h.getChatGLMToken(apiKey.Value) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) | ||||
| 		break | ||||
| 	case types.Baidu: | ||||
| 		request.RequestURI = "" | ||||
| 	case types.OpenAI: | ||||
| 	logger.Debugf("Sending %s request, Channel:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model) | ||||
| 	request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) | ||||
| 		break | ||||
| 	case types.QWen: | ||||
| 		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) | ||||
| 		request.Header.Set("X-DashScope-SSE", "enable") | ||||
| 		break | ||||
| 	} | ||||
| 	// 更新API KEY 最后使用时间 | ||||
| 	h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 	return client.Do(request) | ||||
| } | ||||
|  | ||||
| @@ -539,6 +501,98 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p | ||||
|  | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) saveChatHistory( | ||||
| 	req types.ApiRequest, | ||||
| 	prompt string, | ||||
| 	contents []string, | ||||
| 	message types.Message, | ||||
| 	chatCtx []types.Message, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	userVo vo.User, | ||||
| 	promptCreatedAt time.Time, | ||||
| 	replyCreatedAt time.Time) { | ||||
| 	if message.Role == "" { | ||||
| 		message.Role = "assistant" | ||||
| 	} | ||||
| 	message.Content = strings.Join(contents, "") | ||||
| 	useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 	// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 	if h.App.SysConfig.EnableContext { | ||||
| 		chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 		chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 		h.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 	} | ||||
|  | ||||
| 	// 追加聊天记录 | ||||
| 	// for prompt | ||||
| 	promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 	} | ||||
| 	historyUserMsg := model.ChatMessage{ | ||||
| 		UserId:     userVo.Id, | ||||
| 		ChatId:     session.ChatId, | ||||
| 		RoleId:     role.Id, | ||||
| 		Type:       types.PromptMsg, | ||||
| 		Icon:       userVo.Avatar, | ||||
| 		Content:    template.HTMLEscapeString(prompt), | ||||
| 		Tokens:     promptToken, | ||||
| 		UseContext: true, | ||||
| 		Model:      req.Model, | ||||
| 	} | ||||
| 	historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 	historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 	res := h.DB.Save(&historyUserMsg) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 	} | ||||
|  | ||||
| 	// for reply | ||||
| 	// 计算本次对话消耗的总 token 数量 | ||||
| 	replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 	totalTokens := replyTokens + getTotalTokens(req) | ||||
| 	historyReplyMsg := model.ChatMessage{ | ||||
| 		UserId:     userVo.Id, | ||||
| 		ChatId:     session.ChatId, | ||||
| 		RoleId:     role.Id, | ||||
| 		Type:       types.ReplyMsg, | ||||
| 		Icon:       role.Icon, | ||||
| 		Content:    message.Content, | ||||
| 		Tokens:     totalTokens, | ||||
| 		UseContext: true, | ||||
| 		Model:      req.Model, | ||||
| 	} | ||||
| 	historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 	historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 	res = h.DB.Create(&historyReplyMsg) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("failed to save reply history message: ", res.Error) | ||||
| 	} | ||||
|  | ||||
| 	// 更新用户算力 | ||||
| 	if session.Model.Power > 0 { | ||||
| 		h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
| 	} | ||||
| 	// 保存当前会话 | ||||
| 	var chatItem model.ChatItem | ||||
| 	res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 	if res.Error != nil { | ||||
| 		chatItem.ChatId = session.ChatId | ||||
| 		chatItem.UserId = userVo.Id | ||||
| 		chatItem.RoleId = role.Id | ||||
| 		chatItem.ModelId = session.Model.Id | ||||
| 		if utf8.RuneCountInString(prompt) > 30 { | ||||
| 			chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 		} else { | ||||
| 			chatItem.Title = prompt | ||||
| 		} | ||||
| 		chatItem.Model = req.Model | ||||
| 		h.DB.Create(&chatItem) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 将AI回复消息中生成的图片链接下载到本地 | ||||
| func (h *ChatHandler) extractImgUrl(text string) string { | ||||
| 	pattern := `!\[([^\]]*)]\(([^)]+)\)` | ||||
| @@ -554,7 +608,7 @@ func (h *ChatHandler) extractImgUrl(text string) string { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false) | ||||
| 		newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false) | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with download image: ", err) | ||||
| 			continue | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package chatimpl | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -89,7 +96,7 @@ func (h *ChatHandler) Clear(c *gin.Context) { | ||||
| 	for _, chat := range chats { | ||||
| 		chatIds = append(chatIds, chat.ChatId) | ||||
| 		// 清空会话上下文 | ||||
| 		h.App.ChatContexts.Delete(chat.ChatId) | ||||
| 		h.ChatContexts.Delete(chat.ChatId) | ||||
| 	} | ||||
| 	err = h.DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{}) | ||||
| @@ -101,8 +108,6 @@ func (h *ChatHandler) Clear(c *gin.Context) { | ||||
| 		if res.Error != nil { | ||||
| 			return res.Error | ||||
| 		} | ||||
|  | ||||
| 		// TODO: 是否要删除 MidJourney 绘画记录和图片文件? | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| @@ -168,7 +173,7 @@ func (h *ChatHandler) Remove(c *gin.Context) { | ||||
| 	// TODO: 是否要删除 MidJourney 绘画记录和图片文件? | ||||
|  | ||||
| 	// 清空会话上下文 | ||||
| 	h.App.ChatContexts.Delete(chatId) | ||||
| 	h.ChatContexts.Delete(chatId) | ||||
| 	resp.SUCCESS(c, types.OkMsg) | ||||
| } | ||||
|  | ||||
| @@ -187,12 +192,20 @@ func (h *ChatHandler) Detail(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 填充角色名称 | ||||
| 	var role model.ChatRole | ||||
| 	res = h.DB.Where("id", chatItem.RoleId).First(&role) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Role not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var chatItemVo vo.ChatItem | ||||
| 	err := utils.CopyObject(chatItem, &chatItemVo) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	chatItemVo.RoleName = role.Name | ||||
| 	resp.SUCCESS(c, chatItemVo) | ||||
| } | ||||
|   | ||||
| @@ -1,236 +0,0 @@ | ||||
| package chatimpl | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
|  | ||||
| // 清华大学 ChatGML 消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendChatGLMMessage( | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "context canceled") { | ||||
| 			logger.Info("用户取消了请求:", prompt) | ||||
| 			return nil | ||||
| 		} else if strings.Contains(err.Error(), "no available key") { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| 	} | ||||
|  | ||||
| 	contentType := response.Header.Get("Content-Type") | ||||
| 	if strings.Contains(contentType, "text/event-stream") { | ||||
| 		replyCreatedAt := time.Now() // 记录回复时间 | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		var event, content string | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if len(line) < 5 || strings.HasPrefix(line, "id:") { | ||||
| 				continue | ||||
| 			} | ||||
| 			if strings.HasPrefix(line, "event:") { | ||||
| 				event = line[6:] | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if strings.HasPrefix(line, "data:") { | ||||
| 				content = line[5:] | ||||
| 			} | ||||
| 			// 处理代码换行 | ||||
| 			if len(content) == 0 { | ||||
| 				content = "\n" | ||||
| 			} | ||||
| 			switch event { | ||||
| 			case "add": | ||||
| 				if len(contents) == 0 { | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 				} | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: utils.InterfaceToString(content), | ||||
| 				}) | ||||
| 				contents = append(contents, content) | ||||
| 			case "finish": | ||||
| 				break | ||||
| 			case "error": | ||||
| 				utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content)) | ||||
| 				break | ||||
| 			case "interrupted": | ||||
| 				utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**") | ||||
| 			} | ||||
|  | ||||
| 		} // end for | ||||
|  | ||||
| 		if err := scanner.Err(); err != nil { | ||||
| 			if strings.Contains(err.Error(), "context canceled") { | ||||
| 				logger.Info("用户取消了请求:", prompt) | ||||
| 			} else { | ||||
| 				logger.Error("信息读取出错:", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| 			message.Content = strings.Join(contents, "") | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyTokens + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| 				chatItem.RoleId = role.Id | ||||
| 				chatItem.ModelId = session.Model.Id | ||||
| 				if utf8.RuneCountInString(prompt) > 30 { | ||||
| 					chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		var res struct { | ||||
| 			Code    int    `json:"code"` | ||||
| 			Success bool   `json:"success"` | ||||
| 			Msg     string `json:"msg"` | ||||
| 		} | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
| 		if !res.Success { | ||||
| 			utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) { | ||||
| 	ctx := context.Background() | ||||
| 	tokenString, err := h.redis.Get(ctx, apiKey).Result() | ||||
| 	if err == nil { | ||||
| 		return tokenString, nil | ||||
| 	} | ||||
|  | ||||
| 	expr := time.Hour * 2 | ||||
| 	key := strings.Split(apiKey, ".") | ||||
| 	if len(key) != 2 { | ||||
| 		return "", fmt.Errorf("invalid api key: %s", apiKey) | ||||
| 	} | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||
| 		"api_key":   key[0], | ||||
| 		"timestamp": time.Now().Unix(), | ||||
| 		"exp":       time.Now().Add(expr).Add(time.Second * 10).Unix(), | ||||
| 	}) | ||||
| 	token.Header["alg"] = "HS256" | ||||
| 	token.Header["sign_type"] = "SIGN" | ||||
| 	delete(token.Header, "typ") | ||||
| 	// Sign and get the complete encoded token as a string using the secret | ||||
| 	tokenString, err = token.SignedString([]byte(key[1])) | ||||
| 	h.redis.Set(ctx, apiKey, tokenString, expr) | ||||
| 	return tokenString, err | ||||
| } | ||||
| @@ -1,21 +1,26 @@ | ||||
| package chatimpl | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"html/template" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	req2 "github.com/imroc/req/v3" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	req2 "github.com/imroc/req/v3" | ||||
| ) | ||||
|  | ||||
| // OPenAI 消息发送实现 | ||||
| @@ -31,24 +36,13 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	response, err := h.doRequest(ctx, req, session, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "context canceled") { | ||||
| 			logger.Info("用户取消了请求:", prompt) | ||||
| 			return nil | ||||
| 			return fmt.Errorf("用户取消了请求:%s", prompt) | ||||
| 		} else if strings.Contains(err.Error(), "no available key") { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		if response.Body != nil { | ||||
| 			all, _ := io.ReadAll(response.Body) | ||||
| 			logger.Error(string(all)) | ||||
| 			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 		} | ||||
| 		return err | ||||
| 	} else { | ||||
| @@ -65,18 +59,26 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 		var toolCall = false | ||||
| 		var arguments = make([]string, 0) | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		var isNew = true | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if !strings.Contains(line, "data:") || len(line) < 30 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			var responseBody = types.ApiResponse{} | ||||
| 			err = json.Unmarshal([]byte(line[6:]), &responseBody) | ||||
| 			if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 | ||||
| 				logger.Error(err, line) | ||||
| 				utils.ReplyMessage(ws, ErrorMsg) | ||||
| 				utils.ReplyMessage(ws, ErrImg) | ||||
| 			if err != nil { // 数据解析出错 | ||||
| 				return errors.New(line) | ||||
| 			} | ||||
| 			if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 | ||||
| 				continue | ||||
| 			} | ||||
| 			if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { | ||||
| 				utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| @@ -103,8 +105,10 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 				res := h.DB.Where("name = ?", tool.Function.Name).First(&function) | ||||
| 				if res.Error == nil { | ||||
| 					toolCall = true | ||||
| 					callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)}) | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg}) | ||||
| 					contents = append(contents, callMsg) | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
| @@ -114,16 +118,16 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			// 初始化 role | ||||
| 			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { | ||||
| 				message.Role = responseBody.Choices[0].Delta.Role | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 				continue | ||||
| 			} else if responseBody.Choices[0].FinishReason != "" { | ||||
| 			// output stopped | ||||
| 			if responseBody.Choices[0].FinishReason != "" { | ||||
| 				break // 输出完成或者输出中断了 | ||||
| 			} else { | ||||
| 				content := responseBody.Choices[0].Delta.Content | ||||
| 				contents = append(contents, utils.InterfaceToString(content)) | ||||
| 				if isNew { | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 					isNew = false | ||||
| 				} | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), | ||||
| @@ -140,7 +144,7 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 		} | ||||
|  | ||||
| 		if toolCall { // 调用函数完成任务 | ||||
| 			var params map[string]interface{} | ||||
| 			params := make(map[string]interface{}) | ||||
| 			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) | ||||
| 			logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) | ||||
| 			params["user_id"] = userVo.Id | ||||
| @@ -173,126 +177,11 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| 			message.Content = strings.Join(contents, "") | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.SysConfig.EnableContext && toolCall == false { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			useContext := true | ||||
| 			if toolCall { | ||||
| 				useContext = false | ||||
| 			} | ||||
|  | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: useContext, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			var replyTokens = 0 | ||||
| 			if toolCall { // prompt + 函数名 + 参数 token | ||||
| 				tokens, _ := utils.CalcTokens(function.Name, req.Model) | ||||
| 				replyTokens += tokens | ||||
| 				tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) | ||||
| 				replyTokens += tokens | ||||
| 			} else { | ||||
| 				replyTokens, _ = utils.CalcTokens(message.Content, req.Model) | ||||
| 			} | ||||
| 			replyTokens += getTotalTokens(req) | ||||
|  | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    h.extractImgUrl(message.Content), | ||||
| 				Tokens:     replyTokens, | ||||
| 				UseContext: useContext, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| 				chatItem.RoleId = role.Id | ||||
| 				chatItem.ModelId = session.Model.Id | ||||
| 				if utf8.RuneCountInString(prompt) > 30 { | ||||
| 					chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error()) | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
| 		var res types.ApiError | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```") | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		// OpenAI API 调用异常处理 | ||||
| 		if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") | ||||
| 			// 移除当前 API key | ||||
| 			h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) | ||||
| 		} else if strings.Contains(res.Error.Message, "You exceeded your current quota") { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") | ||||
| 		} else if strings.Contains(res.Error.Message, "This model's maximum context length") { | ||||
| 			logger.Error(res.Error.Message) | ||||
| 			utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") | ||||
| 			h.App.ChatContexts.Delete(session.ChatId) | ||||
| 			return h.sendMessage(ctx, session, role, prompt, ws) | ||||
| 		} else { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) | ||||
| 		} | ||||
| 		body, _ := io.ReadAll(response.Body) | ||||
| 		return fmt.Errorf("请求 OpenAI API 失败:%s", body) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
|   | ||||
| @@ -1,240 +0,0 @@ | ||||
| package chatimpl | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
|  | ||||
| type qWenResp struct { | ||||
| 	Output struct { | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 		Text         string `json:"text"` | ||||
| 	} `json:"output,omitempty"` | ||||
| 	Usage struct { | ||||
| 		TotalTokens  int `json:"total_tokens"` | ||||
| 		InputTokens  int `json:"input_tokens"` | ||||
| 		OutputTokens int `json:"output_tokens"` | ||||
| 	} `json:"usage,omitempty"` | ||||
| 	RequestID string `json:"request_id"` | ||||
|  | ||||
| 	Code    string `json:"code,omitempty"` | ||||
| 	Message string `json:"message,omitempty"` | ||||
| } | ||||
|  | ||||
| // 通义千问消息发送实现 | ||||
| func (h *ChatHandler) sendQWenMessage( | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "context canceled") { | ||||
| 			logger.Info("用户取消了请求:", prompt) | ||||
| 			return nil | ||||
| 		} else if strings.Contains(err.Error(), "no available key") { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| 	} | ||||
| 	contentType := response.Header.Get("Content-Type") | ||||
| 	if strings.Contains(contentType, "text/event-stream") { | ||||
| 		replyCreatedAt := time.Now() // 记录回复时间 | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
|  | ||||
| 		var content, lastText, newText string | ||||
| 		var outPutStart = false | ||||
|  | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if len(line) < 5 || strings.HasPrefix(line, "id:") || | ||||
| 				strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if strings.HasPrefix(line, "data:") { | ||||
| 				content = line[5:] | ||||
| 			} | ||||
|  | ||||
| 			var resp qWenResp | ||||
| 			if len(contents) == 0 { // 发送消息头 | ||||
| 				if !outPutStart { | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 					outPutStart = true | ||||
| 					continue | ||||
| 				} else { | ||||
| 					// 处理代码换行 | ||||
| 					content = "\n" | ||||
| 				} | ||||
| 			} else { | ||||
| 				err := utils.JsonDecode(content, &resp) | ||||
| 				if err != nil { | ||||
| 					logger.Error("error with parse data line: ", content) | ||||
| 					utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) | ||||
| 					break | ||||
| 				} | ||||
| 				if resp.Message != "" { | ||||
| 					utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message)) | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			//通过比较 lastText(上一次的文本)和 currentText(当前的文本), | ||||
| 			//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。 | ||||
| 			//每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。 | ||||
| 			currentText := resp.Output.Text | ||||
| 			if currentText != lastText { | ||||
| 				// 提取新增文本 | ||||
| 				newText = strings.Replace(currentText, lastText, "", 1) | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: utils.InterfaceToString(newText), | ||||
| 				}) | ||||
| 				lastText = currentText // 更新 lastText | ||||
| 			} | ||||
| 			contents = append(contents, newText) | ||||
|  | ||||
| 			if resp.Output.FinishReason == "stop" { | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 		} //end for | ||||
|  | ||||
| 		if err := scanner.Err(); err != nil { | ||||
| 			if strings.Contains(err.Error(), "context canceled") { | ||||
| 				logger.Info("用户取消了请求:", prompt) | ||||
| 			} else { | ||||
| 				logger.Error("信息读取出错:", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| 			message.Content = strings.Join(contents, "") | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyTokens + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| 				chatItem.RoleId = role.Id | ||||
| 				chatItem.ModelId = session.Model.Id | ||||
| 				if utf8.RuneCountInString(prompt) > 30 { | ||||
| 					chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		var res struct { | ||||
| 			Code int    `json:"error_code"` | ||||
| 			Msg  string `json:"error_msg"` | ||||
| 		} | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
| 		utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -1,320 +0,0 @@ | ||||
| package chatimpl | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
|  | ||||
| type xunFeiResp struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int `json:"status"` | ||||
| 			Seq    int `json:"seq"` | ||||
| 			Text   []struct { | ||||
| 				Content string `json:"content"` | ||||
| 				Role    string `json:"role"` | ||||
| 				Index   int    `json:"index"` | ||||
| 			} `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 		Usage struct { | ||||
| 			Text struct { | ||||
| 				QuestionTokens   int `json:"question_tokens"` | ||||
| 				PromptTokens     int `json:"prompt_tokens"` | ||||
| 				CompletionTokens int `json:"completion_tokens"` | ||||
| 				TotalTokens      int `json:"total_tokens"` | ||||
| 			} `json:"text"` | ||||
| 		} `json:"usage"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| var Model2URL = map[string]string{ | ||||
| 	"general":     "v1.1", | ||||
| 	"generalv2":   "v2.1", | ||||
| 	"generalv3":   "v3.1", | ||||
| 	"generalv3.5": "v3.5", | ||||
| } | ||||
|  | ||||
| // 科大讯飞消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendXunFeiMessage( | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	var apiKey model.ApiKey | ||||
| 	res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) | ||||
| 	if res.Error != nil { | ||||
| 		utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 		return nil | ||||
| 	} | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
|  | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| 	key := strings.Split(apiKey.Value, "|") | ||||
| 	if len(key) != 3 { | ||||
| 		utils.ReplyMessage(ws, "非法的 API KEY!") | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) | ||||
| 	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) | ||||
| 	wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) | ||||
| 	//握手并建立websocket 连接 | ||||
| 	conn, resp, err := d.Dial(wsURL, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(readResp(resp) + err.Error()) | ||||
| 		utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) | ||||
| 		return nil | ||||
| 	} else if resp.StatusCode != 101 { | ||||
| 		utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	data := buildRequest(key[0], req) | ||||
| 	fmt.Printf("%+v", data) | ||||
| 	fmt.Println(apiURL) | ||||
| 	err = conn.WriteJSON(data) | ||||
| 	if err != nil { | ||||
| 		utils.ReplyMessage(ws, "发送消息失败:"+err.Error()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	replyCreatedAt := time.Now() // 记录回复时间 | ||||
| 	// 循环读取 Chunk 消息 | ||||
| 	var message = types.Message{} | ||||
| 	var contents = make([]string, 0) | ||||
| 	var content string | ||||
| 	for { | ||||
| 		_, msg, err := conn.ReadMessage() | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with read message:", err) | ||||
| 			utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err)) | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// 解析数据 | ||||
| 		var result xunFeiResp | ||||
| 		err = json.Unmarshal(msg, &result) | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with parsing JSON:", err) | ||||
| 			utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		if result.Header.Code != 0 { | ||||
| 			utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message)) | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		content = result.Payload.Choices.Text[0].Content | ||||
| 		// 处理代码换行 | ||||
| 		if len(content) == 0 { | ||||
| 			content = "\n" | ||||
| 		} | ||||
| 		contents = append(contents, content) | ||||
| 		// 第一个结果 | ||||
| 		if result.Payload.Choices.Status == 0 { | ||||
| 			utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 		} | ||||
| 		utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 			Type:    types.WsMiddle, | ||||
| 			Content: utils.InterfaceToString(content), | ||||
| 		}) | ||||
|  | ||||
| 		if result.Payload.Choices.Status == 2 { // 最终结果 | ||||
| 			_ = conn.Close() // 关闭连接 | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			utils.ReplyMessage(ws, "**用户取消了生成指令!**") | ||||
| 			return nil | ||||
| 		default: | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	// 消息发送成功 | ||||
| 	if len(contents) > 0 { | ||||
| 		if message.Role == "" { | ||||
| 			message.Role = "assistant" | ||||
| 		} | ||||
| 		message.Content = strings.Join(contents, "") | ||||
| 		useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 		// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 		if h.App.SysConfig.EnableContext { | ||||
| 			chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 			chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 			h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 		} | ||||
|  | ||||
| 		// 追加聊天记录 | ||||
| 		// for prompt | ||||
| 		promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 		if err != nil { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
| 		historyUserMsg := model.ChatMessage{ | ||||
| 			UserId:     userVo.Id, | ||||
| 			ChatId:     session.ChatId, | ||||
| 			RoleId:     role.Id, | ||||
| 			Type:       types.PromptMsg, | ||||
| 			Icon:       userVo.Avatar, | ||||
| 			Content:    template.HTMLEscapeString(prompt), | ||||
| 			Tokens:     promptToken, | ||||
| 			UseContext: true, | ||||
| 			Model:      req.Model, | ||||
| 		} | ||||
| 		historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 		historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 		res := h.DB.Save(&historyUserMsg) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 		} | ||||
|  | ||||
| 		// for reply | ||||
| 		// 计算本次对话消耗的总 token 数量 | ||||
| 		replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 		totalTokens := replyTokens + getTotalTokens(req) | ||||
| 		historyReplyMsg := model.ChatMessage{ | ||||
| 			UserId:     userVo.Id, | ||||
| 			ChatId:     session.ChatId, | ||||
| 			RoleId:     role.Id, | ||||
| 			Type:       types.ReplyMsg, | ||||
| 			Icon:       role.Icon, | ||||
| 			Content:    message.Content, | ||||
| 			Tokens:     totalTokens, | ||||
| 			UseContext: true, | ||||
| 			Model:      req.Model, | ||||
| 		} | ||||
| 		historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 		historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 		res = h.DB.Create(&historyReplyMsg) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("failed to save reply history message: ", res.Error) | ||||
| 		} | ||||
|  | ||||
| 		// 更新用户算力 | ||||
| 		h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 		// 保存当前会话 | ||||
| 		var chatItem model.ChatItem | ||||
| 		res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 		if res.Error != nil { | ||||
| 			chatItem.ChatId = session.ChatId | ||||
| 			chatItem.UserId = session.UserId | ||||
| 			chatItem.RoleId = role.Id | ||||
| 			chatItem.ModelId = session.Model.Id | ||||
| 			if utf8.RuneCountInString(prompt) > 30 { | ||||
| 				chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 			} else { | ||||
| 				chatItem.Title = prompt | ||||
| 			} | ||||
| 			chatItem.Model = req.Model | ||||
| 			h.DB.Create(&chatItem) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // 构建 websocket 请求实体 | ||||
| func buildRequest(appid string, req types.ApiRequest) map[string]interface{} { | ||||
| 	return map[string]interface{}{ | ||||
| 		"header": map[string]interface{}{ | ||||
| 			"app_id": appid, | ||||
| 		}, | ||||
| 		"parameter": map[string]interface{}{ | ||||
| 			"chat": map[string]interface{}{ | ||||
| 				"domain":      req.Model, | ||||
| 				"temperature": req.Temperature, | ||||
| 				"top_k":       int64(6), | ||||
| 				"max_tokens":  int64(req.MaxTokens), | ||||
| 				"auditing":    "default", | ||||
| 			}, | ||||
| 		}, | ||||
| 		"payload": map[string]interface{}{ | ||||
| 			"message": map[string]interface{}{ | ||||
| 				"text": req.Messages, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 创建鉴权 URL | ||||
| func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) { | ||||
| 	ul, err := url.Parse(hostURL) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	date := time.Now().UTC().Format(time.RFC1123) | ||||
| 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||
| 	//拼接签名字符串 | ||||
| 	signStr := strings.Join(signString, "\n") | ||||
| 	sha := hmacWithSha256(signStr, apiSecret) | ||||
|  | ||||
| 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||
| 		"hmac-sha256", "host date request-line", sha) | ||||
| 	//将请求参数使用base64编码 | ||||
| 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||
| 	v := url.Values{} | ||||
| 	v.Add("host", ul.Host) | ||||
| 	v.Add("date", date) | ||||
| 	v.Add("authorization", authorization) | ||||
| 	//将编码后的字符串url encode后添加到url后面 | ||||
| 	return hostURL + "?" + v.Encode(), nil | ||||
| } | ||||
|  | ||||
| // 使用 sha256 签名 | ||||
| func hmacWithSha256(data, key string) string { | ||||
| 	mac := hmac.New(sha256.New, []byte(key)) | ||||
| 	mac.Write([]byte(data)) | ||||
| 	encodeData := mac.Sum(nil) | ||||
| 	return base64.StdEncoding.EncodeToString(encodeData) | ||||
| } | ||||
|  | ||||
| // 读取响应 | ||||
| func readResp(resp *http.Response) string { | ||||
| 	if resp == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	b, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) | ||||
| } | ||||
| @@ -1,10 +1,18 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/service" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -12,10 +20,11 @@ import ( | ||||
|  | ||||
| type ConfigHandler struct { | ||||
| 	BaseHandler | ||||
| 	licenseService *service.LicenseService | ||||
| } | ||||
|  | ||||
| func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { | ||||
| 	return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler { | ||||
| 	return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService} | ||||
| } | ||||
|  | ||||
| // Get 获取指定的系统配置 | ||||
| @@ -37,3 +46,9 @@ func (h *ConfigHandler) Get(c *gin.Context) { | ||||
|  | ||||
| 	resp.SUCCESS(c, value) | ||||
| } | ||||
|  | ||||
| // License 获取 License 配置 | ||||
| func (h *ConfigHandler) License(c *gin.Context) { | ||||
| 	license := h.licenseService.GetLicense() | ||||
| 	resp.SUCCESS(c, license.Configs) | ||||
| } | ||||
|   | ||||
							
								
								
									
										255
									
								
								api/handler/dalle_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										255
									
								
								api/handler/dalle_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,255 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service/dalle" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gorilla/websocket" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type DallJobHandler struct { | ||||
| 	BaseHandler | ||||
| 	redis    *redis.Client | ||||
| 	service  *dalle.Service | ||||
| 	uploader *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { | ||||
| 	return &DallJobHandler{ | ||||
| 		service:  service, | ||||
| 		uploader: manager, | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: app, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Client WebSocket 客户端,用于通知任务状态变更 | ||||
| func (h *DallJobHandler) Client(c *gin.Context) { | ||||
| 	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	if userId == 0 { | ||||
| 		logger.Info("Invalid user ID") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	h.service.Clients.Put(uint(userId), client) | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := client.Receive() | ||||
| 			if err != nil { | ||||
| 				client.Close() | ||||
| 				h.service.Clients.Delete(uint(userId)) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var message types.WsMessage | ||||
| 			err = utils.JsonDecode(string(msg), &message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// 心跳消息 | ||||
| 			if message.Type == "heartbeat" { | ||||
| 				logger.Debug("收到 DallE 心跳消息:", message.Content) | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (h *DallJobHandler) preCheck(c *gin.Context) bool { | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return false | ||||
| 	} | ||||
| 	if user.Power < h.App.SysConfig.DallPower { | ||||
| 		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
|  | ||||
| } | ||||
|  | ||||
| // Image 创建一个绘画任务 | ||||
| func (h *DallJobHandler) Image(c *gin.Context) { | ||||
| 	if !h.preCheck(c) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data types.DallTask | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	idValue, _ := c.Get(types.LoginUserID) | ||||
| 	userId := utils.IntValue(utils.InterfaceToString(idValue), 0) | ||||
| 	job := model.DallJob{ | ||||
| 		UserId: uint(userId), | ||||
| 		Prompt: data.Prompt, | ||||
| 		Power:  h.App.SysConfig.DallPower, | ||||
| 	} | ||||
| 	res := h.DB.Create(&job) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "error with save job: "+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.service.PushTask(types.DallTask{ | ||||
| 		JobId:   job.Id, | ||||
| 		UserId:  uint(userId), | ||||
| 		Prompt:  data.Prompt, | ||||
| 		Quality: data.Quality, | ||||
| 		Size:    data.Size, | ||||
| 		Style:   data.Style, | ||||
| 		Power:   job.Power, | ||||
| 	}) | ||||
|  | ||||
| 	client := h.service.Clients.Get(job.UserId) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // ImgWall 照片墙 | ||||
| func (h *DallJobHandler) ImgWall(c *gin.Context) { | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	err, jobs := h.getData(true, 0, page, pageSize, true) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| } | ||||
|  | ||||
| // JobList 获取 SD 任务列表 | ||||
| func (h *DallJobHandler) JobList(c *gin.Context) { | ||||
| 	finish := h.GetBool(c, "finish") | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	publish := h.GetBool(c, "publish") | ||||
|  | ||||
| 	err, jobs := h.getData(finish, userId, page, pageSize, publish) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| } | ||||
|  | ||||
| // JobList 获取任务列表 | ||||
| func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if finish { | ||||
| 		session = session.Where("progress = ?", 100).Order("id DESC") | ||||
| 	} else { | ||||
| 		session = session.Where("progress < ?", 100).Order("id ASC") | ||||
| 	} | ||||
| 	if userId > 0 { | ||||
| 		session = session.Where("user_id = ?", userId) | ||||
| 	} | ||||
| 	if publish { | ||||
| 		session = session.Where("publish", publish) | ||||
| 	} | ||||
| 	if page > 0 && pageSize > 0 { | ||||
| 		offset := (page - 1) * pageSize | ||||
| 		session = session.Offset(offset).Limit(pageSize) | ||||
| 	} | ||||
|  | ||||
| 	var items []model.DallJob | ||||
| 	res := session.Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		return res.Error, nil | ||||
| 	} | ||||
|  | ||||
| 	var jobs = make([]vo.DallJob, 0) | ||||
| 	for _, item := range items { | ||||
| 		var job vo.DallJob | ||||
| 		err := utils.CopyObject(item, &job) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		jobs = append(jobs, job) | ||||
| 	} | ||||
|  | ||||
| 	return nil, jobs | ||||
| } | ||||
|  | ||||
| // Remove remove task image | ||||
| func (h *DallJobHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	var job model.DallJob | ||||
| 	if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil { | ||||
| 		resp.ERROR(c, "记录不存在") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove job recode | ||||
| 	res := h.DB.Delete(&model.DallJob{Id: job.Id}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove image | ||||
| 	err := h.uploader.GetUploadHandler().Delete(job.ImgURL) | ||||
| 	if err != nil { | ||||
| 		logger.Error("remove image failed: ", err) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Publish 发布/取消发布图片到画廊显示 | ||||
| func (h *DallJobHandler) Publish(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享 | ||||
|  | ||||
| 	res := h.DB.Model(&model.DallJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
| @@ -1,29 +1,44 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service/dalle" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type FunctionHandler struct { | ||||
| 	BaseHandler | ||||
| 	config        types.ChatPlusApiConfig | ||||
| 	config        types.ApiConfig | ||||
| 	uploadManager *oss.UploaderManager | ||||
| 	dallService   *dalle.Service | ||||
| } | ||||
|  | ||||
| func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { | ||||
| func NewFunctionHandler( | ||||
| 	server *core.AppServer, | ||||
| 	db *gorm.DB, | ||||
| 	config *types.AppConfig, | ||||
| 	manager *oss.UploaderManager, | ||||
| 	dallService *dalle.Service) *FunctionHandler { | ||||
| 	return &FunctionHandler{ | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: server, | ||||
| @@ -31,6 +46,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo | ||||
| 		}, | ||||
| 		config:        config.ApiConfig, | ||||
| 		uploadManager: manager, | ||||
| 		dallService:   dallService, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -151,30 +167,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) { | ||||
| 	resp.SUCCESS(c, strings.Join(builder, "\n\n")) | ||||
| } | ||||
|  | ||||
| type imgReq struct { | ||||
| 	Model  string `json:"model"` | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| } | ||||
|  | ||||
| type imgRes struct { | ||||
| 	Created int64 `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		RevisedPrompt string `json:"revised_prompt"` | ||||
| 		Url           string `json:"url"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type ErrRes struct { | ||||
| 	Error struct { | ||||
| 		Code    interface{} `json:"code"` | ||||
| 		Message string      `json:"message"` | ||||
| 		Param   interface{} `json:"param"` | ||||
| 		Type    string      `json:"type"` | ||||
| 	} `json:"error"` | ||||
| } | ||||
|  | ||||
| // Dall3 DallE3 AI 绘图 | ||||
| func (h *FunctionHandler) Dall3(c *gin.Context) { | ||||
| 	if err := h.checkAuth(c); err != nil { | ||||
| @@ -190,84 +182,44 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { | ||||
|  | ||||
| 	logger.Debugf("绘画参数:%+v", params) | ||||
| 	var user model.User | ||||
| 	tx := h.DB.Where("id = ?", params["user_id"]).First(&user) | ||||
| 	if tx.Error != nil { | ||||
| 	res := h.DB.Where("id = ?", params["user_id"]).First(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "当前用户不存在!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if user.Power < h.App.SysConfig.DallPower { | ||||
| 		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") | ||||
| 		resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// create dall task | ||||
| 	prompt := utils.InterfaceToString(params["prompt"]) | ||||
| 	// get image generation API KEY | ||||
| 	var apiKey model.ApiKey | ||||
| 	tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) | ||||
| 	if tx.Error != nil { | ||||
| 		resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// translate prompt | ||||
| 	const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" | ||||
| 	pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"])) | ||||
| 	if err == nil { | ||||
| 		logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt) | ||||
| 		prompt = pt | ||||
| 	} | ||||
| 	var res imgRes | ||||
| 	var errRes ErrRes | ||||
| 	var request *req.Request | ||||
| 	if apiKey.ProxyURL != "" { | ||||
| 		request = req.C().SetProxyURL(apiKey.ProxyURL).R() | ||||
| 	} else { | ||||
| 		request = req.C().R() | ||||
| 	} | ||||
| 	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) | ||||
| 	r, err := request.SetHeader("Content-Type", "application/json"). | ||||
| 		SetHeader("Authorization", "Bearer "+apiKey.Value). | ||||
| 		SetBody(imgReq{ | ||||
| 			Model:  "dall-e-3", | ||||
| 			Prompt: prompt, | ||||
| 			N:      1, | ||||
| 			Size:   "1024x1024", | ||||
| 		}). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		SetSuccessResult(&res).Post(apiKey.ApiURL) | ||||
| 	if r.IsErrorState() { | ||||
| 		resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message) | ||||
| 		return | ||||
| 	} | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 	logger.Debugf("%+v", res) | ||||
| 	// 存储图片 | ||||
| 	imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "下载图片失败: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL) | ||||
| 	// 更新用户算力 | ||||
| 	tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		var u model.User | ||||
| 		h.DB.Where("id", user.Id).First(&u) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 	job := model.DallJob{ | ||||
| 		UserId: user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    h.App.SysConfig.DallPower, | ||||
| 			Balance:   u.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     "dall-e-3", | ||||
| 			Remark:    fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 		Prompt: prompt, | ||||
| 		Power:  h.App.SysConfig.DallPower, | ||||
| 	} | ||||
| 	res = h.DB.Create(&job) | ||||
|  | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	content, err := h.dallService.Image(types.DallTask{ | ||||
| 		JobId:   job.Id, | ||||
| 		UserId:  user.Id, | ||||
| 		Prompt:  job.Prompt, | ||||
| 		N:       1, | ||||
| 		Quality: "standard", | ||||
| 		Size:    "1024x1024", | ||||
| 		Style:   "vivid", | ||||
| 		Power:   job.Power, | ||||
| 	}, true) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "任务执行失败:"+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, content) | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
|   | ||||
							
								
								
									
										257
									
								
								api/handler/markmap_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										257
									
								
								api/handler/markmap_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,257 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // MarkMapHandler 生成思维导图 | ||||
| type MarkMapHandler struct { | ||||
| 	BaseHandler | ||||
| 	clients *types.LMap[int, *types.WsClient] | ||||
| } | ||||
|  | ||||
| func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { | ||||
| 	return &MarkMapHandler{ | ||||
| 		BaseHandler: BaseHandler{App: app, DB: db}, | ||||
| 		clients:     types.NewLMap[int, *types.WsClient](), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *MarkMapHandler) Client(c *gin.Context) { | ||||
| 	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	modelId := h.GetInt(c, "model_id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
|  | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	h.clients.Put(userId, client) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := client.Receive() | ||||
| 			if err != nil { | ||||
| 				client.Close() | ||||
| 				h.clients.Delete(userId) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var message types.WsMessage | ||||
| 			err = utils.JsonDecode(string(msg), &message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// 心跳消息 | ||||
| 			if message.Type == "heartbeat" { | ||||
| 				logger.Debug("收到 MarkMap 心跳消息:", message.Content) | ||||
| 				continue | ||||
| 			} | ||||
| 			// change model | ||||
| 			if message.Type == "model_id" { | ||||
| 				modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			logger.Info("Receive a message: ", message.Content) | ||||
| 			err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()}) | ||||
| 			} | ||||
|  | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { | ||||
| 	var user model.User | ||||
| 	res := h.DB.Model(&model.User{}).First(&user, userId) | ||||
| 	if res.Error != nil { | ||||
| 		return fmt.Errorf("error with query user info: %v", res.Error) | ||||
| 	} | ||||
| 	var chatModel model.ChatModel | ||||
| 	res = h.DB.Where("id", modelId).First(&chatModel) | ||||
| 	if res.Error != nil { | ||||
| 		return fmt.Errorf("error with query chat model: %v", res.Error) | ||||
| 	} | ||||
|  | ||||
| 	if user.Status == false { | ||||
| 		return errors.New("当前用户被禁用") | ||||
| 	} | ||||
|  | ||||
| 	if user.Power < chatModel.Power { | ||||
| 		return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power) | ||||
| 	} | ||||
|  | ||||
| 	messages := make([]interface{}, 0) | ||||
| 	messages = append(messages, types.Message{Role: "system", Content: ` | ||||
| 你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: | ||||
| # Geek-AI 助手 | ||||
|  | ||||
| ## 完整的开源系统 | ||||
| ### 前端开源 | ||||
| ### 后端开源 | ||||
|  | ||||
| ## 支持各种大模型 | ||||
| ### OpenAI  | ||||
| ### Azure  | ||||
| ### 文心一言 | ||||
| ### 通义千问 | ||||
|  | ||||
| ## 集成多种收费方式 | ||||
| ### 支付宝 | ||||
| ### 微信 | ||||
|  | ||||
| 另外,除此之外不要任何解释性语句。 | ||||
| `}) | ||||
| 	messages = append(messages, types.Message{Role: "user", Content: prompt}) | ||||
| 	var req = types.ApiRequest{ | ||||
| 		Model:    chatModel.Value, | ||||
| 		Stream:   true, | ||||
| 		Messages: messages, | ||||
| 	} | ||||
|  | ||||
| 	var apiKey model.ApiKey | ||||
| 	response, err := h.doRequest(req, chatModel, &apiKey) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("请求 OpenAI API 失败: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	defer response.Body.Close() | ||||
|  | ||||
| 	contentType := response.Header.Get("Content-Type") | ||||
| 	if strings.Contains(contentType, "text/event-stream") { | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		var isNew = true | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if !strings.Contains(line, "data:") || len(line) < 30 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			var responseBody = types.ApiResponse{} | ||||
| 			err = json.Unmarshal([]byte(line[6:]), &responseBody) | ||||
| 			if err != nil { // 数据解析出错 | ||||
| 				return fmt.Errorf("error with decode data: %v", line) | ||||
| 			} | ||||
|  | ||||
| 			if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if responseBody.Choices[0].FinishReason == "stop" { | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			if isNew { | ||||
| 				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) | ||||
| 				isNew = false | ||||
| 			} | ||||
| 			utils.ReplyChunkMessage(client, types.WsMessage{ | ||||
| 				Type:    types.WsMiddle, | ||||
| 				Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), | ||||
| 			}) | ||||
| 		} // end for | ||||
|  | ||||
| 		utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) | ||||
|  | ||||
| 	} else { | ||||
| 		body, _ := io.ReadAll(response.Body) | ||||
| 		return fmt.Errorf("请求 OpenAI API 失败:%s", string(body)) | ||||
| 	} | ||||
|  | ||||
| 	// 扣减算力 | ||||
| 	if chatModel.Power > 0 { | ||||
| 		res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) | ||||
| 		if res.Error == nil { | ||||
| 			// 记录算力消费日志 | ||||
| 			var u model.User | ||||
| 			h.DB.Where("id", userId).First(&u) | ||||
| 			h.DB.Create(&model.PowerLog{ | ||||
| 				UserId:    u.Id, | ||||
| 				Username:  u.Username, | ||||
| 				Type:      types.PowerConsume, | ||||
| 				Amount:    chatModel.Power, | ||||
| 				Mark:      types.PowerSub, | ||||
| 				Balance:   u.Power, | ||||
| 				Model:     chatModel.Value, | ||||
| 				Remark:    fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), | ||||
| 				CreatedAt: time.Now(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	// if the chat model bind a KEY, use it directly | ||||
| 	if chatModel.KeyId > 0 { | ||||
| 		session = session.Where("id", chatModel.KeyId) | ||||
| 	} else { // use the last unused key | ||||
| 		session = session.Where("type", "chat"). | ||||
| 			Where("enabled", true).Order("last_used_at ASC") | ||||
| 	} | ||||
|  | ||||
| 	res := session.First(apiKey) | ||||
| 	if res.Error != nil { | ||||
| 		return nil, errors.New("no available key, please import key") | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
|  | ||||
| 	// 创建 HttpClient 请求对象 | ||||
| 	var client *http.Client | ||||
| 	requestBody, err := json.Marshal(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	request.Header.Set("Content-Type", "application/json") | ||||
| 	if len(apiKey.ProxyURL) > 5 { // 使用代理 | ||||
| 		proxy, _ := url.Parse(apiKey.ProxyURL) | ||||
| 		client = &http.Client{ | ||||
| 			Transport: &http.Transport{ | ||||
| 				Proxy: http.ProxyURL(proxy), | ||||
| 			}, | ||||
| 		} | ||||
| 	} else { | ||||
| 		client = http.DefaultClient | ||||
| 	} | ||||
| 	request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) | ||||
| 	return client.Do(request) | ||||
| } | ||||
							
								
								
									
										49
									
								
								api/handler/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								api/handler/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"geekai/core" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type MenuHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler { | ||||
| 	return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // List 数据列表 | ||||
| func (h *MenuHandler) List(c *gin.Context) { | ||||
| 	index := h.GetBool(c, "index") | ||||
| 	var items []model.Menu | ||||
| 	var list = make([]vo.Menu, 0) | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	session = session.Where("enabled", true) | ||||
| 	if index { | ||||
| 		session = session.Where("id IN ?", h.App.SysConfig.IndexNavs) | ||||
| 	} | ||||
| 	res := session.Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var product vo.Menu | ||||
| 			err := utils.CopyObject(item, &product) | ||||
| 			if err == nil { | ||||
| 				list = append(list, product) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, list) | ||||
| } | ||||
| @@ -1,17 +1,24 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/mj" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/mj" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -85,20 +92,22 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { | ||||
| // Image 创建一个绘画任务 | ||||
| func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		SessionId string   `json:"session_id"` | ||||
| 		TaskType  string   `json:"task_type"` | ||||
| 		Prompt    string   `json:"prompt"` | ||||
| 		NegPrompt string   `json:"neg_prompt"` | ||||
| 		Rate      string   `json:"rate"` | ||||
| 		Model     string   `json:"model"` | ||||
| 		Chaos     int      `json:"chaos"` | ||||
| 		Raw       bool     `json:"raw"` | ||||
| 		Seed      int64    `json:"seed"` | ||||
| 		Stylize   int      `json:"stylize"` | ||||
| 		Model     string   `json:"model"`   // 模型 | ||||
| 		Chaos     int      `json:"chaos"`   // 创意度取值范围: 0-100 | ||||
| 		Raw       bool     `json:"raw"`     // 是否开启原始模型 | ||||
| 		Seed      int64    `json:"seed"`    // 随机数 | ||||
| 		Stylize   int      `json:"stylize"` // 风格化 | ||||
| 		ImgArr    []string `json:"img_arr"` | ||||
| 		Tile      bool     `json:"tile"` | ||||
| 		Quality   float32  `json:"quality"` | ||||
| 		Weight    float32  `json:"weight"` | ||||
| 		Tile      bool     `json:"tile"`    // 重复平铺 | ||||
| 		Quality   float32  `json:"quality"` // 画质 | ||||
| 		Iw        float32  `json:"iw"` | ||||
| 		CRef      string   `json:"cref"` //生成角色一致的图像 | ||||
| 		SRef      string   `json:"sref"` //生成风格一致的图像 | ||||
| 		Cw        int      `json:"cw"`   // 参考程度 | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -108,41 +117,57 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var prompt = data.Prompt | ||||
| 	if data.Rate != "" && !strings.Contains(prompt, "--ar") { | ||||
| 		prompt += " --ar " + data.Rate | ||||
| 	var params = "" | ||||
| 	if data.Rate != "" && !strings.Contains(params, "--ar") { | ||||
| 		params += " --ar " + data.Rate | ||||
| 	} | ||||
| 	if data.Seed > 0 && !strings.Contains(prompt, "--seed") { | ||||
| 		prompt += fmt.Sprintf(" --seed %d", data.Seed) | ||||
| 	if data.Seed > 0 && !strings.Contains(params, "--seed") { | ||||
| 		params += fmt.Sprintf(" --seed %d", data.Seed) | ||||
| 	} | ||||
| 	if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { | ||||
| 		prompt += fmt.Sprintf(" --s %d", data.Stylize) | ||||
| 	if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") { | ||||
| 		params += fmt.Sprintf(" --s %d", data.Stylize) | ||||
| 	} | ||||
| 	if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { | ||||
| 		prompt += fmt.Sprintf(" --c %d", data.Chaos) | ||||
| 	if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") { | ||||
| 		params += fmt.Sprintf(" --c %d", data.Chaos) | ||||
| 	} | ||||
| 	if data.Weight > 0 { | ||||
| 		prompt += fmt.Sprintf(" --iw %f", data.Weight) | ||||
| 	if len(data.ImgArr) > 0 && data.Iw > 0 { | ||||
| 		params += fmt.Sprintf(" --iw %.2f", data.Iw) | ||||
| 	} | ||||
| 	if data.Raw { | ||||
| 		prompt += " --style raw" | ||||
| 		params += " --style raw" | ||||
| 	} | ||||
| 	if data.Quality > 0 { | ||||
| 		prompt += fmt.Sprintf(" --q %.2f", data.Quality) | ||||
| 	} | ||||
| 	if data.NegPrompt != "" { | ||||
| 		prompt += fmt.Sprintf(" --no %s", data.NegPrompt) | ||||
| 		params += fmt.Sprintf(" --q %.2f", data.Quality) | ||||
| 	} | ||||
| 	if data.Tile { | ||||
| 		prompt += " --tile " | ||||
| 		params += " --tile " | ||||
| 	} | ||||
| 	if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { | ||||
| 		prompt += fmt.Sprintf(" %s", data.Model) | ||||
| 	if data.CRef != "" { | ||||
| 		params += fmt.Sprintf(" --cref %s", data.CRef) | ||||
| 		if data.Cw > 0 { | ||||
| 			params += fmt.Sprintf(" --cw %d", data.Cw) | ||||
| 		} else { | ||||
| 			params += " --cw 100" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if data.SRef != "" { | ||||
| 		params += fmt.Sprintf(" --sref %s", data.SRef) | ||||
| 	} | ||||
| 	if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") { | ||||
| 		params += fmt.Sprintf(" %s", data.Model) | ||||
| 	} | ||||
|  | ||||
| 	// 处理融图和换脸的提示词 | ||||
| 	if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() { | ||||
| 		prompt = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ",")) | ||||
| 		params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ",")) | ||||
| 	} | ||||
|  | ||||
| 	// 如果本地图片上传的是相对地址,处理成绝对地址 | ||||
| 	for k, v := range data.ImgArr { | ||||
| 		if !strings.HasPrefix(v, "http") { | ||||
| 			data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/")) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	idValue, _ := c.Get(types.LoginUserID) | ||||
| @@ -158,7 +183,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 		UserId:    userId, | ||||
| 		TaskId:    taskId, | ||||
| 		Progress:  0, | ||||
| 		Prompt:    prompt, | ||||
| 		Prompt:    fmt.Sprintf("%s %s", data.Prompt, params), | ||||
| 		Power:     h.App.SysConfig.MjPower, | ||||
| 		CreatedAt: time.Now(), | ||||
| 	} | ||||
| @@ -179,9 +204,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 	h.pool.PushTask(types.MjTask{ | ||||
| 		Id:        job.Id, | ||||
| 		TaskId:    taskId, | ||||
| 		SessionId: data.SessionId, | ||||
| 		Type:      types.TaskType(data.TaskType), | ||||
| 		Prompt:    prompt, | ||||
| 		Prompt:    data.Prompt, | ||||
| 		NegPrompt: data.NegPrompt, | ||||
| 		Params:    params, | ||||
| 		UserId:    userId, | ||||
| 		ImgArr:    data.ImgArr, | ||||
| 	}) | ||||
| @@ -216,17 +242,12 @@ type reqVo struct { | ||||
| 	ChannelId   string `json:"channel_id"` | ||||
| 	MessageId   string `json:"message_id"` | ||||
| 	MessageHash string `json:"message_hash"` | ||||
| 	SessionId   string `json:"session_id"` | ||||
| 	Prompt      string `json:"prompt"` | ||||
| 	ChatId      string `json:"chat_id"` | ||||
| 	RoleId      int    `json:"role_id"` | ||||
| 	Icon        string `json:"icon"` | ||||
| } | ||||
|  | ||||
| // Upscale send upscale command to MidJourney Bot | ||||
| func (h *MidJourneyHandler) Upscale(c *gin.Context) { | ||||
| 	var data reqVo | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| @@ -244,7 +265,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { | ||||
| 		UserId:      userId, | ||||
| 		TaskId:      taskId, | ||||
| 		Progress:    0, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		Power:       h.App.SysConfig.MjActionPower, | ||||
| 		CreatedAt:   time.Now(), | ||||
| 	} | ||||
| @@ -255,9 +275,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { | ||||
|  | ||||
| 	h.pool.PushTask(types.MjTask{ | ||||
| 		Id:          job.Id, | ||||
| 		SessionId:   data.SessionId, | ||||
| 		Type:        types.TaskUpscale, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		UserId:      userId, | ||||
| 		ChannelId:   data.ChannelId, | ||||
| 		Index:       data.Index, | ||||
| @@ -292,7 +310,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { | ||||
| // Variation send variation command to MidJourney Bot | ||||
| func (h *MidJourneyHandler) Variation(c *gin.Context) { | ||||
| 	var data reqVo | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| @@ -311,7 +329,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { | ||||
| 		UserId:      userId, | ||||
| 		TaskId:      taskId, | ||||
| 		Progress:    0, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		Power:       h.App.SysConfig.MjActionPower, | ||||
| 		CreatedAt:   time.Now(), | ||||
| 	} | ||||
| @@ -322,9 +339,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { | ||||
|  | ||||
| 	h.pool.PushTask(types.MjTask{ | ||||
| 		Id:          job.Id, | ||||
| 		SessionId:   data.SessionId, | ||||
| 		Type:        types.TaskVariation, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		UserId:      userId, | ||||
| 		Index:       data.Index, | ||||
| 		ChannelId:   data.ChannelId, | ||||
| @@ -372,13 +387,13 @@ func (h *MidJourneyHandler) ImgWall(c *gin.Context) { | ||||
|  | ||||
| // JobList 获取 MJ 任务列表 | ||||
| func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| 	status := h.GetBool(c, "status") | ||||
| 	finish := h.GetBool(c, "finish") | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	publish := h.GetBool(c, "publish") | ||||
|  | ||||
| 	err, jobs := h.getData(status, userId, page, pageSize, publish) | ||||
| 	err, jobs := h.getData(finish, userId, page, pageSize, publish) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| @@ -391,7 +406,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if finish { | ||||
| 		session = session.Where("progress = ?", 100).Order("id DESC") | ||||
| 		session = session.Where("progress >= ?", 100).Order("id DESC") | ||||
| 	} else { | ||||
| 		session = session.Where("progress < ?", 100).Order("id ASC") | ||||
| 	} | ||||
| @@ -421,15 +436,10 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize | ||||
| 		} | ||||
|  | ||||
| 		if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" { | ||||
| 			// discord 服务器图片需要使用代理转发图片数据流 | ||||
| 			if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") { | ||||
| 			image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) | ||||
| 			if err == nil { | ||||
| 				job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) | ||||
| 			} | ||||
| 			} else { | ||||
| 				job.ImgURL = job.OrgURL | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		jobs = append(jobs, job) | ||||
| @@ -439,30 +449,56 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize | ||||
|  | ||||
| // Remove remove task image | ||||
| func (h *MidJourneyHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint   `json:"id"` | ||||
| 		UserId uint   `json:"user_id"` | ||||
| 		ImgURL string `json:"img_url"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	var job model.MidJourneyJob | ||||
| 	if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil { | ||||
| 		resp.ERROR(c, "记录不存在") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove job recode | ||||
| 	res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 	tx := h.DB.Begin() | ||||
| 	if err := tx.Delete(&job).Error; err != nil { | ||||
| 		tx.Rollback() | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// refund power | ||||
| 	err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error | ||||
| 	if err != nil { | ||||
| 		tx.Rollback() | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	h.DB.Where("id = ?", job.UserId).First(&user) | ||||
| 	err = tx.Create(&model.PowerLog{ | ||||
| 		UserId:    user.Id, | ||||
| 		Username:  user.Username, | ||||
| 		Type:      types.PowerConsume, | ||||
| 		Amount:    job.Power, | ||||
| 		Balance:   user.Power + job.Power, | ||||
| 		Mark:      types.PowerAdd, | ||||
| 		Model:     "mid-journey", | ||||
| 		Remark:    fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId), | ||||
| 		CreatedAt: time.Now(), | ||||
| 	}).Error | ||||
| 	if err != nil { | ||||
| 		tx.Rollback() | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	tx.Commit() | ||||
|  | ||||
| 	// remove image | ||||
| 	err := h.uploader.GetUploadHandler().Delete(data.ImgURL) | ||||
| 	err = h.uploader.GetUploadHandler().Delete(job.ImgURL) | ||||
| 	if err != nil { | ||||
| 		logger.Error("remove image failed: ", err) | ||||
| 	} | ||||
|  | ||||
| 	client := h.pool.Clients.Get(data.UserId) | ||||
| 	client := h.pool.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
| @@ -472,17 +508,12 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { | ||||
|  | ||||
| // Publish 发布图片到画廊显示 | ||||
| func (h *MidJourneyHandler) Publish(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint `json:"id"` | ||||
| 		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action) | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享 | ||||
| 	res := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,12 +1,20 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -20,23 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { | ||||
| 	return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // List 订单列表 | ||||
| func (h *OrderHandler) List(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Page     int `json:"page"` | ||||
| 		PageSize int `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	page := h.GetInt(c, "page", 1) | ||||
| 	pageSize := h.GetInt(c, "page_size", 20) | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess) | ||||
| 	var total int64 | ||||
| 	session.Model(&model.Order{}).Count(&total) | ||||
| 	var items []model.Order | ||||
| 	var list = make([]vo.Order, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	offset := (page - 1) * pageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var order vo.Order | ||||
| @@ -51,5 +54,35 @@ func (h *OrderHandler) List(c *gin.Context) { | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list)) | ||||
| } | ||||
|  | ||||
| // Query 查询订单状态 | ||||
| func (h *OrderHandler) Query(c *gin.Context) { | ||||
| 	orderNo := h.GetTrim(c, "order_no") | ||||
| 	var order model.Order | ||||
| 	res := h.DB.Where("order_no = ?", orderNo).First(&order) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Order not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if order.Status == types.OrderPaidSuccess { | ||||
| 		resp.SUCCESS(c, gin.H{"status": order.Status}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	counter := 0 | ||||
| 	for { | ||||
| 		time.Sleep(time.Second) | ||||
| 		var item model.Order | ||||
| 		h.DB.Where("order_no = ?", orderNo).First(&item) | ||||
| 		if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status { | ||||
| 			order.Status = item.Status | ||||
| 			break | ||||
| 		} | ||||
| 		counter++ | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, gin.H{"status": order.Status}) | ||||
| } | ||||
|   | ||||
| @@ -1,16 +1,23 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/payment" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"embed" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/payment" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/shopspring/decimal" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| @@ -22,10 +29,16 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	PayWayAlipay = "支付宝" | ||||
| 	PayWayXunHu  = "虎皮椒" | ||||
| 	PayWayJs     = "PayJS" | ||||
| type PayWay struct { | ||||
| 	Name  string `json:"name"` | ||||
| 	Value string `json:"value"` | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	PayWayAlipay = PayWay{Name: "支付宝", Value: "alipay"} | ||||
| 	PayWayXunHu  = PayWay{Name: "虎皮椒", Value: "hupi"} | ||||
| 	PayWayJs     = PayWay{Name: "PayJS", Value: "payjs"} | ||||
| 	PayWayWechat = PayWay{Name: "微信支付", Value: "wechat"} | ||||
| ) | ||||
|  | ||||
| // PaymentHandler 支付服务回调 handler | ||||
| @@ -33,24 +46,28 @@ type PaymentHandler struct { | ||||
| 	BaseHandler | ||||
| 	alipayService    *payment.AlipayService | ||||
| 	huPiPayService   *payment.HuPiPayService | ||||
| 	js             *payment.PayJS | ||||
| 	jsPayService     *payment.JPayService | ||||
| 	wechatPayService *payment.WechatPayService | ||||
| 	snowflake        *service.Snowflake | ||||
| 	fs               embed.FS | ||||
| 	lock             sync.Mutex | ||||
| 	signKey          string // 用来签名的随机秘钥 | ||||
| } | ||||
|  | ||||
| func NewPaymentHandler( | ||||
| 	server *core.AppServer, | ||||
| 	alipayService *payment.AlipayService, | ||||
| 	huPiPayService *payment.HuPiPayService, | ||||
| 	js *payment.PayJS, | ||||
| 	jsPayService *payment.JPayService, | ||||
| 	wechatPayService *payment.WechatPayService, | ||||
| 	db *gorm.DB, | ||||
| 	snowflake *service.Snowflake, | ||||
| 	fs embed.FS) *PaymentHandler { | ||||
| 	return &PaymentHandler{ | ||||
| 		alipayService:    alipayService, | ||||
| 		huPiPayService:   huPiPayService, | ||||
| 		js:             js, | ||||
| 		jsPayService:     jsPayService, | ||||
| 		wechatPayService: wechatPayService, | ||||
| 		snowflake:        snowflake, | ||||
| 		fs:               fs, | ||||
| 		lock:             sync.Mutex{}, | ||||
| @@ -58,12 +75,27 @@ func NewPaymentHandler( | ||||
| 			App: server, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 		signKey: utils.RandString(32), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PaymentHandler) DoPay(c *gin.Context) { | ||||
| 	orderNo := h.GetTrim(c, "order_no") | ||||
| 	payWay := h.GetTrim(c, "pay_way") | ||||
| 	t := h.GetInt(c, "t", 0) | ||||
| 	sign := h.GetTrim(c, "sign") | ||||
| 	signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey) | ||||
| 	newSign := utils.Sha256(signStr) | ||||
| 	if newSign != sign { | ||||
| 		resp.ERROR(c, "订单签名错误!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查二维码是否过期 | ||||
| 	if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) { | ||||
| 		resp.ERROR(c, "支付二维码已过期,请重新生成!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if orderNo == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -79,19 +111,16 @@ func (h *PaymentHandler) DoPay(c *gin.Context) { | ||||
|  | ||||
| 	// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回 | ||||
| 	if order.Status == types.OrderPaidSuccess { | ||||
| 		resp.ERROR(c, "This order had been paid, please do not pay twice") | ||||
| 		resp.ERROR(c, "订单已支付成功,无需重复支付!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 更新扫码状态 | ||||
| 	h.DB.Model(&order).UpdateColumn("status", types.OrderScanned) | ||||
| 	if payWay == "alipay" { // 支付宝 | ||||
| 		// 生成支付链接 | ||||
| 		notifyURL := h.App.Config.AlipayConfig.NotifyURL | ||||
| 		returnURL := "" // 关闭同步回跳 | ||||
| 		amount := fmt.Sprintf("%.2f", order.Amount) | ||||
|  | ||||
| 		uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject) | ||||
| 	if payWay == "alipay" { // 支付宝 | ||||
| 		amount := fmt.Sprintf("%.2f", order.Amount) | ||||
| 		uri, err := h.alipayService.PayUrlMobile(order.OrderNo, amount, order.Subject) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "error with generate pay url: "+err.Error()) | ||||
| 			return | ||||
| @@ -119,49 +148,11 @@ func (h *PaymentHandler) DoPay(c *gin.Context) { | ||||
| 	resp.ERROR(c, "Invalid operations") | ||||
| } | ||||
|  | ||||
| // OrderQuery 查询订单状态 | ||||
| func (h *PaymentHandler) OrderQuery(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		OrderNo string `json:"order_no"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var order model.Order | ||||
| 	res := h.DB.Where("order_no = ?", data.OrderNo).First(&order) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Order not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if order.Status == types.OrderPaidSuccess { | ||||
| 		resp.SUCCESS(c, gin.H{"status": order.Status}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	counter := 0 | ||||
| 	for { | ||||
| 		time.Sleep(time.Second) | ||||
| 		var item model.Order | ||||
| 		h.DB.Where("order_no = ?", data.OrderNo).First(&item) | ||||
| 		if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status { | ||||
| 			order.Status = item.Status | ||||
| 			break | ||||
| 		} | ||||
| 		counter++ | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, gin.H{"status": order.Status}) | ||||
| } | ||||
|  | ||||
| // PayQrcode 生成支付 URL 二维码 | ||||
| func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		PayWay    string `json:"pay_way"` // 支付方式 | ||||
| 		ProductId uint   `json:"product_id"` | ||||
| 		UserId    int    `json:"user_id"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -180,10 +171,9 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 		resp.ERROR(c, "error with generate trade no: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	res = h.DB.First(&user, data.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Invalid user ID") | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -191,14 +181,21 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 	var notifyURL string | ||||
| 	switch data.PayWay { | ||||
| 	case "hupi": | ||||
| 		payWay = PayWayXunHu | ||||
| 		payWay = PayWayXunHu.Value | ||||
| 		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL | ||||
| 		break | ||||
| 	case "payjs": | ||||
| 		payWay = PayWayJs | ||||
| 		payWay = PayWayJs.Value | ||||
| 		notifyURL = h.App.Config.JPayConfig.NotifyURL | ||||
| 	default: | ||||
| 		payWay = PayWayAlipay | ||||
| 		break | ||||
| 	case "alipay": | ||||
| 		payWay = PayWayAlipay.Value | ||||
| 		notifyURL = h.App.Config.AlipayConfig.NotifyURL | ||||
| 		break | ||||
| 	default: | ||||
| 		payWay = PayWayWechat.Value | ||||
| 		notifyURL = h.App.Config.WechatPayConfig.NotifyURL | ||||
|  | ||||
| 	} | ||||
| 	// 创建订单 | ||||
| 	remark := types.OrderRemark{ | ||||
| @@ -234,7 +231,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 			OutTradeNo: order.OrderNo, | ||||
| 			Subject:    product.Name, | ||||
| 		} | ||||
| 		r := h.js.Pay(params) | ||||
| 		r := h.jsPayService.Pay(params) | ||||
| 		if r.IsOK() { | ||||
| 			resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode}) | ||||
| 			return | ||||
| @@ -253,6 +250,8 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 		} else { | ||||
| 			logo = "res/img/alipay.jpg" | ||||
| 		} | ||||
| 	} else if data.PayWay == "wechat" { | ||||
| 		logo = "res/img/wechat-pay.jpg" | ||||
| 	} | ||||
|  | ||||
| 	file, err := h.fs.Open(logo) | ||||
| @@ -266,8 +265,21 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay) | ||||
| 	timestamp := time.Now().Unix() | ||||
| 	signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey) | ||||
| 	sign := utils.Sha256(signStr) | ||||
| 	var imageURL string | ||||
| 	if data.PayWay == "wechat" { | ||||
| 		payUrl, err := h.wechatPayService.PayUrlNative(order.OrderNo, int(math.Floor(order.Amount*100)), product.Name) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "error with generating wechat payment qrcode: "+err.Error()) | ||||
| 			return | ||||
| 		} else { | ||||
| 			imageURL = payUrl | ||||
| 		} | ||||
| 	} else { | ||||
| 		imageURL = fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign) | ||||
| 	} | ||||
| 	imgData, err := utils.GenQrcode(imageURL, 400, file) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| @@ -282,7 +294,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		PayWay    string `json:"pay_way"` // 支付方式 | ||||
| 		ProductId uint   `json:"product_id"` | ||||
| 		UserId    int    `json:"user_id"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -301,10 +312,9 @@ func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 		resp.ERROR(c, "error with generate trade no: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	res = h.DB.First(&user, data.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Invalid user ID") | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -314,9 +324,11 @@ func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 	var payURL string | ||||
| 	switch data.PayWay { | ||||
| 	case "hupi": | ||||
| 		payWay = PayWayXunHu | ||||
| 		payWay = PayWayXunHu.Name | ||||
| 		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL | ||||
| 		returnURL = h.App.Config.HuPiPayConfig.ReturnURL | ||||
| 		parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL) | ||||
| 		baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host) | ||||
| 		params := payment.HuPiPayReq{ | ||||
| 			Version:      "1.1", | ||||
| 			TradeOrderId: orderNo, | ||||
| @@ -326,16 +338,19 @@ func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 			ReturnURL:    returnURL, | ||||
| 			CallbackURL:  returnURL, | ||||
| 			WapName:      "极客学长", | ||||
| 			WapUrl:       baseURL, | ||||
| 			Type:         "WAP", | ||||
| 		} | ||||
| 		r, err := h.huPiPayService.Pay(params) | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with generating Pay URL: ", err.Error()) | ||||
| 			resp.ERROR(c, "error with generating Pay URL: "+err.Error()) | ||||
| 			errMsg := "error with generating Pay Hupi URL: " + err.Error() | ||||
| 			logger.Error(errMsg) | ||||
| 			resp.ERROR(c, errMsg) | ||||
| 			return | ||||
| 		} | ||||
| 		payURL = r.URL | ||||
| 	case "payjs": | ||||
| 		payWay = PayWayJs | ||||
| 		payWay = PayWayJs.Name | ||||
| 		notifyURL = h.App.Config.JPayConfig.NotifyURL | ||||
| 		returnURL = h.App.Config.JPayConfig.ReturnURL | ||||
| 		totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart() | ||||
| @@ -345,14 +360,22 @@ func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 		params.Add("body", product.Name) | ||||
| 		params.Add("notify_url", notifyURL) | ||||
| 		params.Add("auto", "0") | ||||
| 		payURL = h.js.PayH5(params) | ||||
| 		payURL = h.jsPayService.PayH5(params) | ||||
| 	case "alipay": | ||||
| 		payWay = PayWayAlipay | ||||
| 		notifyURL = h.App.Config.AlipayConfig.NotifyURL | ||||
| 		returnURL = h.App.Config.AlipayConfig.ReturnURL | ||||
| 		payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name) | ||||
| 		payWay = PayWayAlipay.Name | ||||
| 		payURL, err = h.alipayService.PayUrlMobile(orderNo, fmt.Sprintf("%.2f", amount), product.Name) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "error with generating Pay URL: "+err.Error()) | ||||
| 			errMsg := "error with generating Alipay URL: " + err.Error() | ||||
| 			resp.ERROR(c, errMsg) | ||||
| 			return | ||||
| 		} | ||||
| 	case "wechat": | ||||
| 		payWay = PayWayWechat.Name | ||||
| 		payURL, err = h.wechatPayService.PayUrlH5(orderNo, int(amount*100), product.Name, c.ClientIP()) | ||||
| 		if err != nil { | ||||
| 			errMsg := "error with generating Wechat URL: " + err.Error() | ||||
| 			logger.Error(errMsg) | ||||
| 			resp.ERROR(c, errMsg) | ||||
| 			return | ||||
| 		} | ||||
| 	default: | ||||
| @@ -385,7 +408,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, payURL) | ||||
| 	resp.SUCCESS(c, gin.H{"url": payURL, "order_no": orderNo}) | ||||
| } | ||||
|  | ||||
| // 异步通知回调公共逻辑 | ||||
| @@ -424,29 +447,23 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { | ||||
|  | ||||
| 	var opt string | ||||
| 	var power int | ||||
| 	if user.Vip { // 已经是 VIP 用户 | ||||
| 		if remark.Days > 0 { // 只延期 VIP,不增加调用次数 | ||||
| 	if remark.Days > 0 { // VIP 充值 | ||||
| 		if user.ExpiredTime >= time.Now().Unix() { | ||||
| 			user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix() | ||||
| 			opt = "VIP充值,VIP 没到期,只延期不增加算力" | ||||
| 		} else { | ||||
| 			user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix() | ||||
| 			user.Power += h.App.SysConfig.VipMonthPower | ||||
| 			power = h.App.SysConfig.VipMonthPower | ||||
| 			opt = "VIP充值" | ||||
| 		} | ||||
| 		user.Vip = true | ||||
| 	} else { // 充值点卡,直接增加次数即可 | ||||
| 		user.Power += remark.Power | ||||
| 		opt = "点卡充值" | ||||
| 		power = remark.Power | ||||
| 	} | ||||
|  | ||||
| 	} else {                 // 非 VIP 用户 | ||||
| 		if remark.Days > 0 { // vip 套餐:days > 0, power == 0 | ||||
| 			user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix() | ||||
| 			user.Power += h.App.SysConfig.VipMonthPower | ||||
| 			user.Vip = true | ||||
| 			opt = "VIP充值" | ||||
| 			power = h.App.SysConfig.VipMonthPower | ||||
| 		} else { //点卡:days == 0, calls > 0 | ||||
| 			user.Power += remark.Power | ||||
| 			opt = "点卡充值" | ||||
| 			power = remark.Power | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 更新用户信息 | ||||
| 	res = h.DB.Updates(&user) | ||||
| 	if res.Error != nil { | ||||
| @@ -470,7 +487,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { | ||||
| 	h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1)) | ||||
|  | ||||
| 	// 记录算力充值日志 | ||||
| 	if opt != "" { | ||||
| 	if power > 0 { | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| @@ -499,6 +516,9 @@ func (h *PaymentHandler) GetPayWays(c *gin.Context) { | ||||
| 	if h.App.Config.JPayConfig.Enabled { | ||||
| 		data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name} | ||||
| 	} | ||||
| 	if h.App.Config.WechatPayConfig.Enabled { | ||||
| 		data["wechat"] = gin.H{"name": "wechat"} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, data) | ||||
| } | ||||
|  | ||||
| @@ -537,7 +557,7 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	// TODO:验证交易签名 | ||||
| 	res := h.alipayService.TradeVerify(c.Request.Form) | ||||
| 	res := h.alipayService.TradeVerify(c.Request) | ||||
| 	logger.Infof("验证支付结果:%+v", res) | ||||
| 	if !res.Success() { | ||||
| 		logger.Error("订单校验失败:", res.Message) | ||||
| @@ -565,7 +585,7 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) { | ||||
|  | ||||
| 	orderNo := c.Request.Form.Get("out_trade_no") | ||||
| 	returnCode := c.Request.Form.Get("return_code") | ||||
| 	logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode) | ||||
| 	logger.Infof("收到PayJs订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode) | ||||
| 	// 支付失败 | ||||
| 	if returnCode != "1" { | ||||
| 		return | ||||
| @@ -573,7 +593,7 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) { | ||||
|  | ||||
| 	// 校验订单支付状态 | ||||
| 	tradeNo := c.Request.Form.Get("payjs_order_id") | ||||
| 	err = h.js.Check(tradeNo) | ||||
| 	err = h.jsPayService.TradeVerify(tradeNo) | ||||
| 	if err != nil { | ||||
| 		logger.Error("订单校验失败:", err) | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| @@ -588,3 +608,30 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) { | ||||
|  | ||||
| 	c.String(http.StatusOK, "success") | ||||
| } | ||||
|  | ||||
| // WechatPayNotify 微信商户支付异步回调 | ||||
| func (h *PaymentHandler) WechatPayNotify(c *gin.Context) { | ||||
| 	err := c.Request.ParseForm() | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	result := h.wechatPayService.TradeVerify(c.Request) | ||||
| 	if !result.Success() { | ||||
| 		logger.Error("订单校验失败:", err) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{ | ||||
| 			"code":    "FAIL", | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.notify(result.OutTradeNo, result.TradeId) | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.String(http.StatusOK, "success") | ||||
| } | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|   | ||||
| @@ -1,13 +1,20 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"math" | ||||
| @@ -50,12 +57,12 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 	var item model.Reward | ||||
| 	res := h.DB.Where("tx_id = ?", data.TxId).First(&item) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "无效的众筹交易流水号!") | ||||
| 		resp.ERROR(c, "无效的交易流水号!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if item.Status { | ||||
| 		resp.ERROR(c, "当前众筹交易流水号已经被核销,请不要重复核销!") | ||||
| 		resp.ERROR(c, "当前交易流水号已经被核销,请不要重复核销!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -66,6 +73,7 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 	res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power)) | ||||
| 	if res.Error != nil { | ||||
| 		tx.Rollback() | ||||
| 		logger.Error("添加应用失败:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -77,6 +85,7 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 	res = tx.Updates(&item) | ||||
| 	if res.Error != nil { | ||||
| 		tx.Rollback() | ||||
| 		logger.Error("添加应用失败:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -90,7 +99,7 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 		Balance:   user.Power + exchange.Power, | ||||
| 		Mark:      types.PowerAdd, | ||||
| 		Model:     "众筹支付", | ||||
| 		Remark:    fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice), | ||||
| 		Remark:    fmt.Sprintf("充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice), | ||||
| 		CreatedAt: time.Now(), | ||||
| 	}) | ||||
| 	tx.Commit() | ||||
|   | ||||
| @@ -1,16 +1,24 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/service/sd" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| @@ -26,12 +34,16 @@ type SdJobHandler struct { | ||||
| 	redis     *redis.Client | ||||
| 	pool      *sd.ServicePool | ||||
| 	uploader  *oss.UploaderManager | ||||
| 	snowflake *service.Snowflake | ||||
| 	leveldb   *store.LevelDB | ||||
| } | ||||
|  | ||||
| func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler { | ||||
| func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { | ||||
| 	return &SdJobHandler{ | ||||
| 		pool:      pool, | ||||
| 		uploader:  manager, | ||||
| 		snowflake: snowflake, | ||||
| 		leveldb:   levelDB, | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: app, | ||||
| 			DB:  db, | ||||
| @@ -60,7 +72,7 @@ func (h *SdJobHandler) Client(c *gin.Context) { | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) | ||||
| } | ||||
|  | ||||
| func (h *SdJobHandler) checkLimits(c *gin.Context) bool { | ||||
| func (h *SdJobHandler) preCheck(c *gin.Context) bool { | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| @@ -83,14 +95,11 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool { | ||||
|  | ||||
| // Image 创建一个绘画任务 | ||||
| func (h *SdJobHandler) Image(c *gin.Context) { | ||||
| 	if !h.checkLimits(c) { | ||||
| 	if !h.preCheck(c) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data struct { | ||||
| 		SessionId string `json:"session_id"` | ||||
| 		types.SdTaskParams | ||||
| 	} | ||||
| 	var data types.SdTaskParams | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| @@ -116,10 +125,15 @@ func (h *SdJobHandler) Image(c *gin.Context) { | ||||
| 	} | ||||
| 	idValue, _ := c.Get(types.LoginUserID) | ||||
| 	userId := utils.IntValue(utils.InterfaceToString(idValue), 0) | ||||
| 	taskId, err := h.snowflake.Next(true) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "error with generate task id: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	params := types.SdTaskParams{ | ||||
| 		TaskId:         fmt.Sprintf("task(%s)", utils.RandString(15)), | ||||
| 		TaskId:       taskId, | ||||
| 		Prompt:       data.Prompt, | ||||
| 		NegativePrompt: data.NegativePrompt, | ||||
| 		NegPrompt:    data.NegPrompt, | ||||
| 		Steps:        data.Steps, | ||||
| 		Sampler:      data.Sampler, | ||||
| 		FaceFix:      data.FaceFix, | ||||
| @@ -152,7 +166,6 @@ func (h *SdJobHandler) Image(c *gin.Context) { | ||||
|  | ||||
| 	h.pool.PushTask(types.SdTask{ | ||||
| 		Id:     int(job.Id), | ||||
| 		SessionId: data.SessionId, | ||||
| 		Type:   types.TaskImage, | ||||
| 		Params: params, | ||||
| 		UserId: userId, | ||||
| @@ -199,13 +212,13 @@ func (h *SdJobHandler) ImgWall(c *gin.Context) { | ||||
|  | ||||
| // JobList 获取 SD 任务列表 | ||||
| func (h *SdJobHandler) JobList(c *gin.Context) { | ||||
| 	status := h.GetBool(c, "status") | ||||
| 	finish := h.GetBool(c, "finish") | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	publish := h.GetBool(c, "publish") | ||||
|  | ||||
| 	err, jobs := h.getData(status, userId, page, pageSize, publish) | ||||
| 	err, jobs := h.getData(finish, userId, page, pageSize, publish) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| @@ -249,10 +262,11 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, | ||||
| 		} | ||||
|  | ||||
| 		if item.Progress < 100 { | ||||
| 			// 正在运行中任务使用代理访问图片 | ||||
| 			image, err := utils.DownloadImage(item.ImgURL, "") | ||||
| 			// 从 leveldb 中获取图片预览数据 | ||||
| 			var imageData string | ||||
| 			err = h.leveldb.Get(item.TaskId, &imageData) | ||||
| 			if err == nil { | ||||
| 				job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) | ||||
| 				job.ImgURL = "data:image/png;base64," + imageData | ||||
| 			} | ||||
| 		} | ||||
| 		jobs = append(jobs, job) | ||||
| @@ -263,32 +277,30 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, | ||||
|  | ||||
| // Remove remove task image | ||||
| func (h *SdJobHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint   `json:"id"` | ||||
| 		UserId uint   `json:"user_id"` | ||||
| 		ImgURL string `json:"img_url"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	var job model.SdJob | ||||
| 	if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil { | ||||
| 		resp.ERROR(c, "记录不存在") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove job recode | ||||
| 	res := h.DB.Delete(&model.SdJob{Id: data.Id}) | ||||
| 	res := h.DB.Delete(&model.SdJob{Id: job.Id}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove image | ||||
| 	err := h.uploader.GetUploadHandler().Delete(data.ImgURL) | ||||
| 	err := h.uploader.GetUploadHandler().Delete(job.ImgURL) | ||||
| 	if err != nil { | ||||
| 		logger.Error("remove image failed: ", err) | ||||
| 	} | ||||
|  | ||||
| 	client := h.pool.Clients.Get(data.UserId) | ||||
| 	client := h.pool.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 		_ = client.Send([]byte(sd.Finished)) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| @@ -296,17 +308,13 @@ func (h *SdJobHandler) Remove(c *gin.Context) { | ||||
|  | ||||
| // Publish 发布/取消发布图片到画廊显示 | ||||
| func (h *SdJobHandler) Publish(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint `json:"id"` | ||||
| 		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享 | ||||
|  | ||||
| 	res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true) | ||||
| 	res := h.DB.Model(&model.SdJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/sms" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/sms" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -42,14 +49,20 @@ func (h *SmsHandler) SendCode(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Receiver string `json:"receiver"` // 接收者 | ||||
| 		Key      string `json:"key"` | ||||
| 		Dots     string `json:"dots"` | ||||
| 		Dots     string `json:"dots,omitempty"` | ||||
| 		X        int    `json:"x,omitempty"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !h.captcha.Check(data) { | ||||
| 	var check bool | ||||
| 	if data.X != 0 { | ||||
| 		check = h.captcha.SlideCheck(data) | ||||
| 	} else { | ||||
| 		check = h.captcha.Check(data) | ||||
| 	} | ||||
| 	if !check { | ||||
| 		resp.ERROR(c, "验证码错误,请先完人机验证") | ||||
| 		return | ||||
| 	} | ||||
| @@ -57,13 +70,13 @@ func (h *SmsHandler) SendCode(c *gin.Context) { | ||||
| 	code := utils.RandomNumber(6) | ||||
| 	var err error | ||||
| 	if strings.Contains(data.Receiver, "@") { // email | ||||
| 		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") { | ||||
| 		if !utils.Contains(h.App.SysConfig.RegisterWays, "email") { | ||||
| 			resp.ERROR(c, "系统已禁用邮箱注册!") | ||||
| 			return | ||||
| 		} | ||||
| 		err = h.smtp.SendVerifyCode(data.Receiver, code) | ||||
| 	} else { | ||||
| 		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") { | ||||
| 		if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") { | ||||
| 			resp.ERROR(c, "系统已禁用手机号注册!") | ||||
| 			return | ||||
| 		} | ||||
| @@ -82,5 +95,9 @@ func (h *SmsHandler) SendCode(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.App.Debug { | ||||
| 		resp.SUCCESS(c, code) | ||||
| 	} else { | ||||
| 		resp.SUCCESS(c) | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										345
									
								
								api/handler/suno_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										345
									
								
								api/handler/suno_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,345 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/service/suno" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type SunoHandler struct { | ||||
| 	BaseHandler | ||||
| 	service  *suno.Service | ||||
| 	uploader *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler { | ||||
| 	return &SunoHandler{ | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: app, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 		service:  service, | ||||
| 		uploader: uploader, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Client WebSocket 客户端,用于通知任务状态变更 | ||||
| func (h *SunoHandler) Client(c *gin.Context) { | ||||
| 	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	if userId == 0 { | ||||
| 		logger.Info("Invalid user ID") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	h.service.Clients.Put(uint(userId), client) | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) | ||||
| } | ||||
|  | ||||
| func (h *SunoHandler) Create(c *gin.Context) { | ||||
|  | ||||
| 	var data struct { | ||||
| 		Prompt       string `json:"prompt"` | ||||
| 		Instrumental bool   `json:"instrumental"` | ||||
| 		Lyrics       string `json:"lyrics"` | ||||
| 		Model        string `json:"model"` | ||||
| 		Tags         string `json:"tags"` | ||||
| 		Title        string `json:"title"` | ||||
| 		Type         int    `json:"type"` | ||||
| 		RefTaskId    string `json:"ref_task_id"` // 续写的任务id | ||||
| 		ExtendSecs   int    `json:"extend_secs"` // 续写秒数 | ||||
| 		RefSongId    string `json:"ref_song_id"` // 续写的歌曲id | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 插入数据库 | ||||
| 	job := model.SunoJob{ | ||||
| 		UserId:       int(h.GetLoginUserId(c)), | ||||
| 		Prompt:       data.Prompt, | ||||
| 		Instrumental: data.Instrumental, | ||||
| 		ModelName:    data.Model, | ||||
| 		Tags:         data.Tags, | ||||
| 		Title:        data.Title, | ||||
| 		Type:         data.Type, | ||||
| 		RefSongId:    data.RefSongId, | ||||
| 		RefTaskId:    data.RefTaskId, | ||||
| 		ExtendSecs:   data.ExtendSecs, | ||||
| 		Power:        h.App.SysConfig.SunoPower, | ||||
| 	} | ||||
| 	if data.Lyrics != "" { | ||||
| 		job.Prompt = data.Lyrics | ||||
| 	} | ||||
| 	tx := h.DB.Create(&job) | ||||
| 	if tx.Error != nil { | ||||
| 		resp.ERROR(c, tx.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 创建任务 | ||||
| 	h.service.PushTask(types.SunoTask{ | ||||
| 		Id:           job.Id, | ||||
| 		UserId:       job.UserId, | ||||
| 		Type:         job.Type, | ||||
| 		Title:        job.Title, | ||||
| 		RefTaskId:    data.RefTaskId, | ||||
| 		RefSongId:    data.RefSongId, | ||||
| 		ExtendSecs:   data.ExtendSecs, | ||||
| 		Prompt:       job.Prompt, | ||||
| 		Tags:         data.Tags, | ||||
| 		Model:        data.Model, | ||||
| 		Instrumental: data.Instrumental, | ||||
| 	}) | ||||
|  | ||||
| 	// update user's power | ||||
| 	tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		user, _ := h.GetLoginUser(c) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    job.Power, | ||||
| 			Balance:   user.Power - job.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     job.ModelName, | ||||
| 			Remark:    fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	client := h.service.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *SunoHandler) List(c *gin.Context) { | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) | ||||
|  | ||||
| 	// 统计总数 | ||||
| 	var total int64 | ||||
| 	session.Model(&model.SunoJob{}).Count(&total) | ||||
|  | ||||
| 	if page > 0 && pageSize > 0 { | ||||
| 		offset := (page - 1) * pageSize | ||||
| 		session = session.Offset(offset).Limit(pageSize) | ||||
| 	} | ||||
| 	var list []model.SunoJob | ||||
| 	err := session.Order("id desc").Find(&list).Error | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	// 初始化续写关系 | ||||
| 	songIds := make([]string, 0) | ||||
| 	for _, v := range list { | ||||
| 		if v.RefTaskId != "" { | ||||
| 			songIds = append(songIds, v.RefSongId) | ||||
| 		} | ||||
| 	} | ||||
| 	var tasks []model.SunoJob | ||||
| 	h.DB.Where("song_id IN ?", songIds).Find(&tasks) | ||||
| 	songMap := make(map[string]model.SunoJob) | ||||
| 	for _, t := range tasks { | ||||
| 		songMap[t.SongId] = t | ||||
| 	} | ||||
| 	// 转换为 VO | ||||
| 	items := make([]vo.SunoJob, 0) | ||||
| 	for _, v := range list { | ||||
| 		var item vo.SunoJob | ||||
| 		err = utils.CopyObject(v, &item) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		item.CreatedAt = v.CreatedAt.Unix() | ||||
| 		if s, ok := songMap[v.RefSongId]; ok { | ||||
| 			item.RefSong = map[string]interface{}{ | ||||
| 				"id":    s.Id, | ||||
| 				"title": s.Title, | ||||
| 				"cover": s.CoverURL, | ||||
| 				"audio": s.AudioURL, | ||||
| 			} | ||||
| 		} | ||||
| 		items = append(items, item) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items)) | ||||
| } | ||||
|  | ||||
| func (h *SunoHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var job model.SunoJob | ||||
| 	err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	// 删除任务 | ||||
| 	h.DB.Delete(&job) | ||||
| 	// 删除文件 | ||||
| 	_ = h.uploader.GetUploadHandler().Delete(job.CoverURL) | ||||
| 	_ = h.uploader.GetUploadHandler().Delete(job.AudioURL) | ||||
| } | ||||
|  | ||||
| func (h *SunoHandler) Publish(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	publish := h.GetBool(c, "publish") | ||||
| 	err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *SunoHandler) Update(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id    int    `json:"id"` | ||||
| 		Title string `json:"title"` | ||||
| 		Cover string `json:"cover"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if data.Id == 0 || data.Title == "" || data.Cover == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var item model.SunoJob | ||||
| 	if err := h.DB.Where("id", data.Id).Where("user_id", userId).First(&item).Error; err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	item.Title = data.Title | ||||
| 	item.CoverURL = data.Cover | ||||
|  | ||||
| 	if err := h.DB.Updates(&item).Error; err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Detail 歌曲详情 | ||||
| func (h *SunoHandler) Detail(c *gin.Context) { | ||||
| 	songId := c.Query("song_id") | ||||
| 	if songId == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	var item model.SunoJob | ||||
| 	if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 读取用户信息 | ||||
| 	var user model.User | ||||
| 	if err := h.DB.Where("id", item.UserId).First(&user).Error; err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var itemVo vo.SunoJob | ||||
| 	if err := utils.CopyObject(item, &itemVo); err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	itemVo.CreatedAt = item.CreatedAt.Unix() | ||||
| 	itemVo.User = map[string]interface{}{ | ||||
| 		"nickname": user.Nickname, | ||||
| 		"avatar":   user.Avatar, | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, itemVo) | ||||
| } | ||||
|  | ||||
| // Play 增加歌曲播放次数 | ||||
| func (h *SunoHandler) Play(c *gin.Context) { | ||||
| 	songId := c.Query("song_id") | ||||
| 	if songId == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1)) | ||||
| } | ||||
|  | ||||
| const genLyricTemplate = ` | ||||
| 你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。 | ||||
| 请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。 | ||||
| 输出格式如下: | ||||
| 歌曲名称 | ||||
| 第一节: | ||||
| {{歌词内容}} | ||||
| 副歌: | ||||
| {{歌词内容}} | ||||
|  | ||||
| 第二节: | ||||
| {{歌词内容}} | ||||
| 副歌: | ||||
| {{歌词内容}} | ||||
|  | ||||
| 尾声: | ||||
| {{歌词内容}} | ||||
| ` | ||||
|  | ||||
| // Lyric 生成歌词 | ||||
| func (h *SunoHandler) Lyric(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Prompt string `json:"prompt"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini") | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, content) | ||||
| } | ||||
| @@ -1,17 +1,17 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/payment" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/payment" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type TestHandler struct { | ||||
| 	db        *gorm.DB | ||||
| 	snowflake *service.Snowflake | ||||
| 	js        *payment.PayJS | ||||
| 	js        *payment.JPayService | ||||
| } | ||||
|  | ||||
| func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler { | ||||
| func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.JPayService) *TestHandler { | ||||
| 	return &TestHandler{db: db, snowflake: snowflake, js: js} | ||||
| } | ||||
|   | ||||
| @@ -1,12 +1,20 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| @@ -28,6 +36,12 @@ func (h *UploadHandler) Upload(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	logger.Info("upload file: ", file.Name) | ||||
| 	// cut the file name if it's too long | ||||
| 	if len(file.Name) > 100 { | ||||
| 		file.Name = file.Name[:90] + file.Ext | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	res := h.DB.Create(&model.File{ | ||||
| 		UserId:    int(userId), | ||||
| @@ -47,10 +61,23 @@ func (h *UploadHandler) Upload(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *UploadHandler) List(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Urls []string `json:"urls,omitempty"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var items []model.File | ||||
| 	var files = make([]vo.File, 0) | ||||
| 	h.DB.Where("user_id = ?", userId).Find(&items) | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	session = session.Where("user_id = ?", userId) | ||||
| 	if len(data.Urls) > 0 { | ||||
| 		session = session.Where("url IN ?", data.Urls) | ||||
| 	} | ||||
| 	session.Find(&items) | ||||
| 	if len(items) > 0 { | ||||
| 		for _, v := range items { | ||||
| 			var file vo.File | ||||
|   | ||||
| @@ -1,13 +1,22 @@ | ||||
| package handler | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/store/vo" | ||||
| 	"geekai/utils" | ||||
| 	"geekai/utils/resp" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| @@ -23,14 +32,21 @@ type UserHandler struct { | ||||
| 	BaseHandler | ||||
| 	searcher       *xdb.Searcher | ||||
| 	redis          *redis.Client | ||||
| 	licenseService *service.LicenseService | ||||
| } | ||||
|  | ||||
| func NewUserHandler( | ||||
| 	app *core.AppServer, | ||||
| 	db *gorm.DB, | ||||
| 	searcher *xdb.Searcher, | ||||
| 	client *redis.Client) *UserHandler { | ||||
| 	return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client} | ||||
| 	client *redis.Client, | ||||
| 	licenseService *service.LicenseService) *UserHandler { | ||||
| 	return &UserHandler{ | ||||
| 		BaseHandler:    BaseHandler{DB: db, App: app}, | ||||
| 		searcher:       searcher, | ||||
| 		redis:          client, | ||||
| 		licenseService: licenseService, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Register user register | ||||
| @@ -53,9 +69,17 @@ func (h *UserHandler) Register(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检测最大注册人数 | ||||
| 	var totalUser int64 | ||||
| 	h.DB.Model(&model.User{}).Count(&totalUser) | ||||
| 	if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { | ||||
| 		resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查验证码 | ||||
| 	var key string | ||||
| 	if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" { | ||||
| 	if data.RegWay == "email" || data.RegWay == "mobile" { | ||||
| 		key = CodeStorePrefix + data.Username | ||||
| 		code, err := h.redis.Get(c, key).Result() | ||||
| 		if err != nil || code != data.Code { | ||||
| @@ -74,7 +98,7 @@ func (h *UserHandler) Register(c *gin.Context) { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// check if the username is exists | ||||
| 	// check if the username is existing | ||||
| 	var item model.User | ||||
| 	res := h.DB.Where("username = ?", data.Username).First(&item) | ||||
| 	if item.Id > 0 { | ||||
| @@ -86,7 +110,6 @@ func (h *UserHandler) Register(c *gin.Context) { | ||||
| 	user := model.User{ | ||||
| 		Username:   data.Username, | ||||
| 		Password:   utils.GenPassword(data.Password, salt), | ||||
| 		Nickname:   fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)), | ||||
| 		Avatar:     "/images/avatar/user.png", | ||||
| 		Salt:       salt, | ||||
| 		Status:     true, | ||||
| @@ -95,6 +118,16 @@ func (h *UserHandler) Register(c *gin.Context) { | ||||
| 		Power:      h.App.SysConfig.InitPower, | ||||
| 	} | ||||
|  | ||||
| 	// 被邀请人也获得赠送算力 | ||||
| 	if data.InviteCode != "" { | ||||
| 		user.Power += h.App.SysConfig.InvitePower | ||||
| 	} | ||||
| 	if h.licenseService.GetLicense().Configs.DeCopy { | ||||
| 		user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6)) | ||||
| 	} else { | ||||
| 		user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)) | ||||
| 	} | ||||
|  | ||||
| 	res = h.DB.Create(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "保存数据失败") | ||||
| @@ -152,7 +185,7 @@ func (h *UserHandler) Register(c *gin.Context) { | ||||
| 		resp.ERROR(c, "error with save token: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c, tokenString) | ||||
| 	resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username}) | ||||
| } | ||||
|  | ||||
| // Login 用户登录 | ||||
| @@ -211,26 +244,142 @@ func (h *UserHandler) Login(c *gin.Context) { | ||||
| 		resp.ERROR(c, "error with save token: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c, tokenString) | ||||
| 	resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username}) | ||||
| } | ||||
|  | ||||
| // Logout 注 销 | ||||
| func (h *UserHandler) Logout(c *gin.Context) { | ||||
| 	sessionId := c.GetHeader(types.ChatTokenHeader) | ||||
| 	key := h.GetUserKey(c) | ||||
| 	if _, err := h.redis.Del(c, key).Result(); err != nil { | ||||
| 		logger.Error("error with delete session: ", err) | ||||
| 	} | ||||
| 	// 删除 websocket 会话列表 | ||||
| 	h.App.ChatSession.Delete(sessionId) | ||||
| 	// 关闭 socket 连接 | ||||
| 	client := h.App.ChatClients.Get(sessionId) | ||||
| 	if client != nil { | ||||
| 		client.Close() | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // CLogin 第三方登录请求二维码 | ||||
| func (h *UserHandler) CLogin(c *gin.Context) { | ||||
| 	returnURL := h.GetTrim(c, "return_url") | ||||
| 	var res types.BizVo | ||||
| 	apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL) | ||||
| 	r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}). | ||||
| 		SetHeader("AppId", h.App.Config.ApiConfig.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)). | ||||
| 		SetSuccessResult(&res). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if r.IsErrorState() { | ||||
| 		resp.ERROR(c, "error with login http status: "+r.Status) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		resp.ERROR(c, "error with http response: "+res.Message) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, res.Data) | ||||
| } | ||||
|  | ||||
| // CLoginCallback 第三方登录回调 | ||||
| func (h *UserHandler) CLoginCallback(c *gin.Context) { | ||||
| 	loginType := h.GetTrim(c, "login_type") | ||||
| 	code := h.GetTrim(c, "code") | ||||
|  | ||||
| 	var res types.BizVo | ||||
| 	apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL) | ||||
| 	r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}). | ||||
| 		SetHeader("AppId", h.App.Config.ApiConfig.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)). | ||||
| 		SetSuccessResult(&res). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if r.IsErrorState() { | ||||
| 		resp.ERROR(c, "error with login http status: "+r.Status) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		resp.ERROR(c, "error with http response: "+res.Message) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// login successfully | ||||
| 	data := res.Data.(map[string]interface{}) | ||||
| 	session := gin.H{} | ||||
| 	var user model.User | ||||
| 	tx := h.DB.Debug().Where("openid", data["openid"]).First(&user) | ||||
| 	if tx.Error != nil { // user not exist, create new user | ||||
| 		// 检测最大注册人数 | ||||
| 		var totalUser int64 | ||||
| 		h.DB.Model(&model.User{}).Count(&totalUser) | ||||
| 		if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { | ||||
| 			resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		salt := utils.RandString(8) | ||||
| 		password := fmt.Sprintf("%d", utils.RandomNumber(8)) | ||||
| 		user = model.User{ | ||||
| 			Username:   fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)), | ||||
| 			Password:   utils.GenPassword(password, salt), | ||||
| 			Avatar:     fmt.Sprintf("%s", data["avatar"]), | ||||
| 			Salt:       salt, | ||||
| 			Status:     true, | ||||
| 			ChatRoles:  utils.JsonEncode([]string{"gpt"}),               // 默认只订阅通用助手角色 | ||||
| 			ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型 | ||||
| 			Power:      h.App.SysConfig.InitPower, | ||||
| 			OpenId:     fmt.Sprintf("%s", data["openid"]), | ||||
| 			Nickname:   fmt.Sprintf("%s", data["nickname"]), | ||||
| 		} | ||||
|  | ||||
| 		tx = h.DB.Create(&user) | ||||
| 		if tx.Error != nil { | ||||
| 			resp.ERROR(c, "保存数据失败") | ||||
| 			logger.Error(tx.Error) | ||||
| 			return | ||||
| 		} | ||||
| 		session["username"] = user.Username | ||||
| 		session["password"] = password | ||||
| 	} else { // login directly | ||||
| 		// 更新最后登录时间和IP | ||||
| 		user.LastLoginIp = c.ClientIP() | ||||
| 		user.LastLoginAt = time.Now().Unix() | ||||
| 		h.DB.Model(&user).Updates(user) | ||||
|  | ||||
| 		h.DB.Create(&model.UserLoginLog{ | ||||
| 			UserId:       user.Id, | ||||
| 			Username:     user.Username, | ||||
| 			LoginIp:      c.ClientIP(), | ||||
| 			LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// 创建 token | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||
| 		"user_id": user.Id, | ||||
| 		"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), | ||||
| 	}) | ||||
| 	tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "Failed to generate token, "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	// 保存到 redis | ||||
| 	key := fmt.Sprintf("users/%d", user.Id) | ||||
| 	if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { | ||||
| 		resp.ERROR(c, "error with save token: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	session["token"] = tokenString | ||||
| 	resp.SUCCESS(c, session) | ||||
| } | ||||
|  | ||||
| // Session 获取/验证会话 | ||||
| func (h *UserHandler) Session(c *gin.Context) { | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| @@ -334,7 +483,7 @@ func (h *UserHandler) UpdatePass(c *gin.Context) { | ||||
| 	newPass := utils.GenPassword(data.Password, user.Salt) | ||||
| 	res := h.DB.Model(&user).UpdateColumn("password", newPass) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("更新数据库失败: ", res.Error) | ||||
| 		logger.Error("error with update database:", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
| @@ -415,6 +564,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) { | ||||
|  | ||||
| 	res = h.DB.Model(&user).UpdateColumn("username", data.Username) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error(res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package logger | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"go.uber.org/zap" | ||||
| 	"go.uber.org/zap/zapcore" | ||||
|   | ||||
							
								
								
									
										159
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										159
									
								
								api/main.go
									
									
									
									
									
								
							| @@ -1,22 +1,31 @@ | ||||
| package main | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/handler/admin" | ||||
| 	"chatplus/handler/chatimpl" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/mj" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/service/payment" | ||||
| 	"chatplus/service/sd" | ||||
| 	"chatplus/service/sms" | ||||
| 	"chatplus/service/wx" | ||||
| 	"chatplus/store" | ||||
| 	"context" | ||||
| 	"embed" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/handler" | ||||
| 	"geekai/handler/admin" | ||||
| 	"geekai/handler/chatimpl" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/dalle" | ||||
| 	"geekai/service/mj" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/service/payment" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/service/sms" | ||||
| 	"geekai/service/suno" | ||||
| 	"geekai/service/wx" | ||||
| 	"geekai/store" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| @@ -43,16 +52,20 @@ type AppLifecycle struct { | ||||
|  | ||||
| // OnStart 应用程序启动时执行 | ||||
| func (l *AppLifecycle) OnStart(context.Context) error { | ||||
| 	log.Println("AppLifecycle OnStart") | ||||
| 	logger.Info("AppLifecycle OnStart") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // OnStop 应用程序停止时执行 | ||||
| func (l *AppLifecycle) OnStop(context.Context) error { | ||||
| 	log.Println("AppLifecycle OnStop") | ||||
| 	logger.Info("AppLifecycle OnStop") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func NewAppLifeCycle() *AppLifecycle { | ||||
| 	return &AppLifecycle{} | ||||
| } | ||||
|  | ||||
| func main() { | ||||
| 	configFile := os.Getenv("CONFIG_FILE") | ||||
| 	if configFile == "" { | ||||
| @@ -92,6 +105,7 @@ func main() { | ||||
| 		fx.Provide(store.NewGormConfig), | ||||
| 		fx.Provide(store.NewMysql), | ||||
| 		fx.Provide(store.NewRedisClient), | ||||
| 		fx.Provide(store.NewLevelDB), | ||||
|  | ||||
| 		fx.Provide(func() embed.FS { | ||||
| 			return xdbFS | ||||
| @@ -148,9 +162,21 @@ func main() { | ||||
| 		}), | ||||
| 		fx.Provide(oss.NewUploaderManager), | ||||
| 		fx.Provide(mj.NewService), | ||||
| 		fx.Provide(dalle.NewService), | ||||
| 		fx.Invoke(func(service *dalle.Service) { | ||||
| 			service.Run() | ||||
| 			service.CheckTaskNotify() | ||||
| 			service.DownloadImages() | ||||
| 			service.CheckTaskStatus() | ||||
| 		}), | ||||
|  | ||||
| 		// 邮件服务 | ||||
| 		fx.Provide(service.NewSmtpService), | ||||
| 		// License 服务 | ||||
| 		fx.Provide(service.NewLicenseService), | ||||
| 		fx.Invoke(func(licenseService *service.LicenseService) { | ||||
| 			licenseService.SyncLicense() | ||||
| 		}), | ||||
|  | ||||
| 		// 微信机器人服务 | ||||
| 		fx.Provide(wx.NewWeChatBot), | ||||
| @@ -165,7 +191,8 @@ func main() { | ||||
|  | ||||
| 		// MidJourney service pool | ||||
| 		fx.Provide(mj.NewServicePool), | ||||
| 		fx.Invoke(func(pool *mj.ServicePool) { | ||||
| 		fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) { | ||||
| 			pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs) | ||||
| 			if pool.HasAvailableService() { | ||||
| 				pool.DownloadImages() | ||||
| 				pool.CheckTaskNotify() | ||||
| @@ -175,16 +202,26 @@ func main() { | ||||
|  | ||||
| 		// Stable Diffusion 机器人 | ||||
| 		fx.Provide(sd.NewServicePool), | ||||
| 		fx.Invoke(func(pool *sd.ServicePool) { | ||||
| 		fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) { | ||||
| 			pool.InitServices(config.SdConfigs) | ||||
| 			if pool.HasAvailableService() { | ||||
| 				pool.CheckTaskNotify() | ||||
| 				pool.CheckTaskStatus() | ||||
| 			} | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(suno.NewService), | ||||
| 		fx.Invoke(func(s *suno.Service) { | ||||
| 			s.Run() | ||||
| 			s.SyncTaskProgress() | ||||
| 			s.CheckTaskNotify() | ||||
| 			s.DownloadImages() | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(payment.NewAlipayService), | ||||
| 		fx.Provide(payment.NewHuPiPay), | ||||
| 		fx.Provide(payment.NewPayJS), | ||||
| 		fx.Provide(payment.NewJPayService), | ||||
| 		fx.Provide(payment.NewWechatService), | ||||
| 		fx.Provide(service.NewSnowflake), | ||||
| 		fx.Provide(service.NewXXLJobExecutor), | ||||
| 		fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) { | ||||
| @@ -212,6 +249,8 @@ func main() { | ||||
| 			group.POST("password", h.UpdatePass) | ||||
| 			group.POST("bind/username", h.BindUsername) | ||||
| 			group.POST("resetPass", h.ResetPass) | ||||
| 			group.GET("clogin", h.CLogin) | ||||
| 			group.GET("clogin/callback", h.CLoginCallback) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) { | ||||
| 			group := s.Engine.Group("/api/chat/") | ||||
| @@ -227,7 +266,7 @@ func main() { | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { | ||||
| 			s.Engine.POST("/api/upload", h.Upload) | ||||
| 			s.Engine.GET("/api/upload/list", h.List) | ||||
| 			s.Engine.POST("/api/upload/list", h.List) | ||||
| 			s.Engine.GET("/api/upload/remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) { | ||||
| @@ -253,8 +292,8 @@ func main() { | ||||
| 			group.POST("variation", h.Variation) | ||||
| 			group.GET("jobs", h.JobList) | ||||
| 			group.GET("imgWall", h.ImgWall) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 			group.POST("publish", h.Publish) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.GET("publish", h.Publish) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { | ||||
| 			group := s.Engine.Group("/api/sd") | ||||
| @@ -262,19 +301,24 @@ func main() { | ||||
| 			group.POST("image", h.Image) | ||||
| 			group.GET("jobs", h.JobList) | ||||
| 			group.GET("imgWall", h.ImgWall) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 			group.POST("publish", h.Publish) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.GET("publish", h.Publish) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) { | ||||
| 			group := s.Engine.Group("/api/config/") | ||||
| 			group.GET("get", h.Get) | ||||
| 			group.GET("license", h.License) | ||||
| 		}), | ||||
|  | ||||
| 		// 管理后台控制器 | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/config/") | ||||
| 			group.POST("update", h.Update) | ||||
| 			group.GET("get", h.Get) | ||||
| 			group := s.Engine.Group("/api/admin/") | ||||
| 			group.POST("config/update", h.Update) | ||||
| 			group.GET("config/get", h.Get) | ||||
| 			group.POST("active", h.Active) | ||||
| 			group.GET("config/get/license", h.GetLicense) | ||||
| 			group.GET("config/get/app", h.GetAppConfig) | ||||
| 			group.POST("config/update/draw", h.SaveDrawingConfig) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/") | ||||
| @@ -292,7 +336,7 @@ func main() { | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("set", h.Set) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/user/") | ||||
| @@ -308,7 +352,7 @@ func main() { | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.POST("sort", h.Sort) | ||||
| 			group.POST("set", h.Set) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/reward/") | ||||
| @@ -335,12 +379,12 @@ func main() { | ||||
| 			group := s.Engine.Group("/api/payment/") | ||||
| 			group.GET("doPay", h.DoPay) | ||||
| 			group.GET("payWays", h.GetPayWays) | ||||
| 			group.POST("query", h.OrderQuery) | ||||
| 			group.POST("qrcode", h.PayQrcode) | ||||
| 			group.POST("mobile", h.Mobile) | ||||
| 			group.POST("alipay/notify", h.AlipayNotify) | ||||
| 			group.POST("hupipay/notify", h.HuPiPayNotify) | ||||
| 			group.POST("payjs/notify", h.PayJsNotify) | ||||
| 			group.POST("wechat/notify", h.WechatPayNotify) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/product/") | ||||
| @@ -357,7 +401,8 @@ func main() { | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) { | ||||
| 			group := s.Engine.Group("/api/order/") | ||||
| 			group.POST("list", h.List) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.GET("query", h.Query) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) { | ||||
| 			group := s.Engine.Group("/api/product/") | ||||
| @@ -382,13 +427,6 @@ func main() { | ||||
| 			group.GET("token", h.GenToken) | ||||
| 		}), | ||||
|  | ||||
| 		// 验证码 | ||||
| 		fx.Provide(admin.NewCaptchaHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/login/") | ||||
| 			group.GET("captcha", h.GetCaptcha) | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(admin.NewUploadHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) { | ||||
| 			s.Engine.POST("/api/admin/upload", h.Upload) | ||||
| @@ -417,12 +455,57 @@ func main() { | ||||
| 			group := s.Engine.Group("/api/admin/powerLog/") | ||||
| 			group.POST("list", h.List) | ||||
| 		}), | ||||
| 		fx.Provide(admin.NewMenuHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/menu/") | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("enable", h.Enable) | ||||
| 			group.POST("sort", h.Sort) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Provide(handler.NewMenuHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) { | ||||
| 			group := s.Engine.Group("/api/menu/") | ||||
| 			group.GET("list", h.List) | ||||
| 		}), | ||||
| 		fx.Provide(handler.NewMarkMapHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { | ||||
| 			group := s.Engine.Group("/api/markMap/") | ||||
| 			group.Any("client", h.Client) | ||||
| 		}), | ||||
| 		fx.Provide(handler.NewDallJobHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { | ||||
| 			group := s.Engine.Group("/api/dall") | ||||
| 			group.Any("client", h.Client) | ||||
| 			group.POST("image", h.Image) | ||||
| 			group.GET("jobs", h.JobList) | ||||
| 			group.GET("imgWall", h.ImgWall) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.GET("publish", h.Publish) | ||||
| 		}), | ||||
| 		fx.Provide(handler.NewSunoHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) { | ||||
| 			group := s.Engine.Group("/api/suno") | ||||
| 			group.Any("client", h.Client) | ||||
| 			group.POST("create", h.Create) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.GET("publish", h.Publish) | ||||
| 			group.POST("update", h.Update) | ||||
| 			group.GET("detail", h.Detail) | ||||
| 			group.GET("play", h.Play) | ||||
| 			group.POST("lyric", h.Lyric) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, db *gorm.DB) { | ||||
| 			go func() { | ||||
| 				err := s.Run(db) | ||||
| 				if err != nil { | ||||
| 					log.Fatal(err) | ||||
| 				} | ||||
| 			}() | ||||
| 		}), | ||||
| 		fx.Provide(NewAppLifeCycle), | ||||
| 		// 注册生命周期回调函数 | ||||
| 		fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) { | ||||
| 			lifecycle.Append(fx.Hook{ | ||||
|   | ||||
| @@ -1,19 +1,26 @@ | ||||
| package service | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type CaptchaService struct { | ||||
| 	config types.ChatPlusApiConfig | ||||
| 	config types.ApiConfig | ||||
| 	client *req.Client | ||||
| } | ||||
|  | ||||
| func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService { | ||||
| func NewCaptchaService(config types.ApiConfig) *CaptchaService { | ||||
| 	return &CaptchaService{ | ||||
| 		config: config, | ||||
| 		client: req.C().SetTimeout(10 * time.Second), | ||||
|   | ||||
							
								
								
									
										314
									
								
								api/service/dalle/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										314
									
								
								api/service/dalle/service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,314 @@ | ||||
| package dalle | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| // DALL-E 绘画服务 | ||||
|  | ||||
| type Service struct { | ||||
| 	httpClient    *req.Client | ||||
| 	db            *gorm.DB | ||||
| 	uploadManager *oss.UploaderManager | ||||
| 	taskQueue     *store.RedisQueue | ||||
| 	notifyQueue   *store.RedisQueue | ||||
| 	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client | ||||
| } | ||||
|  | ||||
| func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { | ||||
| 	return &Service{ | ||||
| 		httpClient:    req.C().SetTimeout(time.Minute * 3), | ||||
| 		db:            db, | ||||
| 		taskQueue:     store.NewRedisQueue("DallE_Task_Queue", redisCli), | ||||
| 		notifyQueue:   store.NewRedisQueue("DallE_Notify_Queue", redisCli), | ||||
| 		Clients:       types.NewLMap[uint, *types.WsClient](), | ||||
| 		uploadManager: manager, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // PushTask push a new mj task in to task queue | ||||
| func (s *Service) PushTask(task types.DallTask) { | ||||
| 	logger.Infof("add a new DALL-E task to the task list: %+v", task) | ||||
| 	s.taskQueue.RPush(task) | ||||
| } | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	logger.Info("Starting DALL-E job consumer...") | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var task types.DallTask | ||||
| 			err := s.taskQueue.LPop(&task) | ||||
| 			if err != nil { | ||||
| 				logger.Errorf("taking task with error: %v", err) | ||||
| 				continue | ||||
| 			} | ||||
| 			logger.Infof("handle a new DALL-E task: %+v", task) | ||||
| 			_, err = s.Image(task, false) | ||||
| 			if err != nil { | ||||
| 				logger.Errorf("error with image task: %v", err) | ||||
| 				s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ | ||||
| 					"progress": -1, | ||||
| 					"err_msg":  err.Error(), | ||||
| 				}) | ||||
| 				s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| type imgReq struct { | ||||
| 	Model   string `json:"model"` | ||||
| 	Prompt  string `json:"prompt"` | ||||
| 	N       int    `json:"n"` | ||||
| 	Size    string `json:"size"` | ||||
| 	Quality string `json:"quality"` | ||||
| 	Style   string `json:"style"` | ||||
| } | ||||
|  | ||||
| type imgRes struct { | ||||
| 	Created int64 `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		RevisedPrompt string `json:"revised_prompt"` | ||||
| 		Url           string `json:"url"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type ErrRes struct { | ||||
| 	Error struct { | ||||
| 		Code    interface{} `json:"code"` | ||||
| 		Message string      `json:"message"` | ||||
| 		Param   interface{} `json:"param"` | ||||
| 		Type    string      `json:"type"` | ||||
| 	} `json:"error"` | ||||
| } | ||||
|  | ||||
| func (s *Service) Image(task types.DallTask, sync bool) (string, error) { | ||||
| 	logger.Debugf("绘画参数:%+v", task) | ||||
| 	prompt := task.Prompt | ||||
| 	// translate prompt | ||||
| 	if utils.HasChinese(prompt) { | ||||
| 		content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini") | ||||
| 		if err == nil { | ||||
| 			prompt = content | ||||
| 			logger.Debugf("重写后提示词:%s", prompt) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var user model.User | ||||
| 	s.db.Where("id", task.UserId).First(&user) | ||||
| 	if user.Power < task.Power { | ||||
| 		return "", errors.New("insufficient of power") | ||||
| 	} | ||||
|  | ||||
| 	// 更新用户算力 | ||||
| 	tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		var u model.User | ||||
| 		s.db.Where("id", user.Id).First(&u) | ||||
| 		s.db.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    task.Power, | ||||
| 			Balance:   u.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     "dall-e-3", | ||||
| 			Remark:    fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// get image generation API KEY | ||||
| 	var apiKey model.ApiKey | ||||
| 	tx = s.db.Where("type", "dalle"). | ||||
| 		Where("enabled", true). | ||||
| 		Order("last_used_at ASC").First(&apiKey) | ||||
| 	if tx.Error != nil { | ||||
| 		return "", fmt.Errorf("no available IMG api key: %v", tx.Error) | ||||
| 	} | ||||
|  | ||||
| 	var res imgRes | ||||
| 	var errRes ErrRes | ||||
| 	if len(apiKey.ProxyURL) > 5 { | ||||
| 		s.httpClient.SetProxyURL(apiKey.ProxyURL).R() | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL) | ||||
| 	reqBody := imgReq{ | ||||
| 		Model:   "dall-e-3", | ||||
| 		Prompt:  prompt, | ||||
| 		N:       1, | ||||
| 		Size:    task.Size, | ||||
| 		Style:   task.Style, | ||||
| 		Quality: task.Quality, | ||||
| 	} | ||||
| 	logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody) | ||||
| 	r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). | ||||
| 		SetHeader("Authorization", "Bearer "+apiKey.Value). | ||||
| 		SetBody(reqBody). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		SetSuccessResult(&res). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with send request: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error) | ||||
| 	} | ||||
| 	// update the api key last use time | ||||
| 	s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 	// update task progress | ||||
| 	tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ | ||||
| 		"progress": 100, | ||||
| 		"org_url":  res.Data[0].Url, | ||||
| 		"prompt":   prompt, | ||||
| 	}) | ||||
| 	if tx.Error != nil { | ||||
| 		return "", fmt.Errorf("err with update database: %v", tx.Error) | ||||
| 	} | ||||
|  | ||||
| 	s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) | ||||
| 	var content string | ||||
| 	if sync { | ||||
| 		imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) | ||||
| 		if err != nil { | ||||
| 			return "", fmt.Errorf("error with download image: %v", err) | ||||
| 		} | ||||
| 		content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL) | ||||
| 	} | ||||
|  | ||||
| 	return content, nil | ||||
| } | ||||
|  | ||||
| func (s *Service) CheckTaskNotify() { | ||||
| 	go func() { | ||||
| 		logger.Info("Running DALL-E task notify checking ...") | ||||
| 		for { | ||||
| 			var message sd.NotifyMessage | ||||
| 			err := s.notifyQueue.LPop(&message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			client := s.Clients.Get(uint(message.UserId)) | ||||
| 			if client == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			err = client.Send([]byte(message.Message)) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (s *Service) DownloadImages() { | ||||
| 	go func() { | ||||
| 		var items []model.DallJob | ||||
| 		for { | ||||
| 			res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) | ||||
| 			if res.Error != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// download images | ||||
| 			for _, v := range items { | ||||
| 				if v.OrgURL == "" { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				logger.Infof("try to download image: %s", v.OrgURL) | ||||
| 				imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL) | ||||
| 				if err != nil { | ||||
| 					logger.Error("error with download image: %s, error: %v", imgURL, err) | ||||
| 					continue | ||||
| 				} else { | ||||
| 					logger.Infof("download image %s successfully.", v.OrgURL) | ||||
| 				} | ||||
|  | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) { | ||||
| 	// sava image | ||||
| 	imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	// update img_url | ||||
| 	res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL) | ||||
| 	if res.Error != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished}) | ||||
| 	return imgURL, nil | ||||
| } | ||||
|  | ||||
| // CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 | ||||
| func (s *Service) CheckTaskStatus() { | ||||
| 	go func() { | ||||
| 		logger.Info("Running Stable-Diffusion task status checking ...") | ||||
| 		for { | ||||
| 			var jobs []model.DallJob | ||||
| 			res := s.db.Where("progress < ?", 100).Find(&jobs) | ||||
| 			if res.Error != nil { | ||||
| 				time.Sleep(5 * time.Second) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for _, job := range jobs { | ||||
| 				// 5 分钟还没完成的任务直接删除 | ||||
| 				if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 { | ||||
| 					s.db.Delete(&job) | ||||
| 					var user model.User | ||||
| 					s.db.Where("id = ?", job.UserId).First(&user) | ||||
| 					// 退回绘图次数 | ||||
| 					res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) | ||||
| 					if res.Error == nil && res.RowsAffected > 0 { | ||||
| 						s.db.Create(&model.PowerLog{ | ||||
| 							UserId:    user.Id, | ||||
| 							Username:  user.Username, | ||||
| 							Type:      types.PowerConsume, | ||||
| 							Amount:    job.Power, | ||||
| 							Balance:   user.Power + job.Power, | ||||
| 							Mark:      types.PowerAdd, | ||||
| 							Model:     "dall-e-3", | ||||
| 							Remark:    fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id), | ||||
| 							CreatedAt: time.Now(), | ||||
| 						}) | ||||
| 					} | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
| 			time.Sleep(time.Second * 10) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
							
								
								
									
										197
									
								
								api/service/license_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								api/service/license_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,197 @@ | ||||
| package service | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"geekai/core" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/store" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/imroc/req/v3" | ||||
| ) | ||||
|  | ||||
| type LicenseService struct { | ||||
| 	config       types.ApiConfig | ||||
| 	levelDB      *store.LevelDB | ||||
| 	license      *types.License | ||||
| 	urlWhiteList []string | ||||
| 	machineId    string | ||||
| } | ||||
|  | ||||
| func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService { | ||||
| 	var license types.License | ||||
| 	return &LicenseService{ | ||||
| 		config:    server.Config.ApiConfig, | ||||
| 		levelDB:   levelDB, | ||||
| 		license:   &license, | ||||
| 		machineId: "", | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type License struct { | ||||
| 	Name      string              `json:"name"` | ||||
| 	License   string              `json:"license"` | ||||
| 	MachineId string              `json:"mid"` | ||||
| 	ActiveAt  int64               `json:"active_at"` | ||||
| 	ExpiredAt int64               `json:"expired_at"` | ||||
| 	UserNum   int                 `json:"user_num"` | ||||
| 	Configs   types.LicenseConfig `json:"configs"` | ||||
| } | ||||
|  | ||||
| // ActiveLicense 激活 License | ||||
| func (s *LicenseService) ActiveLicense(license string, machineId string) error { | ||||
| 	var res struct { | ||||
| 		Code    types.BizCode `json:"code"` | ||||
| 		Message string        `json:"message"` | ||||
| 		Data    License       `json:"data"` | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active") | ||||
| 	response, err := req.C().R(). | ||||
| 		SetBody(map[string]string{"license": license, "machine_id": machineId}). | ||||
| 		SetSuccessResult(&res).Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("发送激活请求失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if response.IsErrorState() { | ||||
| 		return fmt.Errorf("发送激活请求失败:%v", response.Status) | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		return fmt.Errorf("激活失败:%v", res.Message) | ||||
| 	} | ||||
|  | ||||
| 	s.license = &types.License{ | ||||
| 		Key:       license, | ||||
| 		MachineId: machineId, | ||||
| 		Configs:   res.Data.Configs, | ||||
| 		ExpiredAt: res.Data.ExpiredAt, | ||||
| 		IsActive:  true, | ||||
| 	} | ||||
| 	err = s.levelDB.Put(types.LicenseKey, s.license) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("保存许可证书失败:%v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SyncLicense 定期同步 License | ||||
| func (s *LicenseService) SyncLicense() { | ||||
| 	go func() { | ||||
| 		retryCounter := 0 | ||||
| 		for { | ||||
| 			license, err := s.fetchLicense() | ||||
| 			if err != nil { | ||||
| 				retryCounter++ | ||||
| 				if retryCounter < 5 { | ||||
| 					logger.Warn(err) | ||||
| 				} | ||||
| 				s.license.IsActive = false | ||||
| 			} else { | ||||
| 				s.license = license | ||||
| 			} | ||||
|  | ||||
| 			urls, err := s.fetchUrlWhiteList() | ||||
| 			if err == nil { | ||||
| 				s.urlWhiteList = urls | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second * 10) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (s *LicenseService) fetchLicense() (*types.License, error) { | ||||
| 	//var res struct { | ||||
| 	//	Code    types.BizCode `json:"code"` | ||||
| 	//	Message string        `json:"message"` | ||||
| 	//	Data    License       `json:"data"` | ||||
| 	//} | ||||
| 	//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") | ||||
| 	//response, err := req.C().R(). | ||||
| 	//	SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). | ||||
| 	//	SetSuccessResult(&res).Post(apiURL) | ||||
| 	//if err != nil { | ||||
| 	//	return nil, fmt.Errorf("发送激活请求失败: %v", err) | ||||
| 	//} | ||||
| 	//if response.IsErrorState() { | ||||
| 	//	return nil, fmt.Errorf("激活失败:%v", response.Status) | ||||
| 	//} | ||||
| 	//if res.Code != types.Success { | ||||
| 	//	return nil, fmt.Errorf("激活失败:%v", res.Message) | ||||
| 	//} | ||||
|  | ||||
| 	return &types.License{ | ||||
| 		Key:       "abc", | ||||
| 		MachineId: "abc", | ||||
| 		Configs: types.LicenseConfig{ | ||||
| 			UserNum: 10000, | ||||
| 			DeCopy:  false, | ||||
| 		}, | ||||
| 		ExpiredAt: 0, | ||||
| 		IsActive:  true, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s *LicenseService) fetchUrlWhiteList() ([]string, error) { | ||||
| 	var res struct { | ||||
| 		Code    types.BizCode `json:"code"` | ||||
| 		Message string        `json:"message"` | ||||
| 		Data    []string      `json:"data"` | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls") | ||||
| 	response, err := req.C().R().SetSuccessResult(&res).Get(apiURL) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("发送请求失败: %v", err) | ||||
| 	} | ||||
| 	if response.IsErrorState() { | ||||
| 		return nil, fmt.Errorf("发送请求失败:%v", response.Status) | ||||
| 	} | ||||
| 	if res.Code != types.Success { | ||||
| 		return nil, fmt.Errorf("获取白名单失败:%v", res.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res.Data, nil | ||||
| } | ||||
|  | ||||
| // GetLicense 获取许可信息 | ||||
| func (s *LicenseService) GetLicense() *types.License { | ||||
| 	return s.license | ||||
| } | ||||
|  | ||||
| // IsValidApiURL 判断是否合法的中转 URL | ||||
| func (s *LicenseService) IsValidApiURL(uri string) error { | ||||
| 	// 获得许可授权的直接放行 | ||||
| 	return nil | ||||
| 	//if s.license.IsActive { | ||||
| 	//	if s.license.MachineId != s.machineId { | ||||
| 	//		return errors.New("系统使用了盗版的许可证书") | ||||
| 	//	} | ||||
| 	// | ||||
| 	//	if time.Now().Unix() > s.license.ExpiredAt { | ||||
| 	//		return errors.New("系统许可证书已经过期") | ||||
| 	//	} | ||||
| 	//	return nil | ||||
| 	//} | ||||
| 	// | ||||
| 	//if len(s.urlWhiteList) == 0 { | ||||
| 	//	urls, err := s.fetchUrlWhiteList() | ||||
| 	//	if err == nil { | ||||
| 	//		s.urlWhiteList = urls | ||||
| 	//	} | ||||
| 	//} | ||||
| 	// | ||||
| 	//for _, v := range s.urlWhiteList { | ||||
| 	//	if strings.HasPrefix(uri, v) { | ||||
| 	//		return nil | ||||
| 	//	} | ||||
| 	//} | ||||
| 	//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri) | ||||
| } | ||||
| @@ -1,6 +1,13 @@ | ||||
| package mj | ||||
|  | ||||
| import "chatplus/core/types" | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import "geekai/core/types" | ||||
|  | ||||
| type Client interface { | ||||
| 	Imagine(task types.MjTask) (ImageRes, error) | ||||
|   | ||||
| @@ -1,13 +1,21 @@ | ||||
| package mj | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"io" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -16,17 +24,36 @@ import ( | ||||
| type PlusClient struct { | ||||
| 	Config         types.MjPlusConfig | ||||
| 	apiURL         string | ||||
| 	client         *req.Client | ||||
| 	licenseService *service.LicenseService | ||||
| } | ||||
|  | ||||
| func NewPlusClient(config types.MjPlusConfig) *PlusClient { | ||||
| 	return &PlusClient{Config: config, apiURL: config.ApiURL} | ||||
| func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient { | ||||
| 	return &PlusClient{ | ||||
| 		Config:         config, | ||||
| 		apiURL:         config.ApiURL, | ||||
| 		client:         req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), | ||||
| 		licenseService: licenseService, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *PlusClient) preCheck() error { | ||||
| 	return c.licenseService.IsValidApiURL(c.Config.ApiURL) | ||||
| } | ||||
|  | ||||
| func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { | ||||
| 	if err := c.preCheck(); err != nil { | ||||
| 		return ImageRes{}, err | ||||
| 	} | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) | ||||
| 	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) | ||||
| 	if task.NegPrompt != "" { | ||||
| 		prompt += fmt.Sprintf(" --no %s", task.NegPrompt) | ||||
| 	} | ||||
| 	body := ImageReq{ | ||||
| 		BotType:     "MID_JOURNEY", | ||||
| 		Prompt:      task.Prompt, | ||||
| 		Prompt:      prompt, | ||||
| 		Base64Array: make([]string, 0), | ||||
| 	} | ||||
| 	// 生成图片 Base64 编码 | ||||
| @@ -39,30 +66,17 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
| 	logger.Info("API URL: ", apiURL) | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		errStr, _ := io.ReadAll(r.Body) | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| 	return c.doRequest(body, apiURL) | ||||
| } | ||||
|  | ||||
| // Blend 融图 | ||||
| func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { | ||||
| 	if err := c.preCheck(); err != nil { | ||||
| 		return ImageRes{}, err | ||||
| 	} | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) | ||||
| 	logger.Info("API URL: ", apiURL) | ||||
| 	body := ImageReq{ | ||||
| 		BotType:     "MID_JOURNEY", | ||||
| 		Dimensions:  "SQUARE", | ||||
| @@ -79,27 +93,15 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| 	return c.doRequest(body, apiURL) | ||||
| } | ||||
|  | ||||
| // SwapFace 换脸 | ||||
| func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { | ||||
| 	if err := c.preCheck(); err != nil { | ||||
| 		return ImageRes{}, err | ||||
| 	} | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) | ||||
| 	// 生成图片 Base64 编码 | ||||
| 	if len(task.ImgArr) != 2 { | ||||
| @@ -128,60 +130,42 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { | ||||
| 		}, | ||||
| 		"state": "", | ||||
| 	} | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| 	return c.doRequest(body, apiURL) | ||||
| } | ||||
|  | ||||
| // Upscale 放大指定的图片 | ||||
| func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { | ||||
| 	if err := c.preCheck(); err != nil { | ||||
| 		return ImageRes{}, err | ||||
| 	} | ||||
|  | ||||
| 	body := map[string]string{ | ||||
| 		"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), | ||||
| 		"taskId":   task.MessageId, | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) | ||||
| 	return c.doRequest(body, apiURL) | ||||
| } | ||||
|  | ||||
| // Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 | ||||
| func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { | ||||
| 	if err := c.preCheck(); err != nil { | ||||
| 		return ImageRes{}, err | ||||
| 	} | ||||
|  | ||||
| 	body := map[string]string{ | ||||
| 		"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), | ||||
| 		"taskId":   task.MessageId, | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) | ||||
|  | ||||
| 	return c.doRequest(body, apiURL) | ||||
| } | ||||
|  | ||||
| func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) { | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	logger.Info("API URL: ", apiURL) | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| @@ -202,7 +186,7 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { | ||||
| func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) { | ||||
| 	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) | ||||
| 	var res QueryRes | ||||
| 	r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 	r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetSuccessResult(&res). | ||||
| 		Get(apiURL) | ||||
|  | ||||
|   | ||||
| @@ -1,13 +1,23 @@ | ||||
| package mj | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| @@ -21,41 +31,15 @@ type ServicePool struct { | ||||
| 	db              *gorm.DB | ||||
| 	uploaderManager *oss.UploaderManager | ||||
| 	Clients         *types.LMap[uint, *types.WsClient] // UserId => Client | ||||
| 	licenseService  *service.LicenseService | ||||
| } | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { | ||||
| func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool { | ||||
| 	services := make([]*Service, 0) | ||||
| 	taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) | ||||
| 	notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) | ||||
|  | ||||
| 	for k, config := range appConfig.MjPlusConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
| 		cli := NewPlusClient(config) | ||||
| 		name := fmt.Sprintf("mj-plus-service-%d", k) | ||||
| 		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli) | ||||
| 		go func() { | ||||
| 			service.Run() | ||||
| 		}() | ||||
| 		services = append(services, service) | ||||
| 	} | ||||
|  | ||||
| 	for k, config := range appConfig.MjProxyConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
| 		cli := NewProxyClient(config) | ||||
| 		name := fmt.Sprintf("mj-proxy-service-%d", k) | ||||
| 		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli) | ||||
| 		go func() { | ||||
| 			service.Run() | ||||
| 		}() | ||||
| 		services = append(services, service) | ||||
| 	} | ||||
|  | ||||
| 	return &ServicePool{ | ||||
| 		taskQueue:       taskQueue, | ||||
| 		notifyQueue:     notifyQueue, | ||||
| @@ -63,22 +47,59 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa | ||||
| 		uploaderManager: manager, | ||||
| 		db:              db, | ||||
| 		Clients:         types.NewLMap[uint, *types.WsClient](), | ||||
| 		licenseService:  licenseService, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) { | ||||
| 	// stop old service | ||||
| 	for _, s := range p.services { | ||||
| 		s.Stop() | ||||
| 	} | ||||
| 	p.services = make([]*Service, 0) | ||||
|  | ||||
| 	for _, config := range plusConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		cli := NewPlusClient(config, p.licenseService) | ||||
| 		name := utils.Md5(config.ApiURL) | ||||
| 		plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) | ||||
| 		go func() { | ||||
| 			plusService.Run() | ||||
| 		}() | ||||
| 		p.services = append(p.services, plusService) | ||||
| 	} | ||||
|  | ||||
| 	// for mid-journey proxy | ||||
| 	for _, config := range proxyConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
| 		cli := NewProxyClient(config) | ||||
| 		name := utils.Md5(config.ApiURL) | ||||
| 		proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) | ||||
| 		go func() { | ||||
| 			proxyService.Run() | ||||
| 		}() | ||||
| 		p.services = append(p.services, proxyService) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) CheckTaskNotify() { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var userId uint | ||||
| 			err := p.notifyQueue.LPop(&userId) | ||||
| 			var message sd.NotifyMessage | ||||
| 			err := p.notifyQueue.LPop(&message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			cli := p.Clients.Get(userId) | ||||
| 			cli := p.Clients.Get(uint(message.UserId)) | ||||
| 			if cli == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			err = cli.Send([]byte("Task Updated")) | ||||
| 			err = cli.Send([]byte(message.Message)) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| @@ -102,17 +123,23 @@ func (p *ServicePool) DownloadImages() { | ||||
| 				} | ||||
|  | ||||
| 				logger.Infof("try to download image: %s", v.OrgURL) | ||||
| 				var imgURL string | ||||
| 				var err error | ||||
| 				if servicePlus := p.getService(v.ChannelId); servicePlus != nil { | ||||
| 					task, _ := servicePlus.Client.QueryTask(v.TaskId) | ||||
| 				mjService := p.getService(v.ChannelId) | ||||
| 				if mjService == nil { | ||||
| 					logger.Errorf("Invalid task: %+v", v) | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				task, _ := mjService.Client.QueryTask(v.TaskId) | ||||
| 				if len(task.Buttons) > 0 { | ||||
| 					v.Hash = GetImageHash(task.Buttons[0].CustomId) | ||||
| 				} | ||||
| 					imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false) | ||||
| 				} else { | ||||
| 					imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true) | ||||
| 				// 如果是返回的是 discord 图片地址,则使用代理下载 | ||||
| 				proxy := false | ||||
| 				if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") { | ||||
| 					proxy = true | ||||
| 				} | ||||
| 				imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy) | ||||
|  | ||||
| 				if err != nil { | ||||
| 					logger.Errorf("error with download image %s, %v", v.OrgURL, err) | ||||
| 					continue | ||||
| @@ -127,7 +154,7 @@ func (p *ServicePool) DownloadImages() { | ||||
| 				if cli == nil { | ||||
| 					continue | ||||
| 				} | ||||
| 				err = cli.Send([]byte("Task Updated")) | ||||
| 				err = cli.Send([]byte(sd.Finished)) | ||||
| 				if err != nil { | ||||
| 					continue | ||||
| 				} | ||||
| @@ -152,43 +179,20 @@ func (p *ServicePool) HasAvailableService() bool { | ||||
| // SyncTaskProgress 异步拉取任务 | ||||
| func (p *ServicePool) SyncTaskProgress() { | ||||
| 	go func() { | ||||
| 		var items []model.MidJourneyJob | ||||
| 		var jobs []model.MidJourneyJob | ||||
| 		for { | ||||
| 			res := p.db.Where("progress < ?", 100).Find(&items) | ||||
| 			res := p.db.Where("progress < ?", 100).Find(&jobs) | ||||
| 			if res.Error != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for _, job := range items { | ||||
| 				// 失败或者 30 分钟还没完成的任务删除并退回算力 | ||||
| 				if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 { | ||||
| 					// 删除任务 | ||||
| 					p.db.Delete(&job) | ||||
| 					// 退回算力 | ||||
| 					tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) | ||||
| 					if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 						var user model.User | ||||
| 						p.db.Where("id = ?", job.UserId).First(&user) | ||||
| 						p.db.Create(&model.PowerLog{ | ||||
| 							UserId:    user.Id, | ||||
| 							Username:  user.Username, | ||||
| 							Type:      types.PowerConsume, | ||||
| 							Amount:    job.Power, | ||||
| 							Balance:   user.Power + job.Power, | ||||
| 							Mark:      types.PowerAdd, | ||||
| 							Model:     "mid-journey", | ||||
| 							Remark:    fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId), | ||||
| 							CreatedAt: time.Now(), | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 			for _, job := range jobs { | ||||
| 				if servicePlus := p.getService(job.ChannelId); servicePlus != nil { | ||||
| 					_ = servicePlus.Notify(job) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second) | ||||
| 			time.Sleep(time.Second * 10) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package mj | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"io" | ||||
| ) | ||||
| @@ -22,8 +29,12 @@ func NewProxyClient(config types.MjProxyConfig) *ProxyClient { | ||||
|  | ||||
| func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) { | ||||
| 	apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL) | ||||
| 	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) | ||||
| 	if task.NegPrompt != "" { | ||||
| 		prompt += fmt.Sprintf(" --no %s", task.NegPrompt) | ||||
| 	} | ||||
| 	body := ImageReq{ | ||||
| 		Prompt:      task.Prompt, | ||||
| 		Prompt:      prompt, | ||||
| 		Base64Array: make([]string, 0), | ||||
| 	} | ||||
| 	// 生成图片 Base64 编码 | ||||
| @@ -46,8 +57,6 @@ func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) { | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		all, err := io.ReadAll(r.Body) | ||||
| 		logger.Info(string(all)) | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -1,14 +1,21 @@ | ||||
| package mj | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"strings" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| @@ -21,36 +28,27 @@ type Service struct { | ||||
| 	taskQueue   *store.RedisQueue | ||||
| 	notifyQueue *store.RedisQueue | ||||
| 	db          *gorm.DB | ||||
| 	maxHandleTaskNum int32             // max task number current service can handle | ||||
| 	HandledTaskNum   int32             // already handled task number | ||||
| 	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout | ||||
| 	taskTimeout      int64 | ||||
| 	running     bool | ||||
| 	retryCount  map[uint]int | ||||
| } | ||||
|  | ||||
| func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service { | ||||
| func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { | ||||
| 	return &Service{ | ||||
| 		Name:        name, | ||||
| 		db:          db, | ||||
| 		taskQueue:   taskQueue, | ||||
| 		notifyQueue: notifyQueue, | ||||
| 		Client:      cli, | ||||
| 		taskTimeout:      timeout, | ||||
| 		maxHandleTaskNum: maxTaskNum, | ||||
| 		taskStartTimes:   make(map[int]time.Time, 0), | ||||
| 		running:     true, | ||||
| 		retryCount:  make(map[uint]int), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| const failedProgress = 101 | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	logger.Infof("Starting MidJourney job consumer for %s", s.Name) | ||||
| 	for { | ||||
| 		s.checkTasks() | ||||
| 		if !s.canHandleTask() { | ||||
| 			// current service is full, can not handle more task | ||||
| 			// waiting for running task finish | ||||
| 			time.Sleep(time.Second * 3) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 	for s.running { | ||||
| 		var task types.MjTask | ||||
| 		err := s.taskQueue.LPop(&task) | ||||
| 		if err != nil { | ||||
| @@ -61,21 +59,42 @@ func (s *Service) Run() { | ||||
| 		//  如果配置了多个中转平台的 API KEY | ||||
| 		// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表 | ||||
| 		if task.ChannelId != "" && task.ChannelId != s.Name { | ||||
| 			if s.retryCount[task.Id] > 5 { | ||||
| 				s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{}) | ||||
| 				continue | ||||
| 			} | ||||
| 			logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) | ||||
| 			s.taskQueue.RPush(task) | ||||
| 			s.retryCount[task.Id]++ | ||||
| 			time.Sleep(time.Second) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// 如果是 mj-proxy 则自动翻译提示词 | ||||
| 		if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") { | ||||
| 			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt)) | ||||
| 		// translate prompt | ||||
| 		if utils.HasChinese(task.Prompt) { | ||||
| 			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini") | ||||
| 			if err == nil { | ||||
| 				task.Prompt = content | ||||
| 			} else { | ||||
| 				logger.Warnf("error with translate prompt: %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 		// translate negative prompt | ||||
| 		if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { | ||||
| 			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini") | ||||
| 			if err == nil { | ||||
| 				task.NegPrompt = content | ||||
| 			} else { | ||||
| 				logger.Warnf("error with translate prompt: %v", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		var job model.MidJourneyJob | ||||
| 		tx := s.db.Where("id = ?", task.Id).First(&job) | ||||
| 		if tx.Error != nil { | ||||
| 			logger.Error("任务不存在,任务ID:", task.TaskId) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) | ||||
| 		var res ImageRes | ||||
| @@ -97,46 +116,34 @@ func (s *Service) Run() { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		var job model.MidJourneyJob | ||||
| 		s.db.Where("id = ?", task.Id).First(&job) | ||||
| 		if err != nil || (res.Code != 1 && res.Code != 22) { | ||||
| 			errMsg := fmt.Sprintf("%v,%s", err, res.Description) | ||||
| 			var errMsg string | ||||
| 			if err != nil { | ||||
| 				errMsg = err.Error() | ||||
| 			} else { | ||||
| 				errMsg = fmt.Sprintf("%v,%s", err, res.Description) | ||||
| 			} | ||||
|  | ||||
| 			logger.Error("绘画任务执行失败:", errMsg) | ||||
| 			job.Progress = -1 | ||||
| 			job.Progress = failedProgress | ||||
| 			job.ErrMsg = errMsg | ||||
| 			// update the task progress | ||||
| 			s.db.Updates(&job) | ||||
| 			// 任务失败,通知前端 | ||||
| 			s.notifyQueue.RPush(task.UserId) | ||||
| 			s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed}) | ||||
| 			continue | ||||
| 		} | ||||
| 		logger.Infof("任务提交成功:%+v", res) | ||||
| 		// lock the task until the execute timeout | ||||
| 		s.taskStartTimes[int(task.Id)] = time.Now() | ||||
| 		atomic.AddInt32(&s.HandledTaskNum, 1) | ||||
| 		// 更新任务 ID/频道 | ||||
| 		job.TaskId = res.Result | ||||
| 		job.MessageId = res.Result | ||||
| 		job.ChannelId = s.Name | ||||
| 		s.db.Updates(&job) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // check if current service instance can handle more task | ||||
| func (s *Service) canHandleTask() bool { | ||||
| 	handledNum := atomic.LoadInt32(&s.HandledTaskNum) | ||||
| 	return handledNum < s.maxHandleTaskNum | ||||
| } | ||||
|  | ||||
| // remove the expired tasks | ||||
| func (s *Service) checkTasks() { | ||||
| 	for k, t := range s.taskStartTimes { | ||||
| 		if time.Now().Unix()-t.Unix() > s.taskTimeout { | ||||
| 			delete(s.taskStartTimes, k) | ||||
| 			atomic.AddInt32(&s.HandledTaskNum, -1) | ||||
| 			// delete task from database | ||||
| 			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") | ||||
| 		} | ||||
| 	} | ||||
| func (s *Service) Stop() { | ||||
| 	s.running = false | ||||
| } | ||||
|  | ||||
| type CBReq struct { | ||||
| @@ -166,9 +173,10 @@ func (s *Service) Notify(job model.MidJourneyJob) error { | ||||
| 	// 任务执行失败了 | ||||
| 	if task.FailReason != "" { | ||||
| 		s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ | ||||
| 			"progress": -1, | ||||
| 			"progress": failedProgress, | ||||
| 			"err_msg":  task.FailReason, | ||||
| 		}) | ||||
| 		s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) | ||||
| 		return fmt.Errorf("task failed: %v", task.FailReason) | ||||
| 	} | ||||
|  | ||||
| @@ -181,18 +189,17 @@ func (s *Service) Notify(job model.MidJourneyJob) error { | ||||
| 	if task.ImageUrl != "" { | ||||
| 		job.OrgURL = task.ImageUrl | ||||
| 	} | ||||
| 	job.MessageId = task.Id | ||||
| 	tx := s.db.Updates(&job) | ||||
| 	if tx.Error != nil { | ||||
| 		return fmt.Errorf("error with update database: %v", tx.Error) | ||||
| 	} | ||||
| 	if task.Status == "SUCCESS" { | ||||
| 		// release lock task | ||||
| 		atomic.AddInt32(&s.HandledTaskNum, -1) | ||||
| 	} | ||||
| 	// 通知前端更新任务进度 | ||||
| 	if oldProgress != job.Progress { | ||||
| 		s.notifyQueue.RPush(job.UserId) | ||||
| 		message := sd.Running | ||||
| 		if job.Progress == 100 { | ||||
| 			message = sd.Finished | ||||
| 		} | ||||
| 		s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package oss | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| @@ -77,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	var imageData []byte | ||||
| func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) { | ||||
| 	var fileData []byte | ||||
| 	var err error | ||||
| 	if useProxy { | ||||
| 		imageData, err = utils.DownloadImage(imageURL, s.proxyURL) | ||||
| 		fileData, err = utils.DownloadImage(fileURL, s.proxyURL) | ||||
| 	} else { | ||||
| 		imageData, err = utils.DownloadImage(imageURL, "") | ||||
| 		fileData, err = utils.DownloadImage(fileURL, "") | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with download image: %v", err) | ||||
| 	} | ||||
| 	parse, err := url.Parse(imageURL) | ||||
| 	parse, err := url.Parse(fileURL) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| 	fileExt := utils.GetImgExt(parse.Path) | ||||
| 	objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	// 上传文件字节数据 | ||||
| 	err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) | ||||
| 	err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData)) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|   | ||||
| @@ -1,10 +1,17 @@ | ||||
| package oss | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| @@ -50,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	parse, err := url.Parse(imageURL) | ||||
| func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) { | ||||
| 	parse, err := url.Parse(fileURL) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| @@ -62,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	} | ||||
|  | ||||
| 	if useProxy { | ||||
| 		err = utils.DownloadFile(imageURL, filePath, s.proxyURL) | ||||
| 		err = utils.DownloadFile(fileURL, filePath, s.proxyURL) | ||||
| 	} else { | ||||
| 		err = utils.DownloadFile(imageURL, filePath, "") | ||||
| 		err = utils.DownloadFile(fileURL, filePath, "") | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with download image: %v", err) | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package oss | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| @@ -37,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) { | ||||
| 	return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil | ||||
| } | ||||
|  | ||||
| func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	var imageData []byte | ||||
| func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) { | ||||
| 	var fileData []byte | ||||
| 	var err error | ||||
| 	if useProxy { | ||||
| 		imageData, err = utils.DownloadImage(imageURL, s.proxyURL) | ||||
| 		fileData, err = utils.DownloadImage(fileURL, s.proxyURL) | ||||
| 	} else { | ||||
| 		imageData, err = utils.DownloadImage(imageURL, "") | ||||
| 		fileData, err = utils.DownloadImage(fileURL, "") | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with download image: %v", err) | ||||
| 	} | ||||
| 	parse, err := url.Parse(imageURL) | ||||
| 	parse, err := url.Parse(fileURL) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| @@ -58,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 		context.Background(), | ||||
| 		s.config.Bucket, | ||||
| 		filename, | ||||
| 		strings.NewReader(string(imageData)), | ||||
| 		int64(len(imageData)), | ||||
| 		strings.NewReader(string(fileData)), | ||||
| 		int64(len(fileData)), | ||||
| 		minio.PutObjectOptions{ContentType: "image/png"}) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package oss | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| @@ -86,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	var imageData []byte | ||||
| func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) { | ||||
| 	var fileData []byte | ||||
| 	var err error | ||||
| 	if useProxy { | ||||
| 		imageData, err = utils.DownloadImage(imageURL, s.proxyURL) | ||||
| 		fileData, err = utils.DownloadImage(fileURL, s.proxyURL) | ||||
| 	} else { | ||||
| 		imageData, err = utils.DownloadImage(imageURL, "") | ||||
| 		fileData, err = utils.DownloadImage(fileURL, "") | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with download image: %v", err) | ||||
| 	} | ||||
| 	parse, err := url.Parse(imageURL) | ||||
| 	parse, err := url.Parse(fileURL) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| @@ -106,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	ret := storage.PutRet{} | ||||
| 	extra := storage.PutExtra{} | ||||
| 	// 上传文件字节数据 | ||||
| 	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra) | ||||
| 	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package oss | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import "github.com/gin-gonic/gin" | ||||
|  | ||||
| const Local = "LOCAL" | ||||
| @@ -16,7 +23,7 @@ type File struct { | ||||
| } | ||||
| type Uploader interface { | ||||
| 	PutFile(ctx *gin.Context, name string) (File, error) | ||||
| 	PutImg(imageURL string, useProxy bool) (string, error) | ||||
| 	PutUrlFile(url string, useProxy bool) (string, error) | ||||
| 	PutBase64(imageData string) (string, error) | ||||
| 	Delete(fileURL string) error | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,14 @@ | ||||
| package oss | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"geekai/core/types" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
|   | ||||
| @@ -1,12 +1,20 @@ | ||||
| package payment | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/smartwalle/alipay/v3" | ||||
| 	"log" | ||||
| 	"net/url" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"github.com/go-pay/gopay" | ||||
| 	"github.com/go-pay/gopay/alipay" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| @@ -28,93 +36,90 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) { | ||||
| 		return nil, fmt.Errorf("error with read App Private key: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	xClient, err := alipay.New(config.AppId, priKey, !config.SandBox) | ||||
| 	client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error with initialize alipay service: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil { | ||||
| 		return nil, fmt.Errorf("error with loading App PublicKey: %v", err) | ||||
| 	} | ||||
| 	if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil { | ||||
| 		return nil, fmt.Errorf("error with loading alipay RootCert: %v", err) | ||||
| 	} | ||||
| 	if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil { | ||||
| 		return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err) | ||||
| 	//client.DebugSwitch = gopay.DebugOn // 开启调试模式 | ||||
| 	client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间 | ||||
| 		SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8 | ||||
| 		SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2 | ||||
| 		SetReturnUrl(config.ReturnURL). // 设置返回URL | ||||
| 		SetNotifyUrl(config.NotifyURL) | ||||
|  | ||||
| 	if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil { | ||||
| 		return nil, fmt.Errorf("error with load payment public key: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return &AlipayService{config: &config, client: xClient}, nil | ||||
| 	return &AlipayService{config: &config, client: client}, nil | ||||
| } | ||||
|  | ||||
| func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) { | ||||
| 	var p = alipay.TradeWapPay{} | ||||
| 	p.NotifyURL = notifyURL | ||||
| 	p.ReturnURL = returnURL | ||||
| 	p.Subject = subject | ||||
| 	p.OutTradeNo = outTradeNo | ||||
| 	p.TotalAmount = Amount | ||||
| 	p.ProductCode = "QUICK_WAP_WAY" | ||||
| 	res, err := s.client.TradeWapPay(p) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) { | ||||
| 	bm := make(gopay.BodyMap) | ||||
| 	bm.Set("subject", subject) | ||||
| 	bm.Set("out_trade_no", outTradeNo) | ||||
| 	bm.Set("quit_url", s.config.ReturnURL) | ||||
| 	bm.Set("total_amount", amount) | ||||
| 	bm.Set("product_code", "QUICK_WAP_WAY") | ||||
| 	return s.client.TradeWapPay(context.Background(), bm) | ||||
| } | ||||
|  | ||||
| 	return res.String(), err | ||||
| } | ||||
|  | ||||
| func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) { | ||||
| 	var p = alipay.TradePagePay{} | ||||
| 	p.NotifyURL = notifyURL | ||||
| 	p.ReturnURL = returnURL | ||||
| 	p.Subject = subject | ||||
| 	p.OutTradeNo = outTradeNo | ||||
| 	p.TotalAmount = amount | ||||
| 	p.ProductCode = "FAST_INSTANT_TRADE_PAY" | ||||
| 	res, err := s.client.TradePagePay(p) | ||||
| 	if err != nil { | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	return res.String(), err | ||||
| func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) { | ||||
| 	bm := make(gopay.BodyMap) | ||||
| 	bm.Set("subject", subject) | ||||
| 	bm.Set("out_trade_no", outTradeNo) | ||||
| 	bm.Set("total_amount", amount) | ||||
| 	bm.Set("product_code", "FAST_INSTANT_TRADE_PAY") | ||||
| 	return s.client.TradePagePay(context.Background(), bm) | ||||
| } | ||||
|  | ||||
| // TradeVerify 交易验证 | ||||
| func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo { | ||||
| 	err := s.client.VerifySign(reqForm) | ||||
| func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo { | ||||
| 	notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法 | ||||
| 	if err != nil { | ||||
| 		log.Println("异步通知验证签名发生错误", err) | ||||
| 		return NotifyVo{ | ||||
| 			Status:  0, | ||||
| 			Message: "异步通知验证签名发生错误", | ||||
| 			Status:  Failure, | ||||
| 			Message: "error with parse notify request: " + err.Error(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s.TradeQuery(reqForm.Get("out_trade_no")) | ||||
| 	_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq) | ||||
| 	if err != nil { | ||||
| 		return NotifyVo{ | ||||
| 			Status:  Failure, | ||||
| 			Message: "error with verify sign: " + err.Error(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s.TradeQuery(request.Form.Get("out_trade_no")) | ||||
| } | ||||
|  | ||||
| func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo { | ||||
| 	var p = alipay.TradeQuery{} | ||||
| 	p.OutTradeNo = outTradeNo | ||||
| 	rsp, err := s.client.TradeQuery(p) | ||||
| 	bm := make(gopay.BodyMap) | ||||
| 	bm.Set("out_trade_no", outTradeNo) | ||||
|  | ||||
| 	//查询订单 | ||||
| 	rsp, err := s.client.TradeQuery(context.Background(), bm) | ||||
| 	if err != nil { | ||||
| 		return NotifyVo{ | ||||
| 			Status:  0, | ||||
| 			Status:  Failure, | ||||
| 			Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" { | ||||
| 	if rsp.Response.TradeStatus == "TRADE_SUCCESS" { | ||||
| 		return NotifyVo{ | ||||
| 			Status:     1, | ||||
| 			OutTradeNo: rsp.OutTradeNo, | ||||
| 			TradeNo:    rsp.TradeNo, | ||||
| 			Amount:     rsp.TotalAmount, | ||||
| 			Subject:    rsp.Subject, | ||||
| 			Status:     Success, | ||||
| 			OutTradeNo: rsp.Response.OutTradeNo, | ||||
| 			TradeId:    rsp.Response.TradeNo, | ||||
| 			Amount:     rsp.Response.TotalAmount, | ||||
| 			Subject:    rsp.Response.Subject, | ||||
| 			Message:    "OK", | ||||
| 		} | ||||
| 	} else { | ||||
| 		return NotifyVo{ | ||||
| 			Status:  0, | ||||
| 			Status:  Failure, | ||||
| 			Message: "异步查询验证订单信息发生错误" + outTradeNo, | ||||
| 		} | ||||
| 	} | ||||
| @@ -127,16 +132,3 @@ func readKey(filename string) (string, error) { | ||||
| 	} | ||||
| 	return string(data), nil | ||||
| } | ||||
|  | ||||
| type NotifyVo struct { | ||||
| 	Status     int | ||||
| 	OutTradeNo string | ||||
| 	TradeNo    string | ||||
| 	Amount     string | ||||
| 	Message    string | ||||
| 	Subject    string | ||||
| } | ||||
|  | ||||
| func (v NotifyVo) Success() bool { | ||||
| 	return v.Status == 1 | ||||
| } | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package payment | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"crypto/md5" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| @@ -42,6 +49,8 @@ type HuPiPayReq struct { | ||||
| 	CallbackURL  string `json:"callback_url"` | ||||
| 	Time         string `json:"time"` | ||||
| 	NonceStr     string `json:"nonce_str"` | ||||
| 	Type         string `json:"type"` | ||||
| 	WapUrl       string `json:"wap_url"` | ||||
| } | ||||
|  | ||||
| type HuPiResp struct { | ||||
|   | ||||
| @@ -1,12 +1,19 @@ | ||||
| package payment | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"crypto/md5" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| @@ -14,12 +21,12 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type PayJS struct { | ||||
| type JPayService struct { | ||||
| 	config *types.JPayConfig | ||||
| } | ||||
|  | ||||
| func NewPayJS(appConfig *types.AppConfig) *PayJS { | ||||
| 	return &PayJS{ | ||||
| func NewJPayService(appConfig *types.AppConfig) *JPayService { | ||||
| 	return &JPayService{ | ||||
| 		config: &appConfig.JPayConfig, | ||||
| 	} | ||||
| } | ||||
| @@ -46,7 +53,7 @@ func (r JPayReps) IsOK() bool { | ||||
| 	return r.ReturnMsg == "SUCCESS" | ||||
| } | ||||
|  | ||||
| func (js *PayJS) Pay(param JPayReq) JPayReps { | ||||
| func (js *JPayService) Pay(param JPayReq) JPayReps { | ||||
| 	param.NotifyURL = js.config.NotifyURL | ||||
| 	var p = url.Values{} | ||||
| 	encode := utils.JsonEncode(param) | ||||
| @@ -79,13 +86,13 @@ func (js *PayJS) Pay(param JPayReq) JPayReps { | ||||
| 	return data | ||||
| } | ||||
|  | ||||
| func (js *PayJS) PayH5(p url.Values) string { | ||||
| func (js *JPayService) PayH5(p url.Values) string { | ||||
| 	p.Add("mchid", js.config.AppId) | ||||
| 	p.Add("sign", js.sign(p)) | ||||
| 	return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode()) | ||||
| } | ||||
|  | ||||
| func (js *PayJS) sign(params url.Values) string { | ||||
| func (js *JPayService) sign(params url.Values) string { | ||||
| 	params.Del(`sign`) | ||||
| 	var keys = make([]string, 0, 0) | ||||
| 	for key := range params { | ||||
| @@ -110,20 +117,18 @@ func (js *PayJS) sign(params url.Values) string { | ||||
| 	return strings.ToUpper(md5res) | ||||
| } | ||||
|  | ||||
| // Check 查询订单支付状态 | ||||
| // TradeVerify 查询订单支付状态 | ||||
| // @param tradeNo 支付平台交易 ID | ||||
| func (js *PayJS) Check(tradeNo string) error { | ||||
| func (js *JPayService) TradeVerify(tradeNo string) error { | ||||
| 	apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL) | ||||
| 	params := url.Values{} | ||||
| 	params.Add("payjs_order_id", tradeNo) | ||||
| 	params.Add("sign", js.sign(params)) | ||||
| 	data := strings.NewReader(params.Encode()) | ||||
| 	resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data) | ||||
| 	defer resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with http reqeust: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
|   | ||||
							
								
								
									
										19
									
								
								api/service/payment/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								api/service/payment/types.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package payment | ||||
|  | ||||
| type NotifyVo struct { | ||||
| 	Status     int | ||||
| 	OutTradeNo string // 商户订单号 | ||||
| 	TradeId    string // 交易ID | ||||
| 	Amount     string // 交易金额 | ||||
| 	Message    string | ||||
| 	Subject    string | ||||
| } | ||||
|  | ||||
| func (v NotifyVo) Success() bool { | ||||
| 	return v.Status == Success | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	Success = 0 | ||||
| 	Failure = 1 | ||||
| ) | ||||
							
								
								
									
										135
									
								
								api/service/payment/wepay_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								api/service/payment/wepay_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | ||||
| package payment | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"github.com/go-pay/gopay" | ||||
| 	"github.com/go-pay/gopay/wechat/v3" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type WechatPayService struct { | ||||
| 	config *types.WechatPayConfig | ||||
| 	client *wechat.ClientV3 | ||||
| } | ||||
|  | ||||
| func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) { | ||||
| 	config := appConfig.WechatPayConfig | ||||
| 	if !config.Enabled { | ||||
| 		logger.Info("Disabled WechatPay service") | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	priKey, err := readKey(config.PrivateKey) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error with read App Private key: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error with initialize WechatPay service: %v", err) | ||||
| 	} | ||||
| 	err = client.AutoVerifySign() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error with autoVerifySign: %v", err) | ||||
| 	} | ||||
| 	//client.DebugSwitch = gopay.DebugOn | ||||
|  | ||||
| 	return &WechatPayService{config: &config, client: client}, nil | ||||
| } | ||||
|  | ||||
| func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) { | ||||
| 	expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) | ||||
| 	// 初始化 BodyMap | ||||
| 	bm := make(gopay.BodyMap) | ||||
| 	bm.Set("appid", s.config.AppId). | ||||
| 		Set("mchid", s.config.MchId). | ||||
| 		Set("description", subject). | ||||
| 		Set("out_trade_no", outTradeNo). | ||||
| 		Set("time_expire", expire). | ||||
| 		Set("notify_url", s.config.NotifyURL). | ||||
| 		SetBodyMap("amount", func(bm gopay.BodyMap) { | ||||
| 			bm.Set("total", amount). | ||||
| 				Set("currency", "CNY") | ||||
| 		}) | ||||
|  | ||||
| 	wxRsp, err := s.client.V3TransactionNative(context.Background(), bm) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with client v3 transaction Native: %v", err) | ||||
| 	} | ||||
| 	if wxRsp.Code != wechat.Success { | ||||
| 		return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error) | ||||
| 	} | ||||
| 	return wxRsp.Response.CodeUrl, nil | ||||
| } | ||||
|  | ||||
| func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) { | ||||
| 	expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) | ||||
| 	// 初始化 BodyMap | ||||
| 	bm := make(gopay.BodyMap) | ||||
| 	bm.Set("appid", s.config.AppId). | ||||
| 		Set("mchid", s.config.MchId). | ||||
| 		Set("description", subject). | ||||
| 		Set("out_trade_no", outTradeNo). | ||||
| 		Set("time_expire", expire). | ||||
| 		Set("notify_url", s.config.NotifyURL). | ||||
| 		SetBodyMap("amount", func(bm gopay.BodyMap) { | ||||
| 			bm.Set("total", amount). | ||||
| 				Set("currency", "CNY") | ||||
| 		}). | ||||
| 		SetBodyMap("scene_info", func(bm gopay.BodyMap) { | ||||
| 			bm.Set("payer_client_ip", ip). | ||||
| 				SetBodyMap("h5_info", func(bm gopay.BodyMap) { | ||||
| 					bm.Set("type", "Wap") | ||||
| 				}) | ||||
| 		}) | ||||
|  | ||||
| 	wxRsp, err := s.client.V3TransactionH5(context.Background(), bm) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with client v3 transaction H5: %v", err) | ||||
| 	} | ||||
| 	if wxRsp.Code != wechat.Success { | ||||
| 		return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error) | ||||
| 	} | ||||
| 	return wxRsp.Response.H5Url, nil | ||||
| } | ||||
|  | ||||
| type NotifyResponse struct { | ||||
| 	Code    string `json:"code"` | ||||
| 	Message string `xml:"message"` | ||||
| } | ||||
|  | ||||
| // TradeVerify 交易验证 | ||||
| func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo { | ||||
| 	notifyReq, err := wechat.V3ParseNotify(request) | ||||
| 	if err != nil { | ||||
| 		return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)} | ||||
| 	} | ||||
|  | ||||
| 	// TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签 | ||||
| 	//err = notifyReq.VerifySignByPK(s.client.WxPublicKey()) | ||||
| 	//if err != nil { | ||||
| 	//	return fmt.Errorf("error with client v3 verify sign: %v", err) | ||||
| 	//} | ||||
|  | ||||
| 	// 解密支付密文,验证订单信息 | ||||
| 	result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key) | ||||
| 	if err != nil { | ||||
| 		return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)} | ||||
| 	} | ||||
|  | ||||
| 	return NotifyVo{ | ||||
| 		Status:     Success, | ||||
| 		OutTradeNo: result.OutTradeNo, | ||||
| 		TradeId:    result.TransactionId, | ||||
| 		Amount:     fmt.Sprintf("%.2f", float64(result.Amount.Total)/100), | ||||
| 	} | ||||
| } | ||||
| @@ -1,11 +1,18 @@ | ||||
| package sd | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| @@ -18,28 +25,14 @@ type ServicePool struct { | ||||
| 	notifyQueue *store.RedisQueue | ||||
| 	db          *gorm.DB | ||||
| 	Clients     *types.LMap[uint, *types.WsClient] // UserId => Client | ||||
| 	uploader    *oss.UploaderManager | ||||
| 	levelDB     *store.LevelDB | ||||
| } | ||||
|  | ||||
| func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { | ||||
| func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool { | ||||
| 	services := make([]*Service, 0) | ||||
| 	taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) | ||||
| 	notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) | ||||
| 	// create mj client and service | ||||
| 	for _, config := range appConfig.SdConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// create sd service | ||||
| 		name := fmt.Sprintf("StableDifffusion Service-%s", config.Model) | ||||
| 		service := NewService(name, config, taskQueue, notifyQueue, db, manager) | ||||
| 		// run sd service | ||||
| 		go func() { | ||||
| 			service.Run() | ||||
| 		}() | ||||
|  | ||||
| 		services = append(services, service) | ||||
| 	} | ||||
|  | ||||
| 	return &ServicePool{ | ||||
| 		taskQueue:   taskQueue, | ||||
| @@ -47,6 +40,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa | ||||
| 		services:    services, | ||||
| 		db:          db, | ||||
| 		Clients:     types.NewLMap[uint, *types.WsClient](), | ||||
| 		uploader:    manager, | ||||
| 		levelDB:     levelDB, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) { | ||||
| 	// stop old service | ||||
| 	for _, s := range p.services { | ||||
| 		s.Stop() | ||||
| 	} | ||||
| 	p.services = make([]*Service, 0) | ||||
|  | ||||
| 	for k, config := range configs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// create sd service | ||||
| 		name := fmt.Sprintf(" sd-service-%d", k) | ||||
| 		service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB) | ||||
| 		// run sd service | ||||
| 		go func() { | ||||
| 			service.Run() | ||||
| 		}() | ||||
|  | ||||
| 		p.services = append(p.services, service) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -60,16 +79,16 @@ func (p *ServicePool) CheckTaskNotify() { | ||||
| 	go func() { | ||||
| 		logger.Info("Running Stable-Diffusion task notify checking ...") | ||||
| 		for { | ||||
| 			var userId uint | ||||
| 			err := p.notifyQueue.LPop(&userId) | ||||
| 			var message NotifyMessage | ||||
| 			err := p.notifyQueue.LPop(&message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			client := p.Clients.Get(userId) | ||||
| 			client := p.Clients.Get(uint(message.UserId)) | ||||
| 			if client == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			err = client.Send([]byte("Task Updated")) | ||||
| 			err = client.Send([]byte(message.Message)) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| @@ -113,7 +132,7 @@ func (p *ServicePool) CheckTaskStatus() { | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|   | ||||
| @@ -1,17 +1,25 @@ | ||||
| package sd | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/service" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // SD 绘画服务 | ||||
| @@ -24,9 +32,11 @@ type Service struct { | ||||
| 	db            *gorm.DB | ||||
| 	uploadManager *oss.UploaderManager | ||||
| 	name          string // service name | ||||
| 	leveldb       *store.LevelDB | ||||
| 	running       bool // 运行状态 | ||||
| } | ||||
|  | ||||
| func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { | ||||
| func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { | ||||
| 	config.ApiURL = strings.TrimRight(config.ApiURL, "/") | ||||
| 	return &Service{ | ||||
| 		name:          name, | ||||
| @@ -35,23 +45,39 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor | ||||
| 		taskQueue:     taskQueue, | ||||
| 		notifyQueue:   notifyQueue, | ||||
| 		db:            db, | ||||
| 		leveldb:       levelDB, | ||||
| 		uploadManager: manager, | ||||
| 		running:       true, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	for { | ||||
| 	logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) | ||||
| 	for s.running { | ||||
| 		var task types.SdTask | ||||
| 		err := s.taskQueue.LPop(&task) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("taking task with error: %v", err) | ||||
| 			continue | ||||
| 		} | ||||
| 		// 翻译提示词 | ||||
|  | ||||
| 		// translate prompt | ||||
| 		if utils.HasChinese(task.Params.Prompt) { | ||||
| 			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt)) | ||||
| 			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") | ||||
| 			if err == nil { | ||||
| 				task.Params.Prompt = content | ||||
| 			} else { | ||||
| 				logger.Warnf("error with translate prompt: %v", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// translate negative prompt | ||||
| 		if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { | ||||
| 			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini") | ||||
| 			if err == nil { | ||||
| 				task.Params.NegPrompt = content | ||||
| 			} else { | ||||
| 				logger.Warnf("error with translate prompt: %v", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @@ -65,12 +91,16 @@ func (s *Service) Run() { | ||||
| 				"err_msg":  err.Error(), | ||||
| 			}) | ||||
| 			// 通知前端,任务失败 | ||||
| 			s.notifyQueue.RPush(task.UserId) | ||||
| 			s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) | ||||
| 			continue | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) Stop() { | ||||
| 	s.running = false | ||||
| } | ||||
|  | ||||
| // Txt2ImgReq 文生图请求实体 | ||||
| type Txt2ImgReq struct { | ||||
| 	Prompt            string  `json:"prompt"` | ||||
| @@ -81,6 +111,7 @@ type Txt2ImgReq struct { | ||||
| 	Width             int     `json:"width"` | ||||
| 	Height            int     `json:"height"` | ||||
| 	SamplerName       string  `json:"sampler_name"` | ||||
| 	Scheduler         string  `json:"scheduler"` | ||||
| 	EnableHr          bool    `json:"enable_hr,omitempty"` | ||||
| 	HrScale           int     `json:"hr_scale,omitempty"` | ||||
| 	HrUpscaler        string  `json:"hr_upscaler,omitempty"` | ||||
| @@ -108,12 +139,14 @@ type TaskProgressResp struct { | ||||
| func (s *Service) Txt2Img(task types.SdTask) error { | ||||
| 	body := Txt2ImgReq{ | ||||
| 		Prompt:         task.Params.Prompt, | ||||
| 		NegativePrompt: task.Params.NegativePrompt, | ||||
| 		NegativePrompt: task.Params.NegPrompt, | ||||
| 		Steps:          task.Params.Steps, | ||||
| 		CfgScale:       task.Params.CfgScale, | ||||
| 		Width:          task.Params.Width, | ||||
| 		Height:         task.Params.Height, | ||||
| 		SamplerName:    task.Params.Sampler, | ||||
| 		Scheduler:      task.Params.Scheduler, | ||||
| 		ForceTaskId:    task.Params.TaskId, | ||||
| 	} | ||||
| 	if task.Params.Seed > 0 { | ||||
| 		body.Seed = task.Params.Seed | ||||
| @@ -129,8 +162,13 @@ func (s *Service) Txt2Img(task types.SdTask) error { | ||||
| 	var errChan = make(chan error) | ||||
| 	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) | ||||
| 	logger.Debugf("send image request to %s", apiURL) | ||||
| 	// send a request to sd api endpoint | ||||
| 	go func() { | ||||
| 		response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL) | ||||
| 		response, err := s.httpClient.R(). | ||||
| 			SetHeader("Authorization", s.config.ApiKey). | ||||
| 			SetBody(body). | ||||
| 			SetSuccessResult(&res). | ||||
| 			Post(apiURL) | ||||
| 		if err != nil { | ||||
| 			errChan <- err | ||||
| 			return | ||||
| @@ -154,27 +192,35 @@ func (s *Service) Txt2Img(task types.SdTask) error { | ||||
| 			return | ||||
| 		} | ||||
| 		task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1)) | ||||
| 		s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)}) | ||||
| 		s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt}) | ||||
| 		errChan <- nil | ||||
| 	}() | ||||
|  | ||||
| 	// waiting for task finish | ||||
| 	for { | ||||
| 		select { | ||||
| 		case err := <-errChan: // 任务完成 | ||||
| 		case err := <-errChan: | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			// task finished | ||||
| 			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) | ||||
| 			s.notifyQueue.RPush(task.UserId) | ||||
| 			s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) | ||||
| 			// 从 leveldb 中删除预览图片数据 | ||||
| 			_ = s.leveldb.Delete(task.Params.TaskId) | ||||
| 			return nil | ||||
| 		default: | ||||
| 			err, resp := s.checkTaskProgress() | ||||
| 			// 更新任务进度 | ||||
| 			if err == nil && resp.Progress > 0 { | ||||
| 				logger.Debugf("Check task progress: %+v", resp.Progress) | ||||
| 				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) | ||||
| 				// 发送更新状态信号 | ||||
| 				s.notifyQueue.RPush(task.UserId) | ||||
| 				s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) | ||||
| 				// 保存预览图片数据 | ||||
| 				if resp.CurrentImage != "" { | ||||
| 					_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) | ||||
| 				} | ||||
| 			} | ||||
| 			time.Sleep(time.Second) | ||||
| 		} | ||||
| @@ -186,7 +232,10 @@ func (s *Service) Txt2Img(task types.SdTask) error { | ||||
| func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { | ||||
| 	apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) | ||||
| 	var res TaskProgressResp | ||||
| 	response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL) | ||||
| 	response, err := s.httpClient.R(). | ||||
| 		SetHeader("Authorization", s.config.ApiKey). | ||||
| 		SetSuccessResult(&res). | ||||
| 		Get(apiURL) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
|   | ||||
| @@ -1,47 +1,24 @@ | ||||
| package sd | ||||
|  | ||||
| import logger2 "chatplus/logger" | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import logger2 "geekai/logger" | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type TaskInfo struct { | ||||
| 	UserId      uint          `json:"user_id"` | ||||
| 	SessionId   string        `json:"session_id"` | ||||
| type NotifyMessage struct { | ||||
| 	UserId  int    `json:"user_id"` | ||||
| 	JobId   int    `json:"job_id"` | ||||
| 	TaskId      string        `json:"task_id"` | ||||
| 	Data        []interface{} `json:"data"` | ||||
| 	EventData   interface{}   `json:"event_data"` | ||||
| 	FnIndex     int           `json:"fn_index"` | ||||
| 	SessionHash string        `json:"session_hash"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type CBReq struct { | ||||
| 	UserId    uint | ||||
| 	SessionId string | ||||
| 	JobId     int | ||||
| 	TaskId    string | ||||
| 	ImageName string | ||||
| 	ImageData string | ||||
| 	Progress  int | ||||
| 	Seed      int64 | ||||
| 	Success   bool | ||||
| 	Message   string | ||||
| } | ||||
|  | ||||
| var ParamKeys = map[string]int{ | ||||
| 	"task_id":         0, | ||||
| 	"prompt":          1, | ||||
| 	"negative_prompt": 2, | ||||
| 	"steps":           4, | ||||
| 	"sampler":         5, | ||||
| 	"face_fix":        7, // 面部修复 | ||||
| 	"cfg_scale":       8, | ||||
| 	"seed":            27, | ||||
| 	"height":          10, | ||||
| 	"width":           9, | ||||
| 	"hd_fix":          11, | ||||
| 	"hd_redraw_rate":  12, //高清修复重绘幅度 | ||||
| 	"hd_scale":        13, // 高清修复放大倍数 | ||||
| 	"hd_scale_alg":    14, // 高清修复放大算法 | ||||
| 	"hd_sample_num":   15, // 高清修复采样次数 | ||||
| } | ||||
| const ( | ||||
| 	Running  = "RUNNING" | ||||
| 	Finished = "FINISH" | ||||
| 	Failed   = "FAIL" | ||||
| ) | ||||
|   | ||||
| @@ -1,8 +1,15 @@ | ||||
| package sms | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi" | ||||
| ) | ||||
|  | ||||
|   | ||||
| @@ -1,9 +1,16 @@ | ||||
| package sms | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"geekai/utils" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package sms | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| const Ali = "ALI" | ||||
| const Bao = "BAO" | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,15 @@ | ||||
| package sms | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
|   | ||||
| @@ -1,11 +1,20 @@ | ||||
| package service | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	"mime" | ||||
| 	"net/smtp" | ||||
| 	"net/textproto" | ||||
| ) | ||||
|  | ||||
| type SmtpService struct { | ||||
| @@ -19,12 +28,18 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService { | ||||
| } | ||||
|  | ||||
| func (s *SmtpService) SendVerifyCode(to string, code int) error { | ||||
| 	subject := "ChatPlus注册验证码" | ||||
| 	body := fmt.Sprintf("您正在注册 ChatPlus AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code) | ||||
| 	subject := fmt.Sprintf("%s 注册验证码", s.config.AppName) | ||||
| 	body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code) | ||||
|  | ||||
| 	// 设置SMTP客户端配置 | ||||
| 	auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host) | ||||
| 	if s.config.UseTls { | ||||
| 		return s.sendTLS(auth, to, subject, body) | ||||
| 	} else { | ||||
| 		return s.send(auth, to, subject, body) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *SmtpService) send(auth smtp.Auth, to string, subject string, body string) error { | ||||
| 	// 对主题进行MIME编码 | ||||
| 	encodedSubject := mime.QEncoding.Encode("UTF-8", subject) | ||||
| 	// 组装邮件 | ||||
| @@ -34,11 +49,83 @@ func (s *SmtpService) SendVerifyCode(to string, code int) error { | ||||
| 	message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject)) | ||||
| 	message.WriteString("\r\n" + body) | ||||
|  | ||||
| 	// 发送邮件 | ||||
| 	// 发送邮件 | ||||
| 	err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes()) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error sending email: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
|  | ||||
| } | ||||
|  | ||||
| func (s *SmtpService) sendTLS(auth smtp.Auth, to string, subject string, body string) error { | ||||
| 	// TLS配置 | ||||
| 	tlsConfig := &tls.Config{ | ||||
| 		ServerName: s.config.Host, | ||||
| 	} | ||||
|  | ||||
| 	// 建立TLS连接 | ||||
| 	conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error connecting to SMTP server: %v", err) | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	client, err := smtp.NewClient(conn, s.config.Host) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error creating SMTP client: %v", err) | ||||
| 	} | ||||
| 	defer client.Quit() | ||||
|  | ||||
| 	// 身份验证 | ||||
| 	if err = client.Auth(auth); err != nil { | ||||
| 		return fmt.Errorf("error authenticating: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 设置寄件人 | ||||
| 	if err = client.Mail(s.config.From); err != nil { | ||||
| 		return fmt.Errorf("error setting sender: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 设置收件人 | ||||
| 	if err = client.Rcpt(to); err != nil { | ||||
| 		return fmt.Errorf("error setting recipient: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 发送邮件内容 | ||||
| 	wc, err := client.Data() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error getting data writer: %v", err) | ||||
| 	} | ||||
| 	defer wc.Close() | ||||
|  | ||||
| 	header := make(textproto.MIMEHeader) | ||||
| 	header.Set("From", s.config.From) | ||||
| 	header.Set("To", to) | ||||
| 	header.Set("Subject", subject) | ||||
|  | ||||
| 	// 将邮件头写入 | ||||
| 	for key, values := range header { | ||||
| 		for _, value := range values { | ||||
| 			_, err = fmt.Fprintf(wc, "%s: %s\r\n", key, value) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("error sending email header: %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	_, _ = fmt.Fprintln(wc) | ||||
| 	// 将邮件内容写入 | ||||
| 	_, err = fmt.Fprintf(wc, body) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error sending email: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 发送完毕 | ||||
| 	err = wc.Close() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error closing data writer: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,12 @@ | ||||
| package service | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
|   | ||||
							
								
								
									
										355
									
								
								api/service/suno/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										355
									
								
								api/service/suno/service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,355 @@ | ||||
| package suno | ||||
|  | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
| // * Copyright 2023 The Geek-AI Authors. All rights reserved. | ||||
| // * Use of this source code is governed by a Apache-2.0 license | ||||
| // * that can be found in the LICENSE file. | ||||
| // * @Author yangjian102621@163.com | ||||
| // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"geekai/core/types" | ||||
| 	logger2 "geekai/logger" | ||||
| 	"geekai/service/oss" | ||||
| 	"geekai/service/sd" | ||||
| 	"geekai/store" | ||||
| 	"geekai/store/model" | ||||
| 	"geekai/utils" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"io" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type Service struct { | ||||
| 	httpClient    *req.Client | ||||
| 	db            *gorm.DB | ||||
| 	uploadManager *oss.UploaderManager | ||||
| 	taskQueue     *store.RedisQueue | ||||
| 	notifyQueue   *store.RedisQueue | ||||
| 	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client | ||||
| } | ||||
|  | ||||
| func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { | ||||
| 	return &Service{ | ||||
| 		httpClient:    req.C().SetTimeout(time.Minute * 3), | ||||
| 		db:            db, | ||||
| 		taskQueue:     store.NewRedisQueue("Suno_Task_Queue", redisCli), | ||||
| 		notifyQueue:   store.NewRedisQueue("Suno_Notify_Queue", redisCli), | ||||
| 		Clients:       types.NewLMap[uint, *types.WsClient](), | ||||
| 		uploadManager: manager, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) PushTask(task types.SunoTask) { | ||||
| 	logger.Infof("add a new Suno task to the task list: %+v", task) | ||||
| 	s.taskQueue.RPush(task) | ||||
| } | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	// 将数据库中未提交的人物加载到队列 | ||||
| 	var jobs []model.SunoJob | ||||
| 	s.db.Where("task_id", "").Find(&jobs) | ||||
| 	for _, v := range jobs { | ||||
| 		s.PushTask(types.SunoTask{ | ||||
| 			Id:           v.Id, | ||||
| 			Channel:      v.Channel, | ||||
| 			UserId:       v.UserId, | ||||
| 			Type:         v.Type, | ||||
| 			Title:        v.Title, | ||||
| 			RefTaskId:    v.RefTaskId, | ||||
| 			RefSongId:    v.RefSongId, | ||||
| 			Prompt:       v.Prompt, | ||||
| 			Tags:         v.Tags, | ||||
| 			Model:        v.ModelName, | ||||
| 			Instrumental: v.Instrumental, | ||||
| 			ExtendSecs:   v.ExtendSecs, | ||||
| 		}) | ||||
| 	} | ||||
| 	logger.Info("Starting Suno job consumer...") | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var task types.SunoTask | ||||
| 			err := s.taskQueue.LPop(&task) | ||||
| 			if err != nil { | ||||
| 				logger.Errorf("taking task with error: %v", err) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			r, err := s.Create(task) | ||||
| 			if err != nil { | ||||
| 				logger.Errorf("create task with error: %v", err) | ||||
| 				s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ | ||||
| 					"err_msg":  err.Error(), | ||||
| 					"progress": 101, | ||||
| 				}) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// 更新任务信息 | ||||
| 			s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ | ||||
| 				"task_id": r.Data, | ||||
| 				"channel": r.Channel, | ||||
| 			}) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| type RespVo struct { | ||||
| 	Code    string `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Data    string `json:"data"` | ||||
| 	Channel string `json:"channel,omitempty"` | ||||
| } | ||||
|  | ||||
| func (s *Service) Create(task types.SunoTask) (RespVo, error) { | ||||
| 	// 读取 API KEY | ||||
| 	var apiKey model.ApiKey | ||||
| 	session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true) | ||||
| 	if task.Channel != "" { | ||||
| 		session = session.Where("api_url", task.Channel) | ||||
| 	} | ||||
| 	tx := session.Order("last_used_at DESC").First(&apiKey) | ||||
| 	if tx.Error != nil { | ||||
| 		return RespVo{}, errors.New("no available API KEY for Suno") | ||||
| 	} | ||||
|  | ||||
| 	reqBody := map[string]interface{}{ | ||||
| 		"task_id":           task.RefTaskId, | ||||
| 		"continue_clip_id":  task.RefSongId, | ||||
| 		"continue_at":       task.ExtendSecs, | ||||
| 		"make_instrumental": task.Instrumental, | ||||
| 	} | ||||
| 	// 灵感模式 | ||||
| 	if task.Type == 1 { | ||||
| 		reqBody["gpt_description_prompt"] = task.Prompt | ||||
| 	} else { // 自定义模式 | ||||
| 		reqBody["prompt"] = task.Prompt | ||||
| 		reqBody["tags"] = task.Tags | ||||
| 		reqBody["mv"] = task.Model | ||||
| 		reqBody["title"] = task.Title | ||||
| 	} | ||||
|  | ||||
| 	var res RespVo | ||||
| 	apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL) | ||||
| 	logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+apiKey.Value). | ||||
| 		SetBody(reqBody). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return RespVo{}, fmt.Errorf("请求 API 出错:%v", err) | ||||
| 	} | ||||
|  | ||||
| 	body, _ := io.ReadAll(r.Body) | ||||
| 	err = json.Unmarshal(body, &res) | ||||
| 	if err != nil { | ||||
| 		return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != "success" { | ||||
| 		return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message) | ||||
| 	} | ||||
| 	res.Channel = apiKey.ApiURL | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| func (s *Service) CheckTaskNotify() { | ||||
| 	go func() { | ||||
| 		logger.Info("Running Suno task notify checking ...") | ||||
| 		for { | ||||
| 			var message sd.NotifyMessage | ||||
| 			err := s.notifyQueue.LPop(&message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			client := s.Clients.Get(uint(message.UserId)) | ||||
| 			if client == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			err = client.Send([]byte(message.Message)) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (s *Service) DownloadImages() { | ||||
| 	go func() { | ||||
| 		var items []model.SunoJob | ||||
| 		for { | ||||
| 			res := s.db.Where("progress", 102).Find(&items) | ||||
| 			if res.Error != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for _, v := range items { | ||||
| 				// 下载图片和音频 | ||||
| 				logger.Infof("try download cover image: %s", v.CoverURL) | ||||
| 				coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true) | ||||
| 				if err != nil { | ||||
| 					logger.Errorf("download image with error: %v", err) | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				logger.Infof("try download audio: %s", v.AudioURL) | ||||
| 				audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true) | ||||
| 				if err != nil { | ||||
| 					logger.Errorf("download audio with error: %v", err) | ||||
| 					continue | ||||
| 				} | ||||
| 				v.CoverURL = coverURL | ||||
| 				v.AudioURL = audioURL | ||||
| 				v.Progress = 100 | ||||
| 				s.db.Updates(&v) | ||||
| 				s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished}) | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second * 10) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| // SyncTaskProgress 异步拉取任务 | ||||
| func (s *Service) SyncTaskProgress() { | ||||
| 	go func() { | ||||
| 		var jobs []model.SunoJob | ||||
| 		for { | ||||
| 			res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs) | ||||
| 			if res.Error != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for _, job := range jobs { | ||||
| 				task, err := s.QueryTask(job.TaskId, job.Channel) | ||||
| 				if err != nil { | ||||
| 					logger.Errorf("query task with error: %v", err) | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				if task.Code != "success" { | ||||
| 					logger.Errorf("query task with error: %v", task.Message) | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				logger.Debugf("task: %+v", task.Data.Status) | ||||
| 				// 任务完成,删除旧任务插入两条新任务 | ||||
| 				if task.Data.Status == "SUCCESS" { | ||||
| 					var jobId = job.Id | ||||
| 					var flag = false | ||||
| 					tx := s.db.Begin() | ||||
| 					for _, v := range task.Data.Data { | ||||
| 						job.Id = 0 | ||||
| 						job.Progress = 102 // 102 表示资源未下载完成 | ||||
| 						job.Title = v.Title | ||||
| 						job.SongId = v.Id | ||||
| 						job.Duration = int(v.Metadata.Duration) | ||||
| 						job.Prompt = v.Metadata.Prompt | ||||
| 						job.Tags = v.Metadata.Tags | ||||
| 						job.ModelName = v.ModelName | ||||
| 						job.RawData = utils.JsonEncode(v) | ||||
| 						job.CoverURL = v.ImageLargeUrl | ||||
| 						job.AudioURL = v.AudioUrl | ||||
|  | ||||
| 						if err = tx.Create(&job).Error; err != nil { | ||||
| 							logger.Error("create job with error: %v", err) | ||||
| 							tx.Rollback() | ||||
| 							break | ||||
| 						} | ||||
| 						flag = true | ||||
| 					} | ||||
|  | ||||
| 					// 删除旧任务 | ||||
| 					if flag { | ||||
| 						if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil { | ||||
| 							logger.Error("create job with error: %v", err) | ||||
| 							tx.Rollback() | ||||
| 							continue | ||||
| 						} | ||||
| 					} | ||||
| 					tx.Commit() | ||||
|  | ||||
| 				} else if task.Data.FailReason != "" { | ||||
| 					job.Progress = 101 | ||||
| 					job.ErrMsg = task.Data.FailReason | ||||
| 					s.db.Updates(&job) | ||||
| 					s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second * 10) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| type QueryRespVo struct { | ||||
| 	Code    string `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Data    struct { | ||||
| 		TaskId     string `json:"task_id"` | ||||
| 		Action     string `json:"action"` | ||||
| 		Status     string `json:"status"` | ||||
| 		FailReason string `json:"fail_reason"` | ||||
| 		SubmitTime int    `json:"submit_time"` | ||||
| 		StartTime  int    `json:"start_time"` | ||||
| 		FinishTime int    `json:"finish_time"` | ||||
| 		Progress   string `json:"progress"` | ||||
| 		Data       []struct { | ||||
| 			Id       string `json:"id"` | ||||
| 			Title    string `json:"title"` | ||||
| 			Status   string `json:"status"` | ||||
| 			Metadata struct { | ||||
| 				Tags         string      `json:"tags"` | ||||
| 				Type         string      `json:"type"` | ||||
| 				Prompt       string      `json:"prompt"` | ||||
| 				Stream       bool        `json:"stream"` | ||||
| 				Duration     float64     `json:"duration"` | ||||
| 				ErrorMessage interface{} `json:"error_message"` | ||||
| 			} `json:"metadata"` | ||||
| 			AudioUrl          string `json:"audio_url"` | ||||
| 			ImageUrl          string `json:"image_url"` | ||||
| 			VideoUrl          string `json:"video_url"` | ||||
| 			ModelName         string `json:"model_name"` | ||||
| 			DisplayName       string `json:"display_name"` | ||||
| 			ImageLargeUrl     string `json:"image_large_url"` | ||||
| 			MajorModelVersion string `json:"major_model_version"` | ||||
| 		} `json:"data"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) { | ||||
| 	// 读取 API KEY | ||||
| 	var apiKey model.ApiKey | ||||
| 	tx := s.db.Session(&gorm.Session{}).Where("type", "suno"). | ||||
| 		Where("api_url", channel). | ||||
| 		Where("enabled", true). | ||||
| 		Order("last_used_at DESC").First(&apiKey) | ||||
| 	if tx.Error != nil { | ||||
| 		return QueryRespVo{}, errors.New("no available API KEY for Suno") | ||||
| 	} | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId) | ||||
| 	var res QueryRespVo | ||||
| 	r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err) | ||||
| 	} | ||||
|  | ||||
| 	defer r.Body.Close() | ||||
| 	body, _ := io.ReadAll(r.Body) | ||||
| 	err = json.Unmarshal(body, &res) | ||||
| 	if err != nil { | ||||
| 		return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user