mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-10-31 14:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			353 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 49b5906bc7 | ||
|  | 3075bfb7fc | ||
|  | 82e06fad33 | ||
|  | 4a9028747b | ||
|  | 4a8ff0ccf0 | ||
|  | 99341f0484 | ||
|  | f58ac29ad0 | ||
|  | 7060edb3e5 | ||
|  | 41ae411f9b | ||
|  | 79b7fee47c | ||
|  | 0044bf10af | ||
|  | e9348d3611 | ||
|  | b9236e09a7 | ||
|  | 09b38d5f42 | ||
|  | 7bb539a06e | ||
|  | 5cdada8265 | ||
|  | 4147c217b1 | ||
|  | 8dda639b23 | ||
|  | 8487d2c9eb | ||
|  | c5e583b215 | ||
|  | 549f618cff | ||
|  | e9a3510346 | ||
|  | 30e6e963b3 | ||
|  | c72d963f45 | ||
|  | 172d498618 | ||
|  | 313993532e | ||
|  | e53db3582c | ||
|  | 72c6bd3f77 | ||
|  | ca8b349df3 | ||
|  | 1b206c3640 | ||
|  | c60276fc9f | ||
|  | d00a3167c0 | ||
|  | 6b1cd8c30c | ||
|  | 46f12dc9ad | ||
|  | a3e1d8ae21 | ||
|  | 72a066b93e | ||
|  | 0327a829ac | ||
|  | 882e9b8819 | ||
|  | ef58cfadaa | ||
|  | bf958d6113 | ||
|  | 71611273d7 | ||
|  | b27c654311 | ||
|  | 90930ea9f9 | ||
|  | 1ab2185ff1 | ||
|  | 0f2f978d4c | ||
|  | f61963b0b0 | ||
|  | 2aa413960d | ||
|  | aa4bbba5ec | ||
|  | eba61fea2d | ||
|  | 34e3455128 | ||
|  | 07dca3e739 | ||
|  | 4cb4b145f9 | ||
|  | 1ed417cb69 | ||
|  | 6cf91a84ca | ||
|  | 0b566980fc | ||
|  | f86176b342 | ||
|  | c700b32670 | ||
|  | 22641b452a | ||
|  | d3fbb8c19e | ||
|  | e3bb69ff10 | ||
|  | 770360c614 | ||
|  | f302a0478f | ||
|  | a88697b43a | ||
|  | cc6f140812 | ||
|  | 424f2b3bdc | ||
|  | ec0c13a600 | ||
|  | a1f03bec4c | ||
|  | b5bd4a5e0e | ||
|  | 7c2e49bfdb | ||
|  | f80fe6d041 | ||
|  | 72f80a96bc | ||
|  | 2de655a1cf | ||
|  | da2bd4a501 | ||
|  | e0aa62c40d | ||
|  | 9d26a892d1 | ||
|  | 4ece7f2847 | ||
|  | 32368caf1b | ||
|  | e91f54e79e | ||
|  | bb8f4c57c4 | ||
|  | 43bfac99b6 | ||
|  | be379b6d63 | ||
|  | 17f3c9b840 | ||
|  | 24de97fac2 | ||
|  | bf27b44fee | ||
|  | 1802b4fe4d | ||
|  | 241a5c7bc9 | ||
|  | 557d547bf1 | ||
|  | 2e7b75affb | ||
|  | bc21a1d443 | ||
|  | 3fc9e10a24 | ||
|  | 5fa1aa2060 | ||
|  | ff4b267858 | ||
|  | a590d0497f | ||
|  | ac30d906f0 | ||
|  | 5bc071e038 | ||
|  | 88b956cf98 | ||
|  | f725cf4661 | ||
|  | 057cc1e8a6 | ||
|  | de122735b8 | ||
|  | e87ede981c | ||
|  | 606fb498e1 | ||
|  | a0c06e40a4 | ||
|  | aba8f57279 | ||
|  | 960286a350 | ||
|  | 8c93fa51f6 | ||
|  | cb0e7d64ff | ||
|  | 8e7413da97 | ||
|  | a36f14eb94 | ||
|  | f2f9f6e488 | ||
|  | 85068b8ca2 | ||
|  | f2cfcfeefc | ||
|  | 755273a898 | ||
|  | d4a24a0f1d | ||
|  | 92281fcbb7 | ||
|  | 636db4afcc | ||
|  | ba25b8755e | ||
|  | 6399d13a49 | ||
|  | 06fa54fd25 | ||
|  | a335b965d0 | ||
|  | 725adaa7d0 | ||
|  | 7e7e81e974 | ||
|  | 8cfe6bfc17 | ||
|  | 33de83f2ac | ||
|  | 3f856afec8 | ||
|  | 02a9c422fe | ||
|  | ca69341024 | ||
|  | 169bf069ce | ||
|  | 1bee0ab04d | ||
|  | 440d91dd0e | ||
|  | 8168e246a8 | ||
|  | 2ef07574ae | ||
|  | 37392f2bb2 | ||
|  | a80cd3848e | ||
|  | db6ed84451 | ||
|  | 4463cc5963 | ||
|  | d316158fe2 | ||
|  | e02a8d7586 | ||
|  | 9988dff885 | ||
|  | 35ef5674ff | ||
|  | 976da45bce | ||
|  | c83ac48bd2 | ||
|  | 3d159a833e | ||
|  | 4b09878bdd | ||
|  | b0162e6a92 | ||
|  | 8ab15e5dc4 | ||
|  | d2ac807252 | ||
|  | 0af01f6f1f | ||
|  | 013b319fab | ||
|  | 2899ba5949 | ||
|  | a558b7e104 | ||
|  | 7a833e2233 | ||
|  | bf65746d00 | ||
|  | f08a7862de | ||
|  | 023a2c2f09 | ||
|  | 1bcd0f4c1a | ||
|  | a0f3bc8ccb | ||
|  | dea72738c1 | ||
|  | a1d1fe7763 | ||
|  | a39ed9764c | ||
|  | aaa5ba99aa | ||
|  | 2113508b6d | ||
|  | 7fe4212684 | ||
|  | 8bdda64794 | ||
|  | ec08c24dca | ||
|  | a992a5b3b3 | ||
|  | 0f05970141 | ||
|  | e5e762efcd | ||
|  | b3d0c1ef9c | ||
|  | 397078f7ff | ||
|  | 3ad8065e20 | ||
|  | 66c7717f04 | ||
|  | 412f8ecc6c | ||
|  | 51dcf642b3 | ||
|  | bfeea555b2 | ||
|  | 479f94c372 | ||
|  | 0140713e86 | ||
|  | 15b2ec9721 | ||
|  | c9cd082855 | ||
|  | d7c002890c | ||
|  | 348dd22279 | ||
|  | 3e99b4cbf6 | ||
|  | 6968da3ac7 | ||
|  | bf1c1b84c3 | ||
|  | c70314d930 | ||
|  | 9104ca8e49 | ||
|  | 2af33b3630 | ||
|  | 654e795545 | ||
|  | c62ba2451e | ||
|  | d72d1b8a99 | ||
|  | b939d6016b | ||
|  | 36a2626ccc | ||
|  | bd057a4cc9 | ||
|  | dc24a8c781 | ||
|  | 59fa21779b | ||
|  | a140671aad | ||
|  | 5fe8990fb4 | ||
|  | 12799b7159 | ||
|  | 9929746b1d | ||
|  | d70035ff0c | ||
|  | eec90274d8 | ||
|  | e8fff55c42 | ||
|  | 3cf3cdd705 | ||
|  | 9801fce659 | ||
|  | 4c1f51110b | ||
|  | 913d538587 | ||
|  | 9e704365fc | ||
|  | 485bdbc56a | ||
|  | 7000168fd4 | ||
|  | 5694f97a6b | ||
|  | b677d3fac7 | ||
|  | dc6719cf54 | ||
|  | 7de5b55091 | ||
|  | 76c5101092 | ||
|  | 2f8d2f4854 | ||
|  | b1ee34ba0c | ||
|  | 069ad6a09a | ||
|  | bf1403c818 | ||
|  | bcc622a24d | ||
|  | a06a81a415 | ||
|  | d1950acd01 | ||
|  | 039b70eed2 | ||
|  | d8e4308b1b | ||
|  | 434fbb3463 | ||
|  | de3eb8969c | ||
|  | fbd6eac877 | ||
|  | 1fecab177b | ||
|  | b1b385c455 | ||
|  | 3c6e86d04b | ||
|  | 3d2035d08a | ||
|  | da86f916d8 | ||
|  | e7a07f7e92 | ||
|  | b01e6387fc | ||
|  | d86aca0f5d | ||
|  | 09414fe36a | ||
|  | df0e7508db | ||
|  | 92b1f01118 | ||
|  | 8fb8bd932b | ||
|  | 3f74b94784 | ||
|  | e9467341fa | ||
|  | 131e051ddc | ||
|  | f626fe3166 | ||
|  | 6bc57b6132 | ||
|  | d972e97c88 | ||
|  | 3991f4daec | ||
|  | f6b567d6fc | ||
|  | 8addba8203 | ||
|  | 3ab930a107 | ||
|  | de512a5ea2 | ||
|  | 113cfae2dc | ||
|  | 33aebf9cb5 | ||
|  | 6e58ddf681 | ||
|  | cae5c049e4 | ||
|  | ff76e4bd89 | ||
|  | a0a506a3c4 | ||
|  | aa5a4a9977 | ||
|  | abf4f061c1 | ||
|  | 245cd3ee1a | ||
|  | 45cb29d9a0 | ||
|  | d974b1ff0e | ||
|  | 56269170cb | ||
|  | 4290c4ca22 | ||
|  | 7f7c8e831e | ||
|  | 8f057ca9d1 | ||
|  | 4a56621ec3 | ||
|  | a398e7a550 | ||
|  | 96816c12ca | ||
|  | 9984926f69 | ||
|  | a2a6081027 | ||
|  | 5a10ed37a7 | ||
|  | 1a9dd9de0b | ||
|  | 0dae5bef71 | ||
|  | b4413ed726 | ||
|  | 5e1fe88b8b | ||
|  | 91ed41b536 | ||
|  | 024c0032eb | ||
|  | 4a9f7e3bce | ||
|  | cf4dcc34ec | ||
|  | 4d612c15af | ||
|  | 8aec87cc02 | ||
|  | 442e411cde | ||
|  | acec0194de | ||
|  | 8557f5b94a | ||
|  | babef8baae | ||
|  | efd4ab46f5 | ||
|  | ae8239e5de | ||
|  | f0994ba457 | ||
|  | dae91ed243 | ||
|  | de42a428e6 | ||
|  | 63c7041e1f | ||
|  | b1263ddc69 | ||
|  | 7e50e17aaf | ||
|  | a7265c4251 | ||
|  | 6f39f639bd | ||
|  | a7db123437 | ||
|  | 241c714a8b | ||
|  | 67ac3cfe32 | ||
|  | c926e0afcc | ||
|  | 5bc07e6d57 | ||
|  | c3666a9a71 | ||
|  | 23b5ffa97d | ||
|  | a2c7a75705 | ||
|  | d68f2ef12c | ||
|  | 67d30353f0 | ||
|  | 4813163eac | ||
|  | 5c5210625e | ||
|  | a4a1eec30b | ||
|  | d35164506a | ||
|  | 1ed08f01ea | ||
|  | eca07ab830 | ||
|  | 3512715704 | ||
|  | 6d07881141 | ||
|  | 251fe626f2 | ||
|  | 5fee3a9288 | ||
|  | 9b68d8101e | ||
|  | cfe6f27d48 | ||
|  | b314dd0900 | ||
|  | 950fab6374 | ||
|  | 9d1f5c42ce | ||
|  | a84046390b | ||
|  | aa29323a8a | ||
|  | d5617b7c3a | ||
|  | 1ef60a9e5e | ||
|  | fb6e395ad8 | ||
|  | d9216060bc | ||
|  | bcaa9a92e5 | ||
|  | 576adc9036 | ||
|  | 00de18be9a | ||
|  | c61d32816a | ||
|  | f3fbb0b89c | ||
|  | e311a39632 | ||
|  | 51407abe44 | ||
|  | 8a470b1038 | ||
|  | baddabaa16 | ||
|  | 427b434ce3 | ||
|  | 5f921965e6 | ||
|  | 1e705c8ed5 | ||
|  | b8ae65bb30 | ||
|  | 321e2087ea | ||
|  | aac60edce2 | ||
|  | 9dc9a6923e | ||
|  | 7ca4dfe09b | ||
|  | c584b82ddb | ||
|  | 5f17ab2501 | ||
|  | c84e912dd8 | ||
|  | 2ebff2623f | ||
|  | 72418ce4d7 | ||
|  | e221b1eed4 | ||
|  | 696306f066 | ||
|  | 1807d5b5d4 | ||
|  | 85c12aa322 | ||
|  | da9d0dc3bc | ||
|  | daaca822ac | ||
|  | 59ced3f947 | 
							
								
								
									
										6
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| deploy | ||||
| docs | ||||
| api/static  | ||||
| web/node_modules | ||||
| desktop | ||||
|  | ||||
							
								
								
									
										146
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										146
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -1,6 +1,151 @@ | ||||
| # 更新日志 | ||||
| ## v4.0.0 | ||||
| 非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。 | ||||
| 只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力... | ||||
|  | ||||
| * 功能重构:重构整体系统,全部采用算力来进行结算 | ||||
| * 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽 | ||||
| * 功能优化:移动端聊天页面图片支持预览和放大功能 | ||||
| * 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题 | ||||
| * 功能优化:**PC端不登录也可以预览功能,只有在发起操作的时候才需要登录** | ||||
| * 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能 | ||||
| * 功能新增:支持H5支付 | ||||
| * 功能优化:支持数学公式的识别和美化输出 | ||||
| * 功能新增:新增算力消费日志功能 | ||||
| * 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务 | ||||
| * 功能新增:管理后台新增7日内新增用户和新增订单统计 | ||||
|  | ||||
| ## v3.2.7 | ||||
| * 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能 | ||||
| * 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能 | ||||
| * Bug修复:修复 issue [ | ||||
|   管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532) | ||||
| * 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型 | ||||
| * Bug修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144) | ||||
| * Bug修复:重新生成按钮功能失效问题 | ||||
| * Bug修复:对话输入HTML标签不显示的问题 | ||||
| * 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY | ||||
| * 功能新增:新增删除文件功能 | ||||
| * Bug修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载 | ||||
| * 功能新增:后台管理新怎对话查看和检索功能 | ||||
|  | ||||
| ## v3.2.6 | ||||
| * 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号 | ||||
| * 功能优化:兼用旧版本微信收款消息解析 | ||||
| * 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源 | ||||
| * 功能新增:新增图片发布功能,画廊只显示用户已发布的图片 | ||||
| * 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码 | ||||
| * 功能新增:新增网站公告,可以在管理后台自定义配置 | ||||
| * 功能新增:新增阿里通义千问大模型支持 | ||||
| * Bug修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug | ||||
| * 功能优化:新增虎皮椒和PayJS订单状态校验功能,增加安全性 | ||||
| * Bug修复:修复微信转账交易 ID 提取失败 Bug | ||||
| * 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug | ||||
| * 功能新增:新增短信宝短信平台发送平台集成 | ||||
|  | ||||
|  | ||||
| ## v3.2.5 | ||||
| * 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。 | ||||
| * 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅 Plus 账号了!!! | ||||
| * 功能优化:增强 markdown 图片和引用块解析。 | ||||
| * 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。 | ||||
| * 功能优化:function call 兼用中转 API。 | ||||
| * Bug修复:修复部分已知的 Bug。 | ||||
|  | ||||
| ## v3.2.4.1 | ||||
| * 功能新增:新增 PayJs 支付通道 | ||||
| * Bug修复:紧急修复后台添加用户失败问题 | ||||
| * Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题 | ||||
| * Bug修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖 | ||||
|  | ||||
| ## v3.2.4 | ||||
|  | ||||
| * 功能新增:重磅更新,支持邮箱注册 | ||||
| * 功能优化:优化函数调用授权 | ||||
| * 功能优化:给用户表新增 nickname 字段 | ||||
| * 功能优化:管理后台给聊天角色增加启用/禁用开关 | ||||
| * Bug修复:SD绘画出现重复扣减绘图次数 | ||||
| * 功能优化:优化聊天对话导出样式,适应移动端 | ||||
| * 功能新增:众筹核销可以选择兑换对话还是绘图的额度 | ||||
| * Bug修复:修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92) | ||||
| * Bug修复:修复 MidJourney 绘图任务调度Bug,为 task_id 建议唯一索引 | ||||
| * 功能重构:重构了 API KEY模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。 | ||||
|  | ||||
| ## v3.2.3 | ||||
|  | ||||
| * 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现 | ||||
| * 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数 | ||||
| * Bug修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73) | ||||
| * Bug修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道 | ||||
| * 功能新增:新增删除绘图任何和图片功能 | ||||
| * Bug修复:修复虎皮椒支付二维码重复扫码时报错问题 | ||||
| * 功能优化:自动将 AI 绘画中的中文提示词翻译成英文 | ||||
| * 功能优化:优化AI绘画的大图压缩算法,新增图片缓存 | ||||
| * 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率 | ||||
| * Bug修复:修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71) | ||||
| * Bug修复:修复科大讯飞 V1.5 API 请求失败的问题 | ||||
| * Bug修复:绘图失败后,自动恢复用户的剩余绘图次数 | ||||
| * 功能新增:为移动端新增 SD 绘图功能,分享功能 | ||||
|  | ||||
| ## v3.2.2 | ||||
|  | ||||
| * 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务 | ||||
| * 功能新增:AI绘画页面增加翻译和重写提示词功能 | ||||
| * 功能优化:OSS上传组件支持在 Bucket 下设置二级目录 | ||||
| * Bug修复:修复阿里云 OSS 访问路径错误 | ||||
| * 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket | ||||
|  | ||||
| ## v3.2.1 | ||||
|  | ||||
| * 功能优化:切换角色和模型的时候自动创建新的对话 | ||||
| * Bug修复:修复文件上传失败No such file bug | ||||
| * 功能新增:MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数 | ||||
| * Bug修复:[PC端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59) | ||||
| * 功能新增:增加 arm64 架构打包脚本 | ||||
| * 功能新增:支持 dall-e3 绘图的 API 地址自定义配置 | ||||
| * 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道 | ||||
|  | ||||
| ## v3.2.0 | ||||
|  | ||||
| * 功能新增:新增邀请注册功能 | ||||
| * 功能优化:增加中间件自动对HTTP请求的参数去掉首尾空格 | ||||
| * 功能优化:增加中间件自动为大图片生成缩略图 | ||||
| * 功能优化:MidJourney 页面图片加载优化,实现图片预览懒加载 | ||||
| * 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件 | ||||
| * Bug修复:修复阿里云 OSS 域名设置不起做用的bug | ||||
| * Bug修复:修复MidJourney绘图失败后重复添加到队列的问题 | ||||
|  | ||||
| ## v3.1.9 | ||||
|  | ||||
| * 功能新增:增加讯飞星火大模型 v3.0 支持 | ||||
| * 功能新增:新增找回密码功能 | ||||
| * 功能新增:支持 Markdown 代码复制功能 | ||||
| * Bug修复: xxl-job 任务调度失败的 Bug | ||||
| * 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font | ||||
| * Bug修复:Stable-Diffusion 绘画成功之后没有扣减用户画图次数 | ||||
| * 功能优化:优化会员充值页面 ItemList 组件 | ||||
| * 功能优化:给首页 Logo 增加链接 | ||||
| * Bug修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51) | ||||
| * Bug修复:聊天上下文失效问题 | ||||
| * 功能优化:关闭注册时显示联系管理员二维码 | ||||
| * 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能 | ||||
| * Bug修复:后台启用用户 VIP 不生效问题 | ||||
| * 功能优化:充值支付页面的支付说明文字可以后台配置 | ||||
| * Bug修复:ChatGLM,百度文心,科大讯飞模型输出代码不换行问题 | ||||
|  | ||||
| ## v3.1.8 | ||||
|  | ||||
| 1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道 | ||||
| 2. Bug修复:修复 MidJourney API 参数版本更新导致调用失败问题 | ||||
| 3. Bug修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题 | ||||
| 4. Bug修复:修复七牛云上传报错 expired token | ||||
| 5. Bug修复:修复高权重模型导致的对话次数为负数的漏洞 | ||||
| 6. 功能优化:将聊天报错信息定义为统一常量,方便修改 | ||||
| 7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式 | ||||
| 8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单 | ||||
|  | ||||
| ## v3.1.7 | ||||
|  | ||||
| 1. 功能新增:支持文心4.0 AI 模型 | ||||
| 2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型 | ||||
| 3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如GPT4是GPT3.5的10倍 | ||||
| @@ -8,6 +153,7 @@ | ||||
| 5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式 | ||||
|  | ||||
| ## v3.1.6 | ||||
|  | ||||
| 1. 功能新增:新增AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品 | ||||
| 2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用 | ||||
| 3. 功能优化:优化瀑布流组件的页面布局 | ||||
|   | ||||
							
								
								
									
										365
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										365
									
								
								README.md
									
									
									
									
									
								
							| @@ -8,8 +8,10 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 | ||||
| * 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。 | ||||
| * 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。 | ||||
| * 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。 | ||||
| * 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。(可定制开发其他支付通道支持) | ||||
| * 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI 绘画函数插件。 | ||||
| * 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。 | ||||
| * 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。 | ||||
| * 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI | ||||
|   绘画函数插件。 | ||||
|  | ||||
| ## 功能截图 | ||||
|  | ||||
| @@ -22,17 +24,28 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 | ||||
|  | ||||
|  | ||||
| ### MidJourney 专业绘画界面 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### Stable-Diffusion 专业绘画页面 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 绘图作品展 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### AI应用列表 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 会员充值 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 自动调用函数插件 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -50,50 +63,42 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 | ||||
|  | ||||
|  | ||||
|  | ||||
| ### 7. 体验地址 | ||||
| ### 体验地址 | ||||
|  | ||||
| > 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/> | ||||
| > **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!** | ||||
|  | ||||
| ## 快速部署 | ||||
|  | ||||
| **演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。 | ||||
|  | ||||
| ```shell | ||||
| bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)" | ||||
| ``` | ||||
|  | ||||
| 最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。 | ||||
|  | ||||
| 目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址 | ||||
|  | ||||
| * 前端访问地址:http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。 | ||||
| * 后台管理地址:http://localhost:8080/admin | ||||
| * 移动端地址:http://localhost:8080/mobile | ||||
| * 初始后台管理账号:admin/admin123 | ||||
| * 初始前端体验账号:18575670125/12345678 | ||||
|  | ||||
| 服务启动成功之后不能立刻使用,需要先登录管理后台 -> API-KEY 去添加一个 OpenAI 或者文心一言,科大讯飞等至少一个平台的 API | ||||
| KEY。 | ||||
|  | ||||
|  | ||||
|  | ||||
| 另外,如果您目前还没有 OpenAI 的 API KEY的,推荐您去 https://gpt.bemore.lol 购买,**无需魔法,高速稳定,且价格还远低于 OpenAI | ||||
| 官方**。 | ||||
|  | ||||
| ## 使用须知 | ||||
|  | ||||
| 1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。 | ||||
| 2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。 | ||||
|  | ||||
| ## 项目介绍 | ||||
|  | ||||
| 这一套完整的系统,包括前端聊天应用和一个后台管理系统。系统有用户鉴权,你可以自己使用,也可以部署直接给 C 端用户提供 | ||||
| ChatGPT 的服务。 | ||||
|  | ||||
| ### 项目的技术架构 | ||||
|  | ||||
| 新版的系统前后端都进行大改动的重构,后端还是用的 Gin Web 框架,但是作者整合了 fx 自动注入框架,整个后端应用结构非常简洁,特别适合二次开发。 | ||||
| 另外,数据存储用 MySQL 替换了 leveldb, 因为要对 C 端,后期会涉及到很多业务数据查询统计,leveldb 已经完全不够用了。 | ||||
|  | ||||
| > Gin + fx + MySQL | ||||
|  | ||||
| 3.0 版本之后会陆续添加其他语言的 API 实现,比如 PHP,Java 等。考虑到作者精力有限,api 目录已经添加了,有兴趣的同学自主去认领各自擅长的语言去实现。 | ||||
|  | ||||
| 前端的框架还是: | ||||
|  | ||||
| > Vue3 + Element-Plus | ||||
|  | ||||
| 前后台的页面风格已经全部变了,几乎所有页面样式代码都重写了。逻辑代码还是沿用之前的,毕竟功能没有太大的变化。 | ||||
|  | ||||
| 此次重构改版主要是为了后面功能的扩展准备了。 | ||||
|  | ||||
| 新版本已经实现的功能如下: | ||||
|  | ||||
| 1. 引入用户体系,新增用户注册和登录功能。 | ||||
| 2. 聊天页面改版,实现了跟 ChatGPT 官方版本一致的聊天体验。 | ||||
| 3. 创建会话的时候可以选择聊天角色和模型。 | ||||
| 4. 新增聊天设置功能,用户可以导入自己的 API KEY | ||||
| 5. 保存聊天记录,支持聊天上下文。 | ||||
| 6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。 | ||||
| 7. 引入 ip2region 组件,记录用户的登录IP和地址。 | ||||
| 8. 支持会话搜索过滤。 | ||||
| 9. 支持微信支付充值 | ||||
|  | ||||
| ## 项目地址 | ||||
|  | ||||
| * Github 地址:https://github.com/yangjian102621/chatgpt-plus | ||||
| @@ -105,291 +110,25 @@ ChatGPT 的服务。 | ||||
|  | ||||
| ## TODOLIST | ||||
|  | ||||
| * [x] 整合 Midjourney AI 绘画 API | ||||
| * [x] 开发移动端聊天页面 | ||||
| * [x] 接入微信收款功能 | ||||
| * [x] 支持 ChatGPT 函数功能,通过函数实现插件 | ||||
| * [x] 开发桌面版应用 | ||||
| * [x] 开发手机 App 客户端 | ||||
| * [x] 支付宝支付功能 | ||||
| * [ ] 支持基于知识库的 AI 问答 | ||||
| * [ ] 会员推广功能 | ||||
| * [ ] 会员邀请注册推广功能 | ||||
| * [ ] 微信支付功能 | ||||
|  | ||||
| ## Docker 快速部署 | ||||
| ## 项目文档 | ||||
|  | ||||
| > | ||||
| 鉴于最新不少网友反馈在部署的时候遇到一些问题,大部分问题都是相同的,所以我这边做了一个视频教程 [五分钟部署自己的 ChatGPT 服务](https://www.bilibili.com/video/BV1H14y1B7Qw/)。 | ||||
| > 习惯看视频教程的朋友可以去看视频教程,视频的语速比较慢,建议 2 倍速观看。 | ||||
| 最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/) | ||||
|  | ||||
| V3.0.0 版本以后已经支持使用容器部署了,跳过所有的繁琐的环境准备,一条命令就可以轻松部署上线。 | ||||
| 详细的部署和开发文档请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/)。 | ||||
|  | ||||
| ### 1. 导入数据库 | ||||
| 加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。** | ||||
|  | ||||
| 首先我们需要创建一个 MySQL 容器,并导入初始数据库。 | ||||
|  | ||||
| ```shell | ||||
| cd docker/mysql | ||||
| # 创建 mysql 容器 | ||||
| docker-compose up -d | ||||
| # 导入数据库 | ||||
| docker exec -i chatgpt-plus-mysql sh -c 'exec mysql -uroot -p12345678' < ../../database/chatgpt_plus-v3.1.7.sql | ||||
| ``` | ||||
|  | ||||
| 如果你本地已经安装了 MySQL 服务,那么你只需手动导入数据库即可。 | ||||
|  | ||||
| ```shell | ||||
| # 连接数据库 | ||||
| mysql -u username -p password | ||||
| # 导入数据库 | ||||
| source database/chatgpt_plus.sql | ||||
| ``` | ||||
|  | ||||
| ### 2. 修改配置文档 | ||||
|  | ||||
| 修改配置文档 `docker/conf/config.toml` 配置文档,修改代理地址和管理员密码: | ||||
|  | ||||
| ```toml | ||||
| Listen = "0.0.0.0:5678" | ||||
| ProxyURL = "" # 如 http://127.0.0.1:7777 | ||||
| MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local" | ||||
| StaticDir = "./static" # 静态资源的目录 | ||||
| StaticUrl = "/static" # 静态资源访问 URL | ||||
| AesEncryptKey = "" | ||||
| WeChatBot = false # 是否启动微信机器人 | ||||
|  | ||||
| [Session] | ||||
|   SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换 | ||||
|   MaxAge = 86400 | ||||
|  | ||||
| [Manager] | ||||
|   Username = "admin" | ||||
|   Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改 | ||||
|    | ||||
| [Redis] # redis 配置信息 | ||||
|   Host = "localhost"  | ||||
|   Port = 6379 | ||||
|   Password = "" | ||||
|   DB = 0 | ||||
|    | ||||
| [ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通 | ||||
|   ApiURL = "" | ||||
|   AppId = "" | ||||
|   Token = "" | ||||
|  | ||||
| [SmsConfig] # 阿里云短信服务配置 | ||||
|   AccessKey = "" | ||||
|   AccessSecret = "" | ||||
|   Product = "Dysmsapi" | ||||
|   Domain = "dysmsapi.aliyuncs.com" | ||||
|  | ||||
| [ExtConfig] # MidJourney和微信机器人服务 API 配置,开通此功能需要配合 chatpgt-plus-exts 项目部署 | ||||
|   ApiURL = "" # 插件扩展 API 地址 | ||||
|   Token = "" # 这个 token 随便填,只要确保跟 chatgpt-plus-exts 项目的 token 一样就行  | ||||
|    | ||||
| [OSS] # OSS 配置,用于存储 MJ 绘画图片 | ||||
|    Active = "local" # 默认使用本地文件存储引擎 | ||||
|    [OSS.Local] | ||||
|      BasePath = "./static/upload" # 本地文件上传根路径 | ||||
|      BaseURL = "http://localhost:5678/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可 | ||||
|    [OSS.Minio] | ||||
|      Endpoint = "" # 如 172.22.11.200:9000 | ||||
|      AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key | ||||
|      AccessSecret = "" | ||||
|      Bucket = "chatgpt-plus" # 替换为你自己创建的 Bucket,注意要给 Bucket 设置公开的读权限,否则会出现图片无法显示。 | ||||
|      UseSSL = false | ||||
|      Domain = "" # 地址必须是能够通过公网访问的,否则会出现图片无法显示。 | ||||
|    [OSS.QiNiu] # 七牛云 OSS 配置 | ||||
|        Zone = "z2" # 区域,z0:华东,z1: 华北,na0:北美,as0:新加坡 | ||||
|        AccessKey = "" | ||||
|        AccessSecret = "" | ||||
|        Bucket = "" | ||||
|        Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com | ||||
|         | ||||
| [MjConfig] # MidJourney AI 绘画配置 | ||||
|   Enabled = false # 是否启动 MidJourney 机器人服务 | ||||
|   UserToken = "" # 用户授权 Token | ||||
|   BotToken = "" # Discord 机器人 Token | ||||
|   GuildId = "" # 服务器 ID | ||||
|   ChanelId = "" # 频道 ID | ||||
|  | ||||
| [SdConfig] | ||||
|   Enabled = false # 是否启动 Stable Diffusion 机器人服务 | ||||
|   ApiURL = "http://172.22.11.200:7860" # stable-diffusion-webui API 地址 | ||||
|   ApiKey = "" # 如果开启了授权,这里需要配置授权的 ApiKey | ||||
|   Txt2ImgJsonPath = "res/text2img.json" # 文生图的 API 请求报文 json 模板,允许自定义请求json报文,因为不同版本的 API 绘图的参数以及 fn_index 会不同。 | ||||
| ``` | ||||
|  | ||||
| > 1. 如果你不知道如何获取 Discord 用户 Token 和 Bot Token | ||||
|      请查参考 [Midjourney|如何集成到自己的平台](https://zhuanlan.zhihu.com/p/631079476)。 | ||||
| > 2. `Txt2ImgJsonPath` | ||||
|      的默认用的是使用最广泛的 [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 项目的 | ||||
|      API,如果你用的是其他版本,比如秋叶的懒人包部署的,那么请将对应的 text2img 的参数报文复制放在 `res/text2img.json` | ||||
|      文件中即可。 | ||||
|  | ||||
| 修改 nginx 配置文档 `docker/conf/nginx/conf.d/chatgpt-plus.conf`,把后端转发的地址改成当前主机的内网 IP 地址。 | ||||
|  | ||||
| ```shell | ||||
|  # 这里配置后端 API 的转发 | ||||
| location /api/ { | ||||
|        proxy_http_version 1.1; | ||||
|        proxy_connect_timeout 300s; | ||||
|        proxy_read_timeout 300s; | ||||
|        proxy_send_timeout 12s; | ||||
|        proxy_set_header Host $host; | ||||
|        proxy_set_header X-Real-IP $remote_addr; | ||||
|        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; | ||||
|        proxy_set_header Upgrade $http_upgrade; | ||||
|        proxy_set_header Connection $connection_upgrade; | ||||
|        proxy_pass http://172.28.173.76:6789; # 这里改成后端服务的内网 IP 地址 | ||||
|         | ||||
| # 静态资源转发 | ||||
| location /static/ { | ||||
|    proxy_pass http://172.22.11.47:5678; # 这里改成后端服务的内网 IP 地址 | ||||
| } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ### 3. 启动应用 | ||||
|  | ||||
| 先修改 `docker/docker-compose.yaml` 文件中的镜像地址,改成最新的版本: | ||||
|  | ||||
| ```yaml | ||||
| version: '3' | ||||
| services: | ||||
|   # 后端 API 镜像 | ||||
|   chatgpt-plus-api: | ||||
|     image: registry.cn-shenzhen.aliyuncs.com/geekmaster/chatgpt-plus-api:v3.1.5 #这里改成最新的 release 版本地址 | ||||
|     container_name: chatgpt-plus-api | ||||
|     restart: always | ||||
|     environment: | ||||
|       - DEBUG=false | ||||
|       - LOG_LEVEL=info | ||||
|       - CONFIG_FILE=config.toml | ||||
|     ports: | ||||
|       - "5678:5678" | ||||
|     volumes: | ||||
|       - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime | ||||
|       - ./conf/config.toml:/var/www/app/config.toml | ||||
|       - ./logs:/var/www/app/logs | ||||
|       - ./static:/var/www/app/static | ||||
|  | ||||
|   # 前端应用镜像 | ||||
|   chatgpt-plus-web: | ||||
|     image: registry.cn-shenzhen.aliyuncs.com/geekmaster/chatgpt-plus-web:v3.1.5 #这里改成最新的 release 版本地址 | ||||
|     container_name: chatgpt-plus-web | ||||
|     restart: always | ||||
|     ports: | ||||
|       - "8080:8080" # 这边是对外的端口,支持 8080,80和443 | ||||
|     volumes: | ||||
|       - ./logs/nginx:/var/log/nginx | ||||
|       - ./conf/nginx/conf.d:/etc/nginx/conf.d | ||||
|       - ./conf/nginx/nginx.conf:/etc/nginx/nginx.conf | ||||
|       - ./ssl:/etc/nginx/ssl | ||||
| ``` | ||||
|  | ||||
| ```shell | ||||
| cd docker | ||||
| docker-compose up -d | ||||
| ``` | ||||
|  | ||||
| * 前端访问地址:http://localhost:8080/chat | ||||
| * 后台管理地址:http://localhost:8080/admin | ||||
| * 移动端地址:http://localhost:8080/mobile | ||||
|  | ||||
| > 注意:你得访问后台管理系统 http://localhost:8080/admin | ||||
| > 输入你前面配置文档中设置的管理员用户名和密码登录。 | ||||
| > 然后进入 `API KEY 管理` 菜单,添加一个 OpenAI 的 API KEY 才可以正常开启 AI 对话。 | ||||
|  | ||||
|  | ||||
|  | ||||
| 最后进入前端聊天页面 [http://localhost:8080/chat](http://localhost:8080/chat) | ||||
| 你可以注册新用户,也可以使用系统默认有个账号:`18575670125/12345678` 登录聊天。 | ||||
|  | ||||
| 祝你使用愉快!!! | ||||
|  | ||||
| ## 本地开发调试 | ||||
|  | ||||
| 本地开发同样要分别运行前端和后端程序。 | ||||
|  | ||||
| ### 运行后端程序 | ||||
|  | ||||
| 1. 同样你首先要 [导入数据库](#1-导入数据库) | ||||
| 2. 然后 [修改配置文档](#2-修改配置文档) | ||||
| 3. 运行后端程序: | ||||
|  | ||||
|     ```shell | ||||
|     cd api  | ||||
|     # 1. 先下载依赖 | ||||
|     go mod tidy | ||||
|     # 2. 运行程序 | ||||
|     go run main.go | ||||
|     # 如果你安装了 fresh 可以使用 fresh 实现热启动 | ||||
|     fresh -c fresh.conf | ||||
|     ``` | ||||
|  | ||||
| ### 运行前端程序 | ||||
|  | ||||
| 同样先拷贝配置文档: | ||||
|  | ||||
| ```shell | ||||
| cd web | ||||
| cp .env.production .env.development | ||||
| ``` | ||||
|  | ||||
| 编辑 `.env.development` 文件,修改后端 API 的访问路径: | ||||
|  | ||||
| ```ini | ||||
| VUE_APP_API_HOST=http://localhost:5678 | ||||
| VUE_APP_WS_HOST=ws://localhost:5678 | ||||
| ``` | ||||
|  | ||||
| 配置好了之后就可以运行前端应用了: | ||||
|  | ||||
| ``` | ||||
| # 安装依赖 | ||||
| npm install | ||||
| # 运行 | ||||
| npm run dev | ||||
| ``` | ||||
|  | ||||
| * 前端页面:http://localhost:8888/chat | ||||
| * 后台管理页面:http://localhost:8888/admin | ||||
|  | ||||
| ## 项目打包 | ||||
|  | ||||
| 由于本项目是采用异构开发的方式,所项目打包分成两步:首先编译后端程序,然后再打包前端应用。 | ||||
|  | ||||
| ### 打包前端 | ||||
|  | ||||
| ```shell | ||||
| cd web | ||||
| npm run build | ||||
| ``` | ||||
|  | ||||
| ### 打包后端 | ||||
|  | ||||
| 你可以根据个人需求将项目打包成 windows/linux/darwin 平台项目。 | ||||
|  | ||||
| ```shell | ||||
| cd api | ||||
| # for all platforms | ||||
| make clean all | ||||
| # for linux only | ||||
| make clean linux | ||||
| ``` | ||||
|  | ||||
| 打包后的可执行文件在 `bin` 目录下。 | ||||
|  | ||||
|  | ||||
| ## 参与贡献 | ||||
|  | ||||
| 个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。 | ||||
|  | ||||
| 如果有兴趣的话,也可以加微信进入微信讨论群(**添加好友时请注明来自Github!!!**)。 | ||||
|  | ||||
|  | ||||
|  | ||||
| #### 特此声明:不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合! | ||||
| #### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合! | ||||
|  | ||||
| ### Commit 类型 | ||||
|  | ||||
| @@ -405,10 +144,6 @@ make clean linux | ||||
|  | ||||
| 如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~ | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										19
									
								
								api/Makefile
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								api/Makefile
									
									
									
									
									
								
							| @@ -1,19 +1,14 @@ | ||||
| SHELL=/usr/bin/env bash | ||||
| NAME := chatgpt-plus | ||||
| all: window linux darwin | ||||
| all: amd64 arm64 | ||||
|  | ||||
| amd64: | ||||
| 	CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go | ||||
| .PHONY: amd64 | ||||
|  | ||||
| window: | ||||
| 	CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/$(NAME)-amd64.exe main.go | ||||
| .PHONY: window | ||||
|  | ||||
| linux: | ||||
| 	CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-amd64-linux main.go | ||||
| .PHONY: linux | ||||
|  | ||||
| darwin: | ||||
| 	CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/$(NAME)-amd64-darwin main.go | ||||
| .PHONY: darwin | ||||
| arm64: | ||||
| 	CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go | ||||
| .PHONY: arm64 | ||||
|  | ||||
| clean: | ||||
| 	rm -rf bin/$(NAME)-* | ||||
|   | ||||
| @@ -10,10 +10,6 @@ WeChatBot = false | ||||
|   SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换 | ||||
|   MaxAge = 86400 | ||||
|  | ||||
| [Manager] | ||||
|   Username = "admin" | ||||
|   Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改 | ||||
|  | ||||
| [Redis] # redis 配置信息 | ||||
|   Host = "localhost" | ||||
|   Port = 6379 | ||||
| @@ -25,23 +21,28 @@ WeChatBot = false | ||||
|   AppId = "" | ||||
|   Token = "" | ||||
|  | ||||
| [SmsConfig] # 阿里云短信服务配置 | ||||
|   AccessKey = "" | ||||
|   AccessSecret = "" | ||||
|   Product = "Dysmsapi" | ||||
|   Domain = "dysmsapi.aliyuncs.com" | ||||
|   Sign = "" | ||||
|   CodeTempId = "" | ||||
|  | ||||
| [ExtConfig] # MidJourney和微信机器人服务 API 配置,开通此功能需要配合 chatpgt-plus-exts 项目部署 | ||||
|   ApiURL = "" # 插件扩展 API 地址 | ||||
|   Token = "" # 这个 token 随便填,只要确保跟 chatgpt-plus-exts 项目的 token 一样就行 | ||||
| [SMS] # Sms 配置,用于发送短信 | ||||
|    Active = "Ali" # 当前启用的短信服务,默认使用阿里云 | ||||
|    [SMS.Bao] | ||||
|       Username = "" | ||||
|       Password = "" | ||||
|       Domain = "api.smsbao.com" | ||||
|       Sign = "【极客学长】" | ||||
|       CodeTemplate = "您的验证码是{code}。5分钟有效,若非本人操作,请忽略本短信。" | ||||
|    [SMS.Ali] | ||||
|       AccessKey = "" | ||||
|       AccessSecret = "" | ||||
|       Product = "Dysmsapi" | ||||
|       Domain = "dysmsapi.aliyuncs.com" | ||||
|       Sign = "" | ||||
|       CodeTempId = "" | ||||
|  | ||||
| [OSS] # OSS 配置,用于存储 MJ 绘画图片 | ||||
|    Active = "local" # 默认使用本地文件存储引擎 | ||||
|    [OSS.Local] | ||||
|      BasePath = "./static/upload" # 本地文件上传根路径 | ||||
|      BaseURL = "http://localhost:5678/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可 | ||||
|      BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP | ||||
|    [OSS.Minio] | ||||
|      Endpoint = "" # 如 172.22.11.200:9000 | ||||
|      AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key | ||||
| @@ -55,16 +56,77 @@ WeChatBot = false | ||||
|        AccessSecret = "" | ||||
|        Bucket = "" | ||||
|        Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com | ||||
|    [OSS.AliYun] | ||||
|        Endpoint = "oss-cn-hangzhou.aliyuncs.com" | ||||
|        AccessKey = "" | ||||
|        AccessSecret = "" | ||||
|        Bucket = "chatgpt-plus" | ||||
|        SubDir = "" | ||||
|        Domain = "" | ||||
|  | ||||
| [MjConfig] | ||||
| [[MjConfigs]] | ||||
|   Enabled = false | ||||
|   UserToken = "" | ||||
|   BotToken = "" | ||||
|   GuildId = "" | ||||
|   ChanelId = "" | ||||
|   UseCDN = false #是否使用反向代理访问,设置为true下面的设置才会生效 | ||||
|   DiscordAPI = "" # discord API 反代地址 | ||||
|   DiscordCDN = "" # mj 图片反代地址 | ||||
|   DiscordGateway = "" # discord 机器人反代地址 | ||||
|  | ||||
| [SdConfig] | ||||
| [[MjPlusConfigs]] | ||||
|   Enabled = false | ||||
|   ApiURL = "http://172.22.11.200:7860" | ||||
|   ApiURL = "https://api.chat-plus.net" | ||||
|   CdnURL = "" # CND 加速的 URL,如果有的话就设置 | ||||
|   Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo | ||||
|   ApiKey = "sk-xxx" | ||||
|   NotifyURL = "https://ai.r9it.com/api/mj/notify" # 这里需要改成你的域名 | ||||
|  | ||||
| [[SdConfigs]] | ||||
|   Enabled = false | ||||
|   ApiURL = "" | ||||
|   ApiKey = "" | ||||
|   Txt2ImgJsonPath = "res/text2img.json" | ||||
|   Txt2ImgJsonPath = "res/sd/text2img.json" | ||||
|  | ||||
| [XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动 | ||||
|   Enabled = false # 是否启用 XXL JOB 服务 | ||||
|   ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址 | ||||
|   ExecutorIp = "172.22.11.47" # 执行器 IP 地址 | ||||
|   ExecutorPort = "9999" # 执行器服务端口 | ||||
|   AccessToken = "xxl-job-api-token" # 执行器 API 通信 token | ||||
|   RegistryKey = "chatgpt-plus" # 任务注册 key | ||||
|  | ||||
| [AlipayConfig] | ||||
|   Enabled = false # 启用支付宝支付通道 | ||||
|   SandBox = false # 是否启用沙盒模式 | ||||
|   UserId = "2088721020750581" # 商户ID | ||||
|   AppId = "9021000131658023" # App Id | ||||
|   PrivateKey = "certs/alipay/privateKey.txt" # 应用私钥 | ||||
|   PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书 | ||||
|   AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书 | ||||
|   RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书 | ||||
|   NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址 | ||||
|  | ||||
| [HuPiPayConfig] | ||||
|   Enabled = false | ||||
|   Name = "wechat" | ||||
|   AppId = "" | ||||
|   AppSecret = "" | ||||
|   ApiURL = "https://api.xunhupay.com" | ||||
|   NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify" | ||||
|  | ||||
| [SmtpConfig] # 注意,阿里云服务器禁用了25号端口,所以如果需要使用邮件功能,请别用阿里云服务器 | ||||
|   Host = "smtp.163.com" | ||||
|   Port = 25 | ||||
|   AppName = "极客学长" | ||||
|   From = "test@163.com" # 发件邮箱人地址 | ||||
|   Password = "" #邮箱 stmp 服务授权码 | ||||
|  | ||||
| [JPayConfig] # PayJs 支付配置 | ||||
|   Enabled = false | ||||
|   Name = "wechat" # 请不要改动 | ||||
|   AppId = "" # 商户 ID | ||||
|   PrivateKey = "" # 秘钥 | ||||
|   ApiURL = "https://payjs.cn" | ||||
|   NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的 | ||||
| @@ -1,8 +1,8 @@ | ||||
| package core | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/fun" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| @@ -11,9 +11,14 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/nfnt/resize" | ||||
| 	"gorm.io/gorm" | ||||
| 	"image" | ||||
| 	"image/jpeg" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"runtime/debug" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -23,31 +28,28 @@ type AppServer struct { | ||||
| 	Debug        bool | ||||
| 	Config       *types.AppConfig | ||||
| 	Engine       *gin.Engine | ||||
| 	ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message | ||||
| 	ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message | ||||
|  | ||||
| 	ChatConfig *types.ChatConfig   // chat config cache | ||||
| 	SysConfig  *types.SystemConfig // system config cache | ||||
| 	SysConfig *types.SystemConfig // system config cache | ||||
|  | ||||
| 	// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 | ||||
| 	// 防止第三方直接连接 socket 调用 OpenAI API | ||||
| 	ChatSession   *types.LMap[string, *types.ChatSession] //map[sessionId]UserId | ||||
| 	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合 | ||||
| 	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function | ||||
| 	Functions     map[string]fun.Function | ||||
| } | ||||
|  | ||||
| func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer { | ||||
| func NewServer(appConfig *types.AppConfig) *AppServer { | ||||
| 	gin.SetMode(gin.ReleaseMode) | ||||
| 	gin.DefaultWriter = io.Discard | ||||
| 	return &AppServer{ | ||||
| 		Debug:         false, | ||||
| 		Config:        appConfig, | ||||
| 		Engine:        gin.Default(), | ||||
| 		ChatContexts:  types.NewLMap[string, []interface{}](), | ||||
| 		ChatContexts:  types.NewLMap[string, []types.Message](), | ||||
| 		ChatSession:   types.NewLMap[string, *types.ChatSession](), | ||||
| 		ChatClients:   types.NewLMap[string, *types.WsClient](), | ||||
| 		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), | ||||
| 		Functions:     functions, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -57,30 +59,22 @@ func (s *AppServer) Init(debug bool, client *redis.Client) { | ||||
| 		logger.Info("Enabled debug mode") | ||||
| 	} | ||||
| 	s.Engine.Use(corsMiddleware()) | ||||
| 	s.Engine.Use(staticResourceMiddleware()) | ||||
| 	s.Engine.Use(authorizeMiddleware(s, client)) | ||||
| 	s.Engine.Use(parameterHandlerMiddleware()) | ||||
| 	s.Engine.Use(errorHandler) | ||||
| 	// 添加静态资源访问 | ||||
| 	s.Engine.Static("/static", s.Config.StaticDir) | ||||
| } | ||||
|  | ||||
| func (s *AppServer) Run(db *gorm.DB) error { | ||||
| 	// load chat config from database | ||||
| 	var chatConfig model.Config | ||||
| 	res := db.Where("marker", "chat").First(&chatConfig) | ||||
| 	if res.Error != nil { | ||||
| 		return res.Error | ||||
| 	} | ||||
| 	err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// load system configs | ||||
| 	var sysConfig model.Config | ||||
| 	res = db.Where("marker", "system").First(&sysConfig) | ||||
| 	res := db.Where("marker", "system").First(&sysConfig) | ||||
| 	if res.Error != nil { | ||||
| 		return res.Error | ||||
| 	} | ||||
| 	err = utils.JsonDecode(sysConfig.Config, &s.SysConfig) | ||||
| 	err := utils.JsonDecode(sysConfig.Config, &s.SysConfig) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -138,71 +132,64 @@ func corsMiddleware() gin.HandlerFunc { | ||||
| // 用户授权验证 | ||||
| func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		if c.Request.URL.Path == "/api/user/login" || | ||||
| 			c.Request.URL.Path == "/api/admin/login" || | ||||
| 			c.Request.URL.Path == "/api/user/register" || | ||||
| 			c.Request.URL.Path == "/api/reward/notify" || | ||||
| 			c.Request.URL.Path == "/api/mj/notify" || | ||||
| 			c.Request.URL.Path == "/api/chat/history" || | ||||
| 			c.Request.URL.Path == "/api/chat/detail" || | ||||
| 			c.Request.URL.Path == "/api/role/list" || | ||||
| 			c.Request.URL.Path == "/api/mj/jobs" || | ||||
| 			c.Request.URL.Path == "/api/mj/proxy" || | ||||
| 			c.Request.URL.Path == "/api/sd/jobs" || | ||||
| 			strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || | ||||
| 			strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || | ||||
| 			strings.HasPrefix(c.Request.URL.Path, "/static/") || | ||||
| 			c.Request.URL.Path == "/api/admin/config/get" { | ||||
| 			c.Next() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var tokenString string | ||||
| 		if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API | ||||
| 		isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/") | ||||
| 		if isAdminApi { // 后台管理 API | ||||
| 			tokenString = c.GetHeader(types.AdminAuthHeader) | ||||
| 		} else if c.Request.URL.Path == "/api/chat/new" || | ||||
| 			c.Request.URL.Path == "/api/mj/client" || | ||||
| 			c.Request.URL.Path == "/api/sd/client" { | ||||
| 		} else if c.Request.URL.Path == "/api/chat/new" { | ||||
| 			tokenString = c.Query("token") | ||||
| 		} else { | ||||
| 			tokenString = c.GetHeader(types.UserAuthHeader) | ||||
| 		} | ||||
|  | ||||
| 		if tokenString == "" { | ||||
| 			resp.ERROR(c, "You should put Authorization in request headers") | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 			if needLogin(c) { | ||||
| 				resp.ERROR(c, "You should put Authorization in request headers") | ||||
| 				c.Abort() | ||||
| 				return | ||||
| 			} else { // 直接放行 | ||||
| 				c.Next() | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { | ||||
| 			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | ||||
| 			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) { | ||||
| 				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) | ||||
| 			} | ||||
| 			if isAdminApi { | ||||
| 				return []byte(s.Config.AdminSession.SecretKey), nil | ||||
| 			} else { | ||||
| 				return []byte(s.Config.Session.SecretKey), nil | ||||
| 			} | ||||
|  | ||||
| 			return []byte(s.Config.Session.SecretKey), nil | ||||
| 		}) | ||||
|  | ||||
| 		if err != nil { | ||||
| 		if err != nil && needLogin(c) { | ||||
| 			resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err)) | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		claims, ok := token.Claims.(jwt.MapClaims) | ||||
| 		if !ok || !token.Valid { | ||||
| 		if !ok || !token.Valid && needLogin(c) { | ||||
| 			resp.NotAuth(c, "Token is invalid") | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) | ||||
| 		if expr > 0 && int64(expr) < time.Now().Unix() { | ||||
| 		if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) { | ||||
| 			resp.NotAuth(c, "Token is expired") | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		key := fmt.Sprintf("users/%v", claims["user_id"]) | ||||
| 		if _, err := client.Get(context.Background(), key).Result(); err != nil { | ||||
| 		if isAdminApi { | ||||
| 			key = fmt.Sprintf("admin/%v", claims["user_id"]) | ||||
| 		} | ||||
| 		if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) { | ||||
| 			resp.NotAuth(c, "Token is not found in redis") | ||||
| 			c.Abort() | ||||
| 			return | ||||
| @@ -210,3 +197,160 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { | ||||
| 		c.Set(types.LoginUserID, claims["user_id"]) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func needLogin(c *gin.Context) bool { | ||||
| 	if c.Request.URL.Path == "/api/user/login" || | ||||
| 		c.Request.URL.Path == "/api/user/resetPass" || | ||||
| 		c.Request.URL.Path == "/api/admin/login" || | ||||
| 		c.Request.URL.Path == "/api/admin/login/captcha" || | ||||
| 		c.Request.URL.Path == "/api/user/register" || | ||||
| 		c.Request.URL.Path == "/api/chat/history" || | ||||
| 		c.Request.URL.Path == "/api/chat/detail" || | ||||
| 		c.Request.URL.Path == "/api/chat/list" || | ||||
| 		c.Request.URL.Path == "/api/role/list" || | ||||
| 		c.Request.URL.Path == "/api/model/list" || | ||||
| 		c.Request.URL.Path == "/api/mj/imgWall" || | ||||
| 		c.Request.URL.Path == "/api/mj/client" || | ||||
| 		c.Request.URL.Path == "/api/mj/notify" || | ||||
| 		c.Request.URL.Path == "/api/invite/hits" || | ||||
| 		c.Request.URL.Path == "/api/sd/imgWall" || | ||||
| 		c.Request.URL.Path == "/api/sd/client" || | ||||
| 		c.Request.URL.Path == "/api/config/get" || | ||||
| 		c.Request.URL.Path == "/api/product/list" || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/test") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/function/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/api/payment/") || | ||||
| 		strings.HasPrefix(c.Request.URL.Path, "/static/") { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // 统一参数处理 | ||||
| func parameterHandlerMiddleware() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		// GET 参数处理 | ||||
| 		params := c.Request.URL.Query() | ||||
| 		for key, values := range params { | ||||
| 			for i, value := range values { | ||||
| 				params[key][i] = strings.TrimSpace(value) | ||||
| 			} | ||||
| 		} | ||||
| 		// update get parameters | ||||
| 		c.Request.URL.RawQuery = params.Encode() | ||||
| 		// skip file upload requests | ||||
| 		contentType := c.Request.Header.Get("Content-Type") | ||||
| 		if strings.Contains(contentType, "multipart/form-data") { | ||||
| 			c.Next() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if strings.Contains(contentType, "application/json") { | ||||
| 			// process POST JSON request body | ||||
| 			bodyBytes, err := io.ReadAll(c.Request.Body) | ||||
| 			if err != nil { | ||||
| 				c.Next() | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// 还原请求体 | ||||
| 			c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) | ||||
| 			// 将请求体解析为 JSON | ||||
| 			var jsonData map[string]interface{} | ||||
| 			if err := c.ShouldBindJSON(&jsonData); err != nil { | ||||
| 				c.Next() | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// 对 JSON 数据中的字符串值去除两端空格 | ||||
| 			trimJSONStrings(jsonData) | ||||
| 			// 更新请求体 | ||||
| 			c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData))) | ||||
| 		} | ||||
|  | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 递归对 JSON 数据中的字符串值去除两端空格 | ||||
| func trimJSONStrings(data interface{}) { | ||||
| 	switch v := data.(type) { | ||||
| 	case map[string]interface{}: | ||||
| 		for key, value := range v { | ||||
| 			switch valueType := value.(type) { | ||||
| 			case string: | ||||
| 				v[key] = strings.TrimSpace(valueType) | ||||
| 			case map[string]interface{}, []interface{}: | ||||
| 				trimJSONStrings(value) | ||||
| 			} | ||||
| 		} | ||||
| 	case []interface{}: | ||||
| 		for i, value := range v { | ||||
| 			switch valueType := value.(type) { | ||||
| 			case string: | ||||
| 				v[i] = strings.TrimSpace(valueType) | ||||
| 			case map[string]interface{}, []interface{}: | ||||
| 				trimJSONStrings(value) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 静态资源中间件 | ||||
| func staticResourceMiddleware() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
|  | ||||
| 		url := c.Request.URL.String() | ||||
| 		// 拦截生成缩略图请求 | ||||
| 		if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") { | ||||
| 			r := strings.SplitAfter(url, "imageView2") | ||||
| 			size := strings.Split(r[1], "/") | ||||
| 			if len(size) != 8 { | ||||
| 				c.String(http.StatusNotFound, "invalid thumb args") | ||||
| 				return | ||||
| 			} | ||||
| 			with := utils.IntValue(size[3], 0) | ||||
| 			height := utils.IntValue(size[5], 0) | ||||
| 			quality := utils.IntValue(size[7], 75) | ||||
|  | ||||
| 			// 打开图片文件 | ||||
| 			filePath := strings.TrimLeft(c.Request.URL.Path, "/") | ||||
| 			file, err := os.Open(filePath) | ||||
| 			if err != nil { | ||||
| 				c.String(http.StatusNotFound, "Image not found") | ||||
| 				return | ||||
| 			} | ||||
| 			defer file.Close() | ||||
|  | ||||
| 			// 解码图片 | ||||
| 			img, _, err := image.Decode(file) | ||||
| 			if err != nil { | ||||
| 				c.String(http.StatusInternalServerError, "Error decoding image") | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			var newImg image.Image | ||||
| 			if height == 0 || with == 0 { | ||||
| 				// 固定宽度,高度自适应 | ||||
| 				newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3) | ||||
| 			} else { | ||||
| 				// 生成缩略图 | ||||
| 				newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3) | ||||
| 			} | ||||
| 			var buffer bytes.Buffer | ||||
| 			err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality}) | ||||
| 			if err != nil { | ||||
| 				log.Fatal(err) | ||||
| 			} | ||||
|  | ||||
| 			// 设置图片缓存有效期为一年 (365天) | ||||
| 			c.Header("Cache-Control", "max-age=31536000, public") | ||||
| 			// 直接输出图像数据流 | ||||
| 			c.Data(http.StatusOK, "image/jpeg", buffer.Bytes()) | ||||
| 			c.Abort() // 中断请求 | ||||
| 		} | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -14,13 +14,11 @@ var logger = logger2.GetLogger() | ||||
|  | ||||
| func NewDefaultConfig() *types.AppConfig { | ||||
| 	return &types.AppConfig{ | ||||
| 		Listen:        "0.0.0.0:5678", | ||||
| 		ProxyURL:      "", | ||||
| 		Manager:       types.Manager{Username: "admin", Password: "admin123"}, | ||||
| 		StaticDir:     "./static", | ||||
| 		StaticUrl:     "http://localhost/5678/static", | ||||
| 		Redis:         types.RedisConfig{Host: "localhost", Port: 6379, Password: ""}, | ||||
| 		AesEncryptKey: utils.RandString(24), | ||||
| 		Listen:    "0.0.0.0:5678", | ||||
| 		ProxyURL:  "", | ||||
| 		StaticDir: "./static", | ||||
| 		StaticUrl: "http://localhost/5678/static", | ||||
| 		Redis:     types.RedisConfig{Host: "localhost", Port: 6379, Password: ""}, | ||||
| 		Session: types.Session{ | ||||
| 			SecretKey: utils.RandString(64), | ||||
| 			MaxAge:    86400, | ||||
| @@ -33,9 +31,8 @@ func NewDefaultConfig() *types.AppConfig { | ||||
| 				BasePath: "./static/upload", | ||||
| 			}, | ||||
| 		}, | ||||
| 		MjConfig:  types.MidJourneyConfig{Enabled: false}, | ||||
| 		SdConfig:  types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"}, | ||||
| 		WeChatBot: false, | ||||
| 		WeChatBot:    false, | ||||
| 		AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -8,7 +8,13 @@ type ApiRequest struct { | ||||
| 	Stream      bool          `json:"stream"` | ||||
| 	Messages    []interface{} `json:"messages,omitempty"` | ||||
| 	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM | ||||
| 	Functions   []Function    `json:"functions,omitempty"` | ||||
| 	Tools       []interface{} `json:"tools,omitempty"` | ||||
| 	Functions   []interface{} `json:"functions,omitempty"` // 兼容中转平台 | ||||
|  | ||||
| 	ToolChoice string `json:"tool_choice,omitempty"` | ||||
|  | ||||
| 	Input      map[string]interface{} `json:"input,omitempty"`      //兼容阿里通义千问 | ||||
| 	Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问 | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| @@ -27,10 +33,14 @@ type ChoiceItem struct { | ||||
| } | ||||
|  | ||||
| type Delta struct { | ||||
| 	Role         string       `json:"role"` | ||||
| 	Name         string       `json:"name"` | ||||
| 	Content      interface{}  `json:"content"` | ||||
| 	FunctionCall FunctionCall `json:"function_call,omitempty"` | ||||
| 	Role         string      `json:"role"` | ||||
| 	Name         string      `json:"name"` | ||||
| 	Content      interface{} `json:"content"` | ||||
| 	ToolCalls    []ToolCall  `json:"tool_calls,omitempty"` | ||||
| 	FunctionCall struct { | ||||
| 		Name      string `json:"name,omitempty"` | ||||
| 		Arguments string `json:"arguments,omitempty"` | ||||
| 	} `json:"function_call,omitempty"` | ||||
| } | ||||
|  | ||||
| // ChatSession 聊天会话对象 | ||||
| @@ -44,10 +54,14 @@ type ChatSession struct { | ||||
| } | ||||
|  | ||||
| type ChatModel struct { | ||||
| 	Id       uint     `json:"id"` | ||||
| 	Platform Platform `json:"platform"` | ||||
| 	Value    string   `json:"value"` | ||||
| 	Weight   int      `json:"weight"` | ||||
| 	Id          uint     `json:"id"` | ||||
| 	Platform    Platform `json:"platform"` | ||||
| 	Name        string   `json:"name"` | ||||
| 	Value       string   `json:"value"` | ||||
| 	Power       int      `json:"power"` | ||||
| 	MaxTokens   int      `json:"max_tokens"`  // 最大响应长度 | ||||
| 	MaxContext  int      `json:"max_context"` // 最大上下文长度 | ||||
| 	Temperature float32  `json:"temperature"` // 模型温度 | ||||
| } | ||||
|  | ||||
| type ApiError struct { | ||||
| @@ -61,17 +75,37 @@ type ApiError struct { | ||||
|  | ||||
| const PromptMsg = "prompt" // prompt message | ||||
| const ReplyMsg = "reply"   // reply message | ||||
| const MjMsg = "mj" | ||||
|  | ||||
| var ModelToTokens = map[string]int{ | ||||
| 	"gpt-3.5-turbo":     4096, | ||||
| 	"gpt-3.5-turbo-16k": 16384, | ||||
| 	"gpt-4":             8192, | ||||
| 	"gpt-4-32k":         32768, | ||||
| 	"chatglm_pro":       32768, // 清华智普 | ||||
| 	"chatglm_std":       16384, | ||||
| 	"chatglm_lite":      4096, | ||||
| 	"ernie_bot_turbo":   8192, // 文心一言 | ||||
| 	"general":           8192, // 科大讯飞 | ||||
| 	"general2":          8192, | ||||
| // PowerType 算力日志类型 | ||||
| type PowerType int | ||||
|  | ||||
| const ( | ||||
| 	PowerRecharge = PowerType(1) // 充值 | ||||
| 	PowerConsume  = PowerType(2) // 消费 | ||||
| 	PowerRefund   = PowerType(3) // 任务(SD,MJ)执行失败,退款 | ||||
| 	PowerInvite   = PowerType(4) // 邀请奖励 | ||||
| 	PowerReward   = PowerType(5) // 众筹 | ||||
| 	PowerGift     = PowerType(6) // 系统赠送 | ||||
| ) | ||||
|  | ||||
| func (t PowerType) String() string { | ||||
| 	switch t { | ||||
| 	case PowerRecharge: | ||||
| 		return "充值" | ||||
| 	case PowerConsume: | ||||
| 		return "消费" | ||||
| 	case PowerRefund: | ||||
| 		return "退款" | ||||
| 	case PowerReward: | ||||
| 		return "众筹" | ||||
|  | ||||
| 	} | ||||
| 	return "其他" | ||||
| } | ||||
|  | ||||
| type PowerMark int | ||||
|  | ||||
| const ( | ||||
| 	PowerSub = PowerMark(0) | ||||
| 	PowerAdd = PowerMark(1) | ||||
| ) | ||||
|   | ||||
| @@ -8,19 +8,33 @@ type AppConfig struct { | ||||
| 	Path          string `toml:"-"` | ||||
| 	Listen        string | ||||
| 	Session       Session | ||||
| 	AdminSession  Session | ||||
| 	ProxyURL      string | ||||
| 	MysqlDns      string            // mysql 连接地址 | ||||
| 	Manager       Manager           // 后台管理员账户信息 | ||||
| 	StaticDir     string            // 静态资源目录 | ||||
| 	StaticUrl     string            // 静态资源 URL | ||||
| 	Redis         RedisConfig       // redis 连接信息 | ||||
| 	ApiConfig     ChatPlusApiConfig // ChatPlus API authorization configs | ||||
| 	AesEncryptKey string | ||||
| 	SmsConfig     AliYunSmsConfig       // AliYun send message service config | ||||
| 	OSS           OSSConfig             // OSS config | ||||
| 	MjConfig      MidJourneyConfig      // mj 绘画配置 | ||||
| 	WeChatBot     bool                  // 是否启用微信机器人 | ||||
| 	SdConfig      StableDiffusionConfig // sd 绘画配置 | ||||
| 	MysqlDns      string                  // mysql 连接地址 | ||||
| 	StaticDir     string                  // 静态资源目录 | ||||
| 	StaticUrl     string                  // 静态资源 URL | ||||
| 	Redis         RedisConfig             // redis 连接信息 | ||||
| 	ApiConfig     ChatPlusApiConfig       // ChatPlus API authorization configs | ||||
| 	SMS           SMSConfig               // send mobile message config | ||||
| 	OSS           OSSConfig               // OSS config | ||||
| 	MjConfigs     []MidJourneyConfig      // mj AI draw service pool | ||||
| 	MjPlusConfigs []MidJourneyPlusConfig  // MJ plus config | ||||
| 	WeChatBot     bool                    // 是否启用微信机器人 | ||||
| 	SdConfigs     []StableDiffusionConfig // sd AI draw service pool | ||||
|  | ||||
| 	XXLConfig     XXLConfig | ||||
| 	AlipayConfig  AlipayConfig | ||||
| 	HuPiPayConfig HuPiPayConfig | ||||
| 	SmtpConfig    SmtpConfig // 邮件发送配置 | ||||
| 	JPayConfig    JPayConfig // payjs 支付配置 | ||||
| } | ||||
|  | ||||
| type SmtpConfig struct { | ||||
| 	Host     string | ||||
| 	Port     int | ||||
| 	AppName  string // 应用名称 | ||||
| 	From     string // 发件人邮箱地址 | ||||
| 	Password string // 发件人邮箱密码 | ||||
| } | ||||
|  | ||||
| type ChatPlusApiConfig struct { | ||||
| @@ -30,15 +44,15 @@ type ChatPlusApiConfig struct { | ||||
| } | ||||
|  | ||||
| type MidJourneyConfig struct { | ||||
| 	Enabled   bool | ||||
| 	UserToken string | ||||
| 	BotToken  string | ||||
| 	GuildId   string // Server ID | ||||
| 	ChanelId  string // Chanel ID | ||||
| } | ||||
|  | ||||
| type WeChatConfig struct { | ||||
| 	Enabled bool | ||||
| 	Enabled        bool | ||||
| 	UserToken      string | ||||
| 	BotToken       string | ||||
| 	GuildId        string // Server ID | ||||
| 	ChanelId       string // Chanel ID | ||||
| 	UseCDN         bool | ||||
| 	ImgCdnURL      string // 图片反代加速地址 | ||||
| 	DiscordAPI     string | ||||
| 	DiscordGateway string | ||||
| } | ||||
|  | ||||
| type StableDiffusionConfig struct { | ||||
| @@ -48,13 +62,56 @@ type StableDiffusionConfig struct { | ||||
| 	Txt2ImgJsonPath string | ||||
| } | ||||
|  | ||||
| type AliYunSmsConfig struct { | ||||
| 	AccessKey    string | ||||
| 	AccessSecret string | ||||
| 	Product      string | ||||
| 	Domain       string | ||||
| 	Sign         string // 短信签名 | ||||
| 	CodeTempId   string // 验证码短信模板 ID | ||||
| type MidJourneyPlusConfig struct { | ||||
| 	Enabled   bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 | ||||
| 	ApiURL    string // api 地址 | ||||
| 	Mode      string // 绘画模式,可选值:fast/turbo/relax | ||||
| 	CdnURL    string // CDN 加速地址 | ||||
| 	ApiKey    string | ||||
| 	NotifyURL string // 任务进度更新回调地址 | ||||
| } | ||||
|  | ||||
| type AlipayConfig struct { | ||||
| 	Enabled         bool   // 是否启用该支付通道 | ||||
| 	SandBox         bool   // 是否沙盒环境 | ||||
| 	AppId           string // 应用 ID | ||||
| 	UserId          string // 支付宝用户 ID | ||||
| 	PrivateKey      string // 用户私钥文件路径 | ||||
| 	PublicKey       string // 用户公钥文件路径 | ||||
| 	AlipayPublicKey string // 支付宝公钥文件路径 | ||||
| 	RootCert        string // Root 秘钥路径 | ||||
| 	NotifyURL       string // 异步通知回调 | ||||
| 	ReturnURL       string // 支付成功返回地址 | ||||
| } | ||||
|  | ||||
| type HuPiPayConfig struct { //虎皮椒第四方支付配置 | ||||
| 	Enabled   bool   // 是否启用该支付通道 | ||||
| 	Name      string // 支付名称,如:wechat/alipay | ||||
| 	AppId     string // App ID | ||||
| 	AppSecret string // app 密钥 | ||||
| 	ApiURL    string // 支付网关 | ||||
| 	NotifyURL string // 异步通知回调 | ||||
| 	ReturnURL string // 支付成功返回地址 | ||||
| } | ||||
|  | ||||
| // JPayConfig PayJs 支付配置 | ||||
| type JPayConfig struct { | ||||
| 	Enabled    bool | ||||
| 	Name       string // 支付名称,默认 wechat | ||||
| 	AppId      string // 商户 ID | ||||
| 	PrivateKey string // 私钥 | ||||
| 	ApiURL     string // API 网关 | ||||
| 	NotifyURL  string // 异步回调地址 | ||||
| 	ReturnURL  string // 支付成功返回地址 | ||||
| } | ||||
|  | ||||
| type XXLConfig struct { // XXL 任务调度配置 | ||||
| 	Enabled      bool | ||||
| 	ServerAddr   string | ||||
| 	ExecutorIp   string | ||||
| 	ExecutorPort string | ||||
| 	AccessToken  string | ||||
| 	RegistryKey  string | ||||
| } | ||||
|  | ||||
| type RedisConfig struct { | ||||
| @@ -68,25 +125,6 @@ func (c RedisConfig) Url() string { | ||||
| 	return fmt.Sprintf("%s:%d", c.Host, c.Port) | ||||
| } | ||||
|  | ||||
| // Manager 管理员 | ||||
| type Manager struct { | ||||
| 	Username string `json:"username"` | ||||
| 	Password string `json:"password"` | ||||
| } | ||||
|  | ||||
| // ChatConfig 系统默认的聊天配置 | ||||
| type ChatConfig struct { | ||||
| 	OpenAI  ModelAPIConfig `json:"open_ai"` | ||||
| 	Azure   ModelAPIConfig `json:"azure"` | ||||
| 	ChatGML ModelAPIConfig `json:"chat_gml"` | ||||
| 	Baidu   ModelAPIConfig `json:"baidu"` | ||||
| 	XunFei  ModelAPIConfig `json:"xun_fei"` | ||||
|  | ||||
| 	EnableContext bool `json:"enable_context"` // 是否开启聊天上下文 | ||||
| 	EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录 | ||||
| 	ContextDeep   int  `json:"context_deep"`   // 上下文深度 | ||||
| } | ||||
|  | ||||
| type Platform string | ||||
|  | ||||
| const OpenAI = Platform("OpenAI") | ||||
| @@ -94,31 +132,34 @@ const Azure = Platform("Azure") | ||||
| const ChatGLM = Platform("ChatGLM") | ||||
| const Baidu = Platform("Baidu") | ||||
| const XunFei = Platform("XunFei") | ||||
|  | ||||
| // UserChatConfig 用户的聊天配置 | ||||
| type UserChatConfig struct { | ||||
| 	ApiKeys map[Platform]string `json:"api_keys"` | ||||
| } | ||||
|  | ||||
| type ModelAPIConfig struct { | ||||
| 	ApiURL      string  `json:"api_url,omitempty"` | ||||
| 	Temperature float32 `json:"temperature"` | ||||
| 	MaxTokens   int     `json:"max_tokens"` | ||||
| 	ApiKey      string  `json:"api_key"` | ||||
| } | ||||
| const QWen = Platform("QWen") | ||||
|  | ||||
| type SystemConfig struct { | ||||
| 	Title           string   `json:"title"` | ||||
| 	AdminTitle      string   `json:"admin_title"` | ||||
| 	Models          []string `json:"models"` | ||||
| 	UserInitCalls   int      `json:"user_init_calls"` // 新用户注册默认总送多少次调用 | ||||
| 	InitImgCalls    int      `json:"init_img_calls"` | ||||
| 	VipMonthCalls   int      `json:"vip_month_calls"` // 会员每个赠送的调用次数 | ||||
| 	EnabledRegister bool     `json:"enabled_register"` | ||||
| 	EnabledMsg      bool     `json:"enabled_msg"`      // 启用短信验证码服务 | ||||
| 	EnabledDraw     bool     `json:"enabled_draw"`     // 启动 AI 绘画功能 | ||||
| 	RewardImg       string   `json:"reward_img"`       // 众筹收款二维码地址 | ||||
| 	EnabledFunction bool     `json:"enabled_function"` // 启用 API 函数功能 | ||||
| 	EnabledReward   bool     `json:"enabled_reward"`   // 启用众筹功能 | ||||
| 	DefaultModels   []string `json:"default_models"`   // 默认开通的 AI 模型 | ||||
| 	Title         string `json:"title,omitempty"` | ||||
| 	AdminTitle    string `json:"admin_title,omitempty"` | ||||
| 	Logo          string `json:"logo,omitempty"` | ||||
| 	InitPower     int    `json:"init_power,omitempty"`      // 新用户注册赠送算力值 | ||||
| 	DailyPower    int    `json:"daily_power,omitempty"`     // 每日赠送算力 | ||||
| 	InvitePower   int    `json:"invite_power,omitempty"`    // 邀请新用户赠送算力值 | ||||
| 	VipMonthPower int    `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值 | ||||
|  | ||||
| 	RegisterWays    []string `json:"register_ways,omitempty"`    // 注册方式:支持手机,邮箱注册,账号密码注册 | ||||
| 	EnabledRegister bool     `json:"enabled_register,omitempty"` // 是否开放注册 | ||||
|  | ||||
| 	RewardImg     string  `json:"reward_img,omitempty"`     // 众筹收款二维码地址 | ||||
| 	EnabledReward bool    `json:"enabled_reward,omitempty"` // 启用众筹功能 | ||||
| 	PowerPrice    float64 `json:"power_price,omitempty"`    // 算力单价 | ||||
|  | ||||
| 	OrderPayTimeout int    `json:"order_pay_timeout,omitempty"` //订单支付超时时间 | ||||
| 	VipInfoText     string `json:"vip_info_text"`               // 会员页面充值说明 | ||||
| 	DefaultModels   []int  `json:"default_models,omitempty"`    // 默认开通的 AI 模型 | ||||
|  | ||||
| 	MjPower   int `json:"mj_power,omitempty"`   // MJ 绘画消耗算力 | ||||
| 	SdPower   int `json:"sd_power,omitempty"`   // SD 绘画消耗算力 | ||||
| 	DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力 | ||||
|  | ||||
| 	WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 | ||||
|  | ||||
| 	EnableContext bool `json:"enable_context,omitempty"` | ||||
| 	ContextDeep   int  `json:"context_deep,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -1,8 +1,11 @@ | ||||
| package types | ||||
|  | ||||
| type FunctionCall struct { | ||||
| 	Name      string `json:"name"` | ||||
| 	Arguments string `json:"arguments"` | ||||
| type ToolCall struct { | ||||
| 	Type     string `json:"type"` | ||||
| 	Function struct { | ||||
| 		Name      string `json:"name"` | ||||
| 		Arguments string `json:"arguments"` | ||||
| 	} `json:"function"` | ||||
| } | ||||
|  | ||||
| type Function struct { | ||||
| @@ -21,72 +24,3 @@ type Property struct { | ||||
| 	Type        string `json:"type"` | ||||
| 	Description string `json:"description"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	FuncZaoBao     = "zao_bao"     // 每日早报 | ||||
| 	FuncHeadLine   = "headline"    // 今日头条 | ||||
| 	FuncWeibo      = "weibo_hot"   // 微博热搜 | ||||
| 	FuncMidJourney = "mid_journey" // MJ 绘画 | ||||
| ) | ||||
|  | ||||
| var InnerFunctions = []Function{ | ||||
| 	{ | ||||
| 		Name:        FuncZaoBao, | ||||
| 		Description: "每日早报,获取当天全球的热门新闻事件列表", | ||||
| 		Parameters: Parameters{ | ||||
|  | ||||
| 			Type: "object", | ||||
| 			Properties: map[string]Property{ | ||||
| 				"text": { | ||||
| 					Type:        "string", | ||||
| 					Description: "", | ||||
| 				}, | ||||
| 			}, | ||||
| 			Required: []string{}, | ||||
| 		}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		Name:        FuncWeibo, | ||||
| 		Description: "新浪微博热搜榜,微博当日热搜榜单", | ||||
| 		Parameters: Parameters{ | ||||
| 			Type: "object", | ||||
| 			Properties: map[string]Property{ | ||||
| 				"text": { | ||||
| 					Type:        "string", | ||||
| 					Description: "", | ||||
| 				}, | ||||
| 			}, | ||||
| 			Required: []string{}, | ||||
| 		}, | ||||
| 	}, | ||||
|  | ||||
| 	{ | ||||
| 		Name:        FuncHeadLine, | ||||
| 		Description: "今日头条,给用户推荐当天的头条新闻,周榜热文", | ||||
| 		Parameters: Parameters{ | ||||
| 			Type: "object", | ||||
| 			Properties: map[string]Property{ | ||||
| 				"text": { | ||||
| 					Type:        "string", | ||||
| 					Description: "", | ||||
| 				}, | ||||
| 			}, | ||||
| 			Required: []string{}, | ||||
| 		}, | ||||
| 	}, | ||||
|  | ||||
| 	{ | ||||
| 		Name:        FuncMidJourney, | ||||
| 		Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画", | ||||
| 		Parameters: Parameters{ | ||||
| 			Type: "object", | ||||
| 			Properties: map[string]Property{ | ||||
| 				"prompt": { | ||||
| 					Type:        "string", | ||||
| 					Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。提示词中的参数作为提示的一部分,不要删除", | ||||
| 				}, | ||||
| 			}, | ||||
| 			Required: []string{}, | ||||
| 		}, | ||||
| 	}, | ||||
| } | ||||
|   | ||||
| @@ -6,10 +6,10 @@ import ( | ||||
| ) | ||||
|  | ||||
| type MKey interface { | ||||
| 	string | int | ||||
| 	string | int | uint | ||||
| } | ||||
| type MValue interface { | ||||
| 	*WsClient | *ChatSession | context.CancelFunc | []interface{} | ||||
| 	*WsClient | *ChatSession | context.CancelFunc | []Message | ||||
| } | ||||
| type LMap[K MKey, T MValue] struct { | ||||
| 	lock sync.RWMutex | ||||
|   | ||||
							
								
								
									
										17
									
								
								api/core/types/order.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								api/core/types/order.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| package types | ||||
|  | ||||
| type OrderStatus int | ||||
|  | ||||
| const ( | ||||
| 	OrderNotPaid     = OrderStatus(0) | ||||
| 	OrderScanned     = OrderStatus(1) // 已扫码 | ||||
| 	OrderPaidSuccess = OrderStatus(2) | ||||
| ) | ||||
|  | ||||
| type OrderRemark struct { | ||||
| 	Days     int     `json:"days"`  // 有效期 | ||||
| 	Power    int     `json:"power"` // 增加算力点数 | ||||
| 	Name     string  `json:"name"`  // 产品名称 | ||||
| 	Price    float64 `json:"price"` | ||||
| 	Discount float64 `json:"discount"` | ||||
| } | ||||
| @@ -12,6 +12,7 @@ type MiniOssConfig struct { | ||||
| 	AccessKey    string | ||||
| 	AccessSecret string | ||||
| 	Bucket       string | ||||
| 	SubDir       string | ||||
| 	UseSSL       bool | ||||
| 	Domain       string | ||||
| } | ||||
| @@ -21,6 +22,7 @@ type QiNiuOssConfig struct { | ||||
| 	AccessKey    string | ||||
| 	AccessSecret string | ||||
| 	Bucket       string | ||||
| 	SubDir       string | ||||
| 	Domain       string | ||||
| } | ||||
|  | ||||
| @@ -29,6 +31,7 @@ type AliYunOssConfig struct { | ||||
| 	AccessKey    string | ||||
| 	AccessSecret string | ||||
| 	Bucket       string | ||||
| 	SubDir       string | ||||
| 	Domain       string | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										26
									
								
								api/core/types/sms.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								api/core/types/sms.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| package types | ||||
|  | ||||
| type SMSConfig struct { | ||||
| 	Active string | ||||
| 	Ali    SmsConfigAli | ||||
| 	Bao    SmsConfigBao | ||||
| } | ||||
|  | ||||
| // SmsConfigAli 阿里云短信平台配置 | ||||
| type SmsConfigAli struct { | ||||
| 	AccessKey    string | ||||
| 	AccessSecret string | ||||
| 	Product      string | ||||
| 	Domain       string | ||||
| 	Sign         string // 短信签名 | ||||
| 	CodeTempId   string // 验证码短信模板 ID | ||||
| } | ||||
|  | ||||
| // SmsConfigBao 短信宝平台配置 | ||||
| type SmsConfigBao struct { | ||||
| 	Username     string //短信宝平台注册的用户名 | ||||
| 	Password     string //短信宝平台注册的密码 | ||||
| 	Domain       string //域名 | ||||
| 	Sign         string // 短信签名 | ||||
| 	CodeTemplate string // 验证码短信模板 匹配 | ||||
| } | ||||
| @@ -9,30 +9,22 @@ func (t TaskType) String() string { | ||||
|  | ||||
| const ( | ||||
| 	TaskImage     = TaskType("image") | ||||
| 	TaskBlend     = TaskType("blend") | ||||
| 	TaskSwapFace  = TaskType("swapFace") | ||||
| 	TaskUpscale   = TaskType("upscale") | ||||
| 	TaskVariation = TaskType("variation") | ||||
| 	TaskTxt2Img   = TaskType("text2img") | ||||
| ) | ||||
|  | ||||
| // TaskSrc 任务来源 | ||||
| type TaskSrc string | ||||
|  | ||||
| const ( | ||||
| 	TaskSrcChat = TaskSrc("chat") // 来自聊天页面 | ||||
| 	TaskSrcImg  = TaskSrc("img")  // 专业绘画页面 | ||||
| ) | ||||
|  | ||||
| // MjTask MidJourney 任务 | ||||
| type MjTask struct { | ||||
| 	Id          int      `json:"id"` | ||||
| 	Id          uint     `json:"id"` | ||||
| 	TaskId      string   `json:"task_id"` | ||||
| 	ImgArr      []string `json:"img_arr"` | ||||
| 	ChannelId   string   `json:"channel_id"` | ||||
| 	SessionId   string   `json:"session_id"` | ||||
| 	Src         TaskSrc  `json:"src"` | ||||
| 	Type        TaskType `json:"type"` | ||||
| 	UserId      int      `json:"user_id"` | ||||
| 	Prompt      string   `json:"prompt,omitempty"` | ||||
| 	ChatId      string   `json:"chat_id,omitempty"` | ||||
| 	RoleId      int      `json:"role_id,omitempty"` | ||||
| 	Icon        string   `json:"icon,omitempty"` | ||||
| 	Index       int      `json:"index,omitempty"` | ||||
| 	MessageId   string   `json:"message_id,omitempty"` | ||||
| 	MessageHash string   `json:"message_hash,omitempty"` | ||||
| @@ -42,7 +34,6 @@ type MjTask struct { | ||||
| type SdTask struct { | ||||
| 	Id         int          `json:"id"` // job 数据库ID | ||||
| 	SessionId  string       `json:"session_id"` | ||||
| 	Src        TaskSrc      `json:"src"` | ||||
| 	Type       TaskType     `json:"type"` | ||||
| 	UserId     int          `json:"user_id"` | ||||
| 	Prompt     string       `json:"prompt,omitempty"` | ||||
|   | ||||
| @@ -30,6 +30,7 @@ const ( | ||||
| 	Success       = BizCode(0) | ||||
| 	Failed        = BizCode(1) | ||||
| 	NotAuthorized = BizCode(400) // 未授权 | ||||
| 	NotPermission = BizCode(403) // 没有权限 | ||||
|  | ||||
| 	OkMsg       = "Success" | ||||
| 	ErrorMsg    = "系统开小差了" | ||||
|   | ||||
							
								
								
									
										23
									
								
								api/go.mod
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								api/go.mod
									
									
									
									
									
								
							| @@ -6,7 +6,6 @@ require ( | ||||
| 	github.com/BurntSushi/toml v1.1.0 | ||||
| 	github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 | ||||
| 	github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible | ||||
| 	github.com/bwmarrin/discordgo v0.27.1 | ||||
| 	github.com/eatmoreapple/openwechat v1.2.1 | ||||
| 	github.com/gin-gonic/gin v1.9.1 | ||||
| 	github.com/go-redis/redis/v8 v8.11.5 | ||||
| @@ -18,12 +17,26 @@ require ( | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 | ||||
| 	github.com/qiniu/go-sdk/v7 v7.17.1 | ||||
| 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e | ||||
| 	github.com/syndtr/goleveldb v1.0.0 | ||||
| 	github.com/smartwalle/alipay/v3 v3.2.15 | ||||
| 	go.uber.org/zap v1.23.0 | ||||
| 	gopkg.in/natefinch/lumberjack.v2 v2.2.1 | ||||
| 	gorm.io/driver/mysql v1.4.7 | ||||
| ) | ||||
|  | ||||
| require github.com/xxl-job/xxl-job-executor-go v1.2.0 | ||||
|  | ||||
| require github.com/bg5t/mydiscordgo v0.28.1 | ||||
|  | ||||
| require ( | ||||
| 	github.com/mojocn/base64Captcha v1.3.1 | ||||
| 	github.com/shopspring/decimal v1.3.1 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect | ||||
| 	golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/andybalholm/brotli v1.0.4 // indirect | ||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | ||||
| @@ -34,6 +47,7 @@ require ( | ||||
| 	github.com/dustin/go-humanize v1.0.1 // indirect | ||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | ||||
| 	github.com/gaukas/godicttls v0.0.3 // indirect | ||||
| 	github.com/go-basic/ipv4 v1.0.0 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.7.0 // indirect | ||||
| 	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| @@ -49,6 +63,7 @@ require ( | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.5 // indirect | ||||
| 	github.com/minio/md5-simd v1.1.2 // indirect | ||||
| 	github.com/minio/sha256-simd v1.0.1 // indirect | ||||
| 	github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 | ||||
| 	github.com/onsi/ginkgo/v2 v2.10.0 // indirect | ||||
| 	github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| @@ -59,6 +74,9 @@ require ( | ||||
| 	github.com/refraction-networking/utls v1.3.2 // indirect | ||||
| 	github.com/rs/xid v1.5.0 // indirect | ||||
| 	github.com/sirupsen/logrus v1.9.3 // indirect | ||||
| 	github.com/smartwalle/ncrypto v1.0.2 // indirect | ||||
| 	github.com/smartwalle/ngx v1.0.6 // indirect | ||||
| 	github.com/smartwalle/nsign v1.0.8 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	go.uber.org/dig v1.16.1 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| @@ -79,7 +97,6 @@ require ( | ||||
| 	github.com/go-playground/locales v0.14.1 // indirect | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-playground/validator/v10 v10.14.0 // indirect | ||||
| 	github.com/golang/snappy v0.0.1 // indirect | ||||
| 	github.com/json-iterator/go v1.1.12 // indirect | ||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.19 // indirect | ||||
|   | ||||
							
								
								
									
										43
									
								
								api/go.sum
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								api/go.sum
									
									
									
									
									
								
							| @@ -7,8 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9 | ||||
| github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= | ||||
| github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= | ||||
| github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= | ||||
| github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY= | ||||
| github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= | ||||
| github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8= | ||||
| github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8= | ||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | ||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | ||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| @@ -29,7 +29,6 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp | ||||
| github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= | ||||
| github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU= | ||||
| github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= | ||||
| github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= | ||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | ||||
| @@ -39,6 +38,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE | ||||
| github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= | ||||
| github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= | ||||
| github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= | ||||
| github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= | ||||
| github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= | ||||
| github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= | ||||
| github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= | ||||
| github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= | ||||
| @@ -64,14 +65,12 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG | ||||
| github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= | ||||
| github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= | ||||
| github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= | ||||
| github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= | ||||
| github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= | ||||
| github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= | ||||
| github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= | ||||
| github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= | ||||
| github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= | ||||
| github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| @@ -87,7 +86,6 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY | ||||
| github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= | ||||
| github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= | ||||
| github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= | ||||
| github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | ||||
| github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I= | ||||
| github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| @@ -133,13 +131,14 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ | ||||
| github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= | ||||
| github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= | ||||
| github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= | ||||
| github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0= | ||||
| github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY= | ||||
| github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= | ||||
| github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= | ||||
| github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | ||||
| github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= | ||||
| github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= | ||||
| github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | ||||
| github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= | ||||
| github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= | ||||
| github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= | ||||
| github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= | ||||
| github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A= | ||||
| github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU= | ||||
| @@ -171,10 +170,20 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA | ||||
| github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= | ||||
| github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= | ||||
| github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= | ||||
| github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= | ||||
| github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= | ||||
| github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= | ||||
| github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= | ||||
| github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= | ||||
| github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= | ||||
| github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8= | ||||
| github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE= | ||||
| github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI= | ||||
| github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk= | ||||
| github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps= | ||||
| github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0= | ||||
| github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo= | ||||
| github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc= | ||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||
| github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | ||||
| github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= | ||||
| @@ -187,8 +196,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o | ||||
| github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | ||||
| github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= | ||||
| github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||||
| github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= | ||||
| github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | ||||
| github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= | ||||
| @@ -197,6 +204,8 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK | ||||
| github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= | ||||
| github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= | ||||
| github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= | ||||
| github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ= | ||||
| github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo= | ||||
| github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= | ||||
| github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= | ||||
| go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= | ||||
| @@ -224,11 +233,12 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= | ||||
| golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= | ||||
| golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= | ||||
| golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= | ||||
| golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ= | ||||
| golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= | ||||
| golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= | ||||
| golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= | ||||
| golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= | ||||
| golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= | ||||
| golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | ||||
| golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||||
| golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| @@ -237,13 +247,11 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug | ||||
| golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= | ||||
| golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= | ||||
| golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= | ||||
| golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= | ||||
| golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= | ||||
| golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| @@ -290,15 +298,12 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= | ||||
| gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= | ||||
| gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= | ||||
| gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= | ||||
| gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= | ||||
| gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= | ||||
| gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= | ||||
| gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= | ||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | ||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= | ||||
| gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | ||||
| gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
|   | ||||
| @@ -6,12 +6,14 @@ import ( | ||||
| 	"chatplus/handler" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"strings" | ||||
| 	"github.com/mojocn/base64Captcha" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -20,47 +22,88 @@ import ( | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| // Manager 管理员 | ||||
| type Manager struct { | ||||
| 	Username  string `json:"username"` | ||||
| 	Password  string `json:"password"` | ||||
| 	Captcha   string `json:"captcha"`    // 验证码 | ||||
| 	CaptchaId string `json:"captcha_id"` // 验证码id | ||||
| } | ||||
|  | ||||
| const SuperManagerID = 1 | ||||
|  | ||||
| type ManagerHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db    *gorm.DB | ||||
| 	redis *redis.Client | ||||
| } | ||||
|  | ||||
| func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler { | ||||
| 	h := ManagerHandler{db: db, redis: client} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client} | ||||
| } | ||||
|  | ||||
| // Login 登录 | ||||
| func (h *ManagerHandler) Login(c *gin.Context) { | ||||
| 	var data types.Manager | ||||
| 	var data Manager | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	manager := h.App.Config.Manager | ||||
| 	if data.Username == manager.Username && data.Password == manager.Password { | ||||
| 		// 创建 token | ||||
| 		token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||
| 			"user_id": manager.Username, | ||||
| 			"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), | ||||
| 		}) | ||||
| 		tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "Failed to generate token, "+err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		// 保存到 redis | ||||
| 		key := "users/" + manager.Username | ||||
| 		if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil { | ||||
| 			resp.ERROR(c, "error with save token: "+err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		resp.SUCCESS(c, tokenString) | ||||
| 	} else { | ||||
| 		resp.ERROR(c, "用户名或者密码错误") | ||||
|  | ||||
| 	// add captcha | ||||
| 	if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) { | ||||
| 		resp.ERROR(c, "验证码错误!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var manager model.AdminUser | ||||
| 	res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "请检查用户名或者密码是否填写正确") | ||||
| 		return | ||||
| 	} | ||||
| 	password := utils.GenPassword(data.Password, manager.Salt) | ||||
| 	if password != manager.Password { | ||||
| 		resp.ERROR(c, "用户名或密码错误") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 超级管理员默认是ID:1 | ||||
| 	if manager.Id != SuperManagerID && manager.Status == false { | ||||
| 		resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 创建 token | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||
| 		"user_id": manager.Id, | ||||
| 		"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), | ||||
| 	}) | ||||
| 	tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey)) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "Failed to generate token, "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	// 保存到 redis | ||||
| 	key := fmt.Sprintf("admin/%d", manager.Id) | ||||
| 	if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil { | ||||
| 		resp.ERROR(c, "error with save token: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 更新最后登录时间和IP | ||||
| 	manager.LastLoginIp = c.ClientIP() | ||||
| 	manager.LastLoginAt = time.Now().Unix() | ||||
| 	h.DB.Updates(&manager) | ||||
|  | ||||
| 	var result = struct { | ||||
| 		IsSuperAdmin bool   `json:"is_super_admin"` | ||||
| 		Token        string `json:"token"` | ||||
| 	}{ | ||||
| 		IsSuperAdmin: manager.Id == 1, | ||||
| 		Token:        tokenString, | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, result) | ||||
| } | ||||
|  | ||||
| // Logout 注销 | ||||
| @@ -75,74 +118,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) { | ||||
|  | ||||
| // Session 会话检测 | ||||
| func (h *ManagerHandler) Session(c *gin.Context) { | ||||
| 	token := c.GetHeader(types.AdminAuthHeader) | ||||
| 	if token == "" { | ||||
| 	id := h.GetLoginUserId(c) | ||||
| 	key := fmt.Sprintf("admin/%d", id) | ||||
| 	if _, err := h.redis.Get(context.Background(), key).Result(); err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 	} else { | ||||
| 		resp.SUCCESS(c) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Migrate 数据修正 | ||||
| func (h *ManagerHandler) Migrate(c *gin.Context) { | ||||
| 	opt := c.Query("opt") | ||||
| 	switch opt { | ||||
| 	case "user": | ||||
| 		// 将用户订阅角色的数据结构从 map 改成数组 | ||||
| 		var users []model.User | ||||
| 		h.db.Find(&users) | ||||
| 		for _, u := range users { | ||||
| 			var m map[string]int | ||||
| 			var roleKeys = make([]string, 0) | ||||
| 			err := utils.JsonDecode(u.ChatRoles, &m) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for k := range m { | ||||
| 				roleKeys = append(roleKeys, k) | ||||
| 			} | ||||
| 			u.ChatRoles = utils.JsonEncode(roleKeys) | ||||
| 			h.db.Updates(&u) | ||||
|  | ||||
| 		} | ||||
| 		break | ||||
| 	case "role": | ||||
| 		// 修改角色图片,改成绝对路径 | ||||
| 		var roles []model.ChatRole | ||||
| 		h.db.Find(&roles) | ||||
| 		for _, r := range roles { | ||||
| 			if !strings.HasPrefix(r.Icon, "/") { | ||||
| 				r.Icon = "/" + r.Icon | ||||
| 				h.db.Updates(&r) | ||||
| 			} | ||||
| 		} | ||||
| 		break | ||||
| 	case "history": | ||||
| 		// 修改角色图片,改成绝对路径 | ||||
| 		var message []model.HistoryMessage | ||||
| 		h.db.Find(&message) | ||||
| 		for _, r := range message { | ||||
| 			if !strings.HasPrefix(r.Icon, "/") { | ||||
| 				r.Icon = "/" + r.Icon | ||||
| 				h.db.Updates(&r) | ||||
| 			} | ||||
|  | ||||
| 		} | ||||
| 		break | ||||
|  | ||||
| 	case "avatar": | ||||
| 		// 更新用户的头像地址 | ||||
| 		var users []model.User | ||||
| 		h.db.Find(&users) | ||||
| 		for _, u := range users { | ||||
| 			if !strings.HasPrefix(u.Avatar, "/") { | ||||
| 				u.Avatar = "/" + u.Avatar | ||||
| 				h.db.Updates(&u) | ||||
| 			} | ||||
| 		} | ||||
| 		break | ||||
| 	var manager model.AdminUser | ||||
| 	res := h.DB.Where("id", id).First(&manager) | ||||
| 	if res.Error != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, "SUCCESS") | ||||
| 	resp.SUCCESS(c, manager) | ||||
| } | ||||
|  | ||||
| // List 数据列表 | ||||
| func (h *ManagerHandler) List(c *gin.Context) { | ||||
| 	var items []model.AdminUser | ||||
| 	res := h.DB.Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	users := make([]vo.AdminUser, 0) | ||||
| 	for _, item := range items { | ||||
| 		var u vo.AdminUser | ||||
| 		err := utils.CopyObject(item, &u) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		u.Id = item.Id | ||||
| 		u.CreatedAt = item.CreatedAt.Unix() | ||||
| 		users = append(users, u) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, users) | ||||
|  | ||||
| } | ||||
|  | ||||
| func (h *ManagerHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Username string `json:"username"` | ||||
| 		Password string `json:"password"` | ||||
| 		Status   bool   `json:"status"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var user model.AdminUser | ||||
| 	res := h.DB.Where("username", data.Username).First(&user) | ||||
| 	if res.Error == nil { | ||||
| 		resp.ERROR(c, "用户名已存在") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 生成密码 | ||||
| 	salt := utils.RandString(8) | ||||
| 	password := utils.GenPassword(data.Password, salt) | ||||
| 	res = h.DB.Save(&model.AdminUser{ | ||||
| 		Username: data.Username, | ||||
| 		Password: password, | ||||
| 		Salt:     salt, | ||||
| 		Status:   data.Status, | ||||
| 	}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "failed with update database") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Remove 删除管理员 | ||||
| func (h *ManagerHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	if id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if id == SuperManagerID { | ||||
| 		resp.ERROR(c, "超级管理员不能删除") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Where("id", id).Delete(&model.AdminUser{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Enable 启用/禁用 | ||||
| func (h *ManagerHandler) Enable(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id      uint `json:"id"` | ||||
| 		Enabled bool `json:"enabled"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // ResetPass 重置密码 | ||||
| func (h *ManagerHandler) ResetPass(c *gin.Context) { | ||||
| 	id := h.GetLoginUserId(c) | ||||
| 	if id != SuperManagerID { | ||||
| 		resp.ERROR(c, "只有超级管理员能够进行该操作") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data struct { | ||||
| 		Id       int    `json:"id"` | ||||
| 		Password string `json:"password"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var user model.AdminUser | ||||
| 	res := h.DB.Where("id", data.Id).First(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	password := utils.GenPassword(data.Password, user.Salt) | ||||
| 	user.Password = password | ||||
| 	res = h.DB.Updates(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
| @@ -14,20 +14,22 @@ import ( | ||||
|  | ||||
| type ApiKeyHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler { | ||||
| 	h := ApiKeyHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}} | ||||
| } | ||||
|  | ||||
| func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id       uint   `json:"id"` | ||||
| 		Platform string `json:"platform"` | ||||
| 		Name     string `json:"name"` | ||||
| 		Type     string `json:"type"` | ||||
| 		Value    string `json:"value"` | ||||
| 		ApiURL   string `json:"api_url"` | ||||
| 		Enabled  bool   `json:"enabled"` | ||||
| 		ProxyURL string `json:"proxy_url"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -36,11 +38,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
|  | ||||
| 	apiKey := model.ApiKey{} | ||||
| 	if data.Id > 0 { | ||||
| 		h.db.Find(&apiKey, data.Id) | ||||
| 		h.DB.Find(&apiKey, data.Id) | ||||
| 	} | ||||
| 	apiKey.Platform = data.Platform | ||||
| 	apiKey.Value = data.Value | ||||
| 	res := h.db.Debug().Save(&apiKey) | ||||
| 	apiKey.Type = data.Type | ||||
| 	apiKey.ApiURL = data.ApiURL | ||||
| 	apiKey.Enabled = data.Enabled | ||||
| 	apiKey.ProxyURL = data.ProxyURL | ||||
| 	apiKey.Name = data.Name | ||||
| 	res := h.DB.Save(&apiKey) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| @@ -58,9 +65,14 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ApiKeyHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.ApiKey | ||||
| 	var keys = make([]vo.ApiKey, 0) | ||||
| 	res := h.db.Find(&items) | ||||
| 	res := h.DB.Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var key vo.ApiKey | ||||
| @@ -78,11 +90,36 @@ func (h *ApiKeyHandler) List(c *gin.Context) { | ||||
| 	resp.SUCCESS(c, keys) | ||||
| } | ||||
|  | ||||
| func (h *ApiKeyHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| func (h *ApiKeyHandler) Set(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id    uint        `json:"id"` | ||||
| 		Filed string      `json:"filed"` | ||||
| 		Value interface{} `json:"value"` | ||||
| 	} | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		res := h.db.Where("id = ?", id).Delete(&model.ApiKey{}) | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *ApiKeyHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id uint | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	if data.Id > 0 { | ||||
| 		res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
|   | ||||
							
								
								
									
										39
									
								
								api/handler/admin/captcha_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								api/handler/admin/captcha_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/mojocn/base64Captcha" | ||||
| ) | ||||
|  | ||||
| type CaptchaHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler { | ||||
| 	return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}} | ||||
| } | ||||
|  | ||||
| type CaptchaVo struct { | ||||
| 	CaptchaId string `json:"captcha_id"` | ||||
| 	PicPath   string `json:"pic_path"` | ||||
| } | ||||
|  | ||||
| // GetCaptcha 获取验证码 | ||||
| func (h *CaptchaHandler) GetCaptcha(c *gin.Context) { | ||||
| 	var captchaVo CaptchaVo | ||||
| 	driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10) | ||||
| 	cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore) | ||||
| 	// b64s是图片的base64编码 | ||||
| 	id, b64s, err := cp.Generate() | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "生成验证码错误!") | ||||
| 		return | ||||
| 	} | ||||
| 	captchaVo.CaptchaId = id | ||||
| 	captchaVo.PicPath = b64s | ||||
|  | ||||
| 	resp.SUCCESS(c, captchaVo) | ||||
| } | ||||
							
								
								
									
										266
									
								
								api/handler/admin/chat_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										266
									
								
								api/handler/admin/chat_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,266 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type ChatHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { | ||||
| 	return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| type chatItemVo struct { | ||||
| 	Username  string      `json:"username"` | ||||
| 	UserId    uint        `json:"user_id"` | ||||
| 	ChatId    string      `json:"chat_id"` | ||||
| 	Title     string      `json:"title"` | ||||
| 	Role      vo.ChatRole `json:"role"` | ||||
| 	Model     string      `json:"model"` | ||||
| 	Token     int         `json:"token"` | ||||
| 	CreatedAt int64       `json:"created_at"` | ||||
| 	MsgNum    int         `json:"msg_num"` // 消息数量 | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data struct { | ||||
| 		Title    string   `json:"title"` | ||||
| 		UserId   uint     `json:"user_id"` | ||||
| 		Model    string   `json:"model"` | ||||
| 		CreateAt []string `json:"created_time"` | ||||
| 		Page     int      `json:"page"` | ||||
| 		PageSize int      `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if data.Title != "" { | ||||
| 		session = session.Where("title LIKE ?", "%"+data.Title+"%") | ||||
| 	} | ||||
| 	if data.UserId > 0 { | ||||
| 		session = session.Where("user_id = ?", data.UserId) | ||||
| 	} | ||||
| 	if data.Model != "" { | ||||
| 		session = session.Where("model = ?", data.Model) | ||||
| 	} | ||||
| 	if len(data.CreateAt) == 2 { | ||||
| 		start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00") | ||||
| 		end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00") | ||||
| 		session = session.Where("created_at >= ? AND created_at <= ?", start, end) | ||||
| 	} | ||||
|  | ||||
| 	var total int64 | ||||
| 	session.Model(&model.ChatItem{}).Count(&total) | ||||
| 	var items []model.ChatItem | ||||
| 	var list = make([]chatItemVo, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		userIds := make([]uint, 0) | ||||
| 		chatIds := make([]string, 0) | ||||
| 		roleIds := make([]uint, 0) | ||||
| 		for _, item := range items { | ||||
| 			userIds = append(userIds, item.UserId) | ||||
| 			chatIds = append(chatIds, item.ChatId) | ||||
| 			roleIds = append(roleIds, item.RoleId) | ||||
| 		} | ||||
| 		var messages []model.ChatMessage | ||||
| 		var users []model.User | ||||
| 		var roles []model.ChatRole | ||||
| 		h.DB.Where("chat_id IN ?", chatIds).Find(&messages) | ||||
| 		h.DB.Where("id IN ?", userIds).Find(&users) | ||||
| 		h.DB.Where("id IN ?", roleIds).Find(&roles) | ||||
|  | ||||
| 		tokenMap := make(map[string]int) | ||||
| 		userMap := make(map[uint]string) | ||||
| 		msgMap := make(map[string]int) | ||||
| 		roleMap := make(map[uint]vo.ChatRole) | ||||
| 		for _, msg := range messages { | ||||
| 			tokenMap[msg.ChatId] += msg.Tokens | ||||
| 			msgMap[msg.ChatId] += 1 | ||||
| 		} | ||||
| 		for _, user := range users { | ||||
| 			userMap[user.Id] = user.Username | ||||
| 		} | ||||
| 		for _, r := range roles { | ||||
| 			var roleVo vo.ChatRole | ||||
| 			err := utils.CopyObject(r, &roleVo) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			roleMap[r.Id] = roleVo | ||||
| 		} | ||||
| 		for _, item := range items { | ||||
| 			list = append(list, chatItemVo{ | ||||
| 				UserId:    item.UserId, | ||||
| 				Username:  userMap[item.UserId], | ||||
| 				ChatId:    item.ChatId, | ||||
| 				Title:     item.Title, | ||||
| 				Model:     item.Model, | ||||
| 				Token:     tokenMap[item.ChatId], | ||||
| 				MsgNum:    msgMap[item.ChatId], | ||||
| 				Role:      roleMap[item.RoleId], | ||||
| 				CreatedAt: item.CreatedAt.Unix(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
|  | ||||
| type chatMessageVo struct { | ||||
| 	Id        uint   `json:"id"` | ||||
| 	UserId    uint   `json:"user_id"` | ||||
| 	Username  string `json:"username"` | ||||
| 	Content   string `json:"content"` | ||||
| 	Type      string `json:"type"` | ||||
| 	Model     string `json:"model"` | ||||
| 	Token     int    `json:"token"` | ||||
| 	Icon      string `json:"icon"` | ||||
| 	CreatedAt int64  `json:"created_at"` | ||||
| } | ||||
|  | ||||
| // Messages 读取聊天记录列表 | ||||
| func (h *ChatHandler) Messages(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		UserId   uint     `json:"user_id"` | ||||
| 		Content  string   `json:"content"` | ||||
| 		Model    string   `json:"model"` | ||||
| 		CreateAt []string `json:"created_time"` | ||||
| 		Page     int      `json:"page"` | ||||
| 		PageSize int      `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if data.Content != "" { | ||||
| 		session = session.Where("content LIKE ?", "%"+data.Content+"%") | ||||
| 	} | ||||
| 	if data.UserId > 0 { | ||||
| 		session = session.Where("user_id = ?", data.UserId) | ||||
| 	} | ||||
| 	if data.Model != "" { | ||||
| 		session = session.Where("model = ?", data.Model) | ||||
| 	} | ||||
| 	if len(data.CreateAt) == 2 { | ||||
| 		start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00") | ||||
| 		end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00") | ||||
| 		session = session.Where("created_at >= ? AND created_at <= ?", start, end) | ||||
| 	} | ||||
|  | ||||
| 	var total int64 | ||||
| 	session.Model(&model.ChatMessage{}).Count(&total) | ||||
| 	var items []model.ChatMessage | ||||
| 	var list = make([]chatMessageVo, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		userIds := make([]uint, 0) | ||||
| 		for _, item := range items { | ||||
| 			userIds = append(userIds, item.UserId) | ||||
| 		} | ||||
| 		var users []model.User | ||||
| 		h.DB.Where("id IN ?", userIds).Find(&users) | ||||
| 		userMap := make(map[uint]string) | ||||
| 		for _, user := range users { | ||||
| 			userMap[user.Id] = user.Username | ||||
| 		} | ||||
| 		for _, item := range items { | ||||
| 			list = append(list, chatMessageVo{ | ||||
| 				Id:        item.Id, | ||||
| 				UserId:    item.UserId, | ||||
| 				Username:  userMap[item.UserId], | ||||
| 				Content:   item.Content, | ||||
| 				Model:     item.Model, | ||||
| 				Token:     item.Tokens, | ||||
| 				Icon:      item.Icon, | ||||
| 				Type:      item.Type, | ||||
| 				CreatedAt: item.CreatedAt.Unix(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
|  | ||||
| // History 获取聊天历史记录 | ||||
| func (h *ChatHandler) History(c *gin.Context) { | ||||
| 	chatId := c.Query("chat_id") // 会话 ID | ||||
| 	var items []model.ChatMessage | ||||
| 	var messages = make([]vo.HistoryMessage, 0) | ||||
| 	res := h.DB.Where("chat_id = ?", chatId).Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No history message") | ||||
| 		return | ||||
| 	} else { | ||||
| 		for _, item := range items { | ||||
| 			var v vo.HistoryMessage | ||||
| 			err := utils.CopyObject(item, &v) | ||||
| 			v.CreatedAt = item.CreatedAt.Unix() | ||||
| 			v.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 			if err == nil { | ||||
| 				messages = append(messages, v) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, messages) | ||||
| } | ||||
|  | ||||
| // RemoveChat 删除对话 | ||||
| func (h *ChatHandler) RemoveChat(c *gin.Context) { | ||||
| 	chatId := h.GetTrim(c, "chat_id") | ||||
| 	if chatId == "" { | ||||
| 		resp.ERROR(c, "请传入 ChatId") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tx := h.DB.Begin() | ||||
| 	// 删除聊天记录 | ||||
| 	res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "failed to remove chat message") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 删除对话 | ||||
| 	res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{}) | ||||
| 	if res.Error != nil { | ||||
| 		tx.Rollback() // 回滚 | ||||
| 		resp.ERROR(c, "failed to remove chat") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tx.Commit() | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // RemoveMessage 删除聊天记录 | ||||
| func (h *ChatHandler) RemoveMessage(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}) | ||||
| 	if tx.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
| @@ -15,37 +15,48 @@ import ( | ||||
|  | ||||
| type ChatModelHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { | ||||
| 	h := ChatModelHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *ChatModelHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id        uint   `json:"id"` | ||||
| 		Name      string `json:"name"` | ||||
| 		Value     string `json:"value"` | ||||
| 		Enabled   bool   `json:"enabled"` | ||||
| 		SortNum   int    `json:"sort_num"` | ||||
| 		Platform  string `json:"platform"` | ||||
| 		Weight    int    `json:"weight"` | ||||
| 		CreatedAt int64  `json:"created_at"` | ||||
| 		Id          uint    `json:"id"` | ||||
| 		Name        string  `json:"name"` | ||||
| 		Value       string  `json:"value"` | ||||
| 		Enabled     bool    `json:"enabled"` | ||||
| 		SortNum     int     `json:"sort_num"` | ||||
| 		Open        bool    `json:"open"` | ||||
| 		Platform    string  `json:"platform"` | ||||
| 		Power       int     `json:"power"` | ||||
| 		MaxTokens   int     `json:"max_tokens"`  // 最大响应长度 | ||||
| 		MaxContext  int     `json:"max_context"` // 最大上下文长度 | ||||
| 		Temperature float32 `json:"temperature"` // 模型温度 | ||||
| 		CreatedAt   int64   `json:"created_at"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	item := model.ChatModel{Platform: data.Platform, Name: data.Name, Value: data.Value, Enabled: data.Enabled, SortNum: data.SortNum, Weight: data.Weight} | ||||
| 	item := model.ChatModel{ | ||||
| 		Platform:    data.Platform, | ||||
| 		Name:        data.Name, | ||||
| 		Value:       data.Value, | ||||
| 		Enabled:     data.Enabled, | ||||
| 		SortNum:     data.SortNum, | ||||
| 		Open:        data.Open, | ||||
| 		MaxTokens:   data.MaxTokens, | ||||
| 		MaxContext:  data.MaxContext, | ||||
| 		Temperature: data.Temperature, | ||||
| 		Power:       data.Power} | ||||
| 	item.Id = data.Id | ||||
| 	if item.Id > 0 { | ||||
| 		item.CreatedAt = time.Unix(data.CreatedAt, 0) | ||||
| 	} | ||||
| 	res := h.db.Save(&item) | ||||
| 	res := h.DB.Save(&item) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| @@ -64,7 +75,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) { | ||||
|  | ||||
| // List 模型列表 | ||||
| func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 	session := h.db.Session(&gorm.Session{}) | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	enable := h.GetBool(c, "enable") | ||||
| 	if enable { | ||||
| 		session = session.Where("enabled", enable) | ||||
| @@ -89,10 +105,11 @@ func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 	resp.SUCCESS(c, cms) | ||||
| } | ||||
|  | ||||
| func (h *ChatModelHandler) Enable(c *gin.Context) { | ||||
| func (h *ChatModelHandler) Set(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id      uint `json:"id"` | ||||
| 		Enabled bool `json:"enabled"` | ||||
| 		Id    uint        `json:"id"` | ||||
| 		Filed string      `json:"filed"` | ||||
| 		Value interface{} `json:"value"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| @@ -100,7 +117,7 @@ func (h *ChatModelHandler) Enable(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update("enabled", data.Enabled) | ||||
| 	res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| @@ -120,7 +137,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| @@ -132,13 +149,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) { | ||||
|  | ||||
| func (h *ChatModelHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	if id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		res := h.db.Where("id = ?", id).Delete(&model.ChatModel{}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
| @@ -15,13 +15,10 @@ import ( | ||||
|  | ||||
| type ChatRoleHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { | ||||
| 	h := ChatRoleHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // Save 创建或者更新某个角色 | ||||
| @@ -41,7 +38,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) { | ||||
| 	if data.CreatedAt > 0 { | ||||
| 		role.CreatedAt = time.Unix(data.CreatedAt, 0) | ||||
| 	} | ||||
| 	res := h.db.Save(&role) | ||||
| 	res := h.DB.Save(&role) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| @@ -53,9 +50,14 @@ func (h *ChatRoleHandler) Save(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.ChatRole | ||||
| 	var roles = make([]vo.ChatRole, 0) | ||||
| 	res := h.db.Order("sort_num ASC").Find(&items) | ||||
| 	res := h.DB.Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No data found") | ||||
| 		return | ||||
| @@ -88,7 +90,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| @@ -98,14 +100,39 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) { | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *ChatRoleHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	if id <= 0 { | ||||
| func (h *ChatRoleHandler) Set(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id    uint        `json:"id"` | ||||
| 		Filed string      `json:"filed"` | ||||
| 		Value interface{} `json:"value"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.db.Where("id = ?", id).Delete(&model.ChatRole{}) | ||||
| 	res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *ChatRoleHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id uint | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	if data.Id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "删除失败!") | ||||
| 		return | ||||
|   | ||||
| @@ -14,36 +14,38 @@ import ( | ||||
|  | ||||
| type ConfigHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { | ||||
| 	h := ConfigHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *ConfigHandler) Update(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Key    string                 `json:"key"` | ||||
| 		Config map[string]interface{} `json:"config"` | ||||
| 		Key    string `json:"key"` | ||||
| 		Config struct { | ||||
| 			types.SystemConfig | ||||
| 			Content string `json:"content,omitempty"` | ||||
| 			Updated bool   `json:"updated,omitempty"` | ||||
| 		} `json:"config"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	str := utils.JsonEncode(&data.Config) | ||||
| 	config := model.Config{Key: data.Key, Config: str} | ||||
| 	res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key}) | ||||
|  | ||||
| 	value := utils.JsonEncode(&data.Config) | ||||
| 	config := model.Config{Key: data.Key, Config: value} | ||||
| 	res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if config.Id > 0 { | ||||
| 		config.Config = str | ||||
| 		res := h.db.Updates(&config) | ||||
| 		config.Config = value | ||||
| 		res := h.DB.Updates(&config) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, res.Error.Error()) | ||||
| 			return | ||||
| @@ -51,12 +53,10 @@ func (h *ConfigHandler) Update(c *gin.Context) { | ||||
|  | ||||
| 		// update config cache for AppServer | ||||
| 		var cfg model.Config | ||||
| 		h.db.Where("marker", data.Key).First(&cfg) | ||||
| 		h.DB.Where("marker", data.Key).First(&cfg) | ||||
| 		var err error | ||||
| 		if data.Key == "system" { | ||||
| 			err = utils.JsonDecode(cfg.Config, &h.App.SysConfig) | ||||
| 		} else if data.Key == "chat" { | ||||
| 			err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "Failed to update config cache: "+err.Error()) | ||||
| @@ -70,20 +70,25 @@ func (h *ConfigHandler) Update(c *gin.Context) { | ||||
|  | ||||
| // Get 获取指定的系统配置 | ||||
| func (h *ConfigHandler) Get(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	key := c.Query("key") | ||||
| 	var config model.Config | ||||
| 	res := h.db.Where("marker", key).First(&config) | ||||
| 	res := h.DB.Where("marker", key).First(&config) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var m map[string]interface{} | ||||
| 	err := utils.JsonDecode(config.Config, &m) | ||||
| 	var value map[string]interface{} | ||||
| 	err := utils.JsonDecode(config.Config, &value) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, m) | ||||
| 	resp.SUCCESS(c, value) | ||||
| } | ||||
|   | ||||
| @@ -2,30 +2,30 @@ package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/shopspring/decimal" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type DashboardHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler { | ||||
| 	h := DashboardHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| type statsVo struct { | ||||
| 	Users   int64   `json:"users"` | ||||
| 	Chats   int64   `json:"chats"` | ||||
| 	Tokens  int64   `json:"tokens"` | ||||
| 	Rewards float64 `json:"rewards"` | ||||
| 	Users  int64                         `json:"users"` | ||||
| 	Chats  int64                         `json:"chats"` | ||||
| 	Tokens int                           `json:"tokens"` | ||||
| 	Income float64                       `json:"income"` | ||||
| 	Chart  map[string]map[string]float64 `json:"chart"` | ||||
| } | ||||
|  | ||||
| func (h *DashboardHandler) Stats(c *gin.Context) { | ||||
| @@ -34,30 +34,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) { | ||||
| 	var userCount int64 | ||||
| 	now := time.Now() | ||||
| 	zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) | ||||
| 	res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount) | ||||
| 	res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount) | ||||
| 	if res.Error == nil { | ||||
| 		stats.Users = userCount | ||||
| 	} | ||||
|  | ||||
| 	// new chats statistic | ||||
| 	var chatCount int64 | ||||
| 	res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount) | ||||
| 	res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount) | ||||
| 	if res.Error == nil { | ||||
| 		stats.Chats = chatCount | ||||
| 	} | ||||
|  | ||||
| 	// tokens took stats | ||||
| 	var tokenCount int64 | ||||
| 	res = h.db.Model(&model.HistoryMessage{}).Select("sum(tokens) as total").Where("created_at > ?", zeroTime).Scan(&tokenCount) | ||||
| 	if res.Error == nil { | ||||
| 		stats.Tokens = tokenCount | ||||
| 	var historyMessages []model.ChatMessage | ||||
| 	res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages) | ||||
| 	for _, item := range historyMessages { | ||||
| 		stats.Tokens += item.Tokens | ||||
| 	} | ||||
|  | ||||
| 	// reward revenue | ||||
| 	var amount float64 | ||||
| 	res = h.db.Model(&model.Reward{}).Select("sum(amount) as total").Where("created_at > ?", zeroTime).Scan(&amount) | ||||
| 	if res.Error == nil { | ||||
| 		stats.Rewards = amount | ||||
| 	// 众筹收入 | ||||
| 	var rewards []model.Reward | ||||
| 	res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards) | ||||
| 	for _, item := range rewards { | ||||
| 		stats.Income += item.Amount | ||||
| 	} | ||||
|  | ||||
| 	// 订单收入 | ||||
| 	var orders []model.Order | ||||
| 	res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders) | ||||
| 	for _, item := range orders { | ||||
| 		stats.Income += item.Amount | ||||
| 	} | ||||
|  | ||||
| 	// 统计7天的订单的图表 | ||||
| 	startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02") | ||||
| 	var statsChart = make(map[string]map[string]float64) | ||||
| 	//// 初始化 | ||||
| 	var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64) | ||||
| 	for i := 0; i < 7; i++ { | ||||
| 		var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02") | ||||
| 		userStatistic[initTime] = float64(0) | ||||
| 		historyMessagesStatistic[initTime] = float64(0) | ||||
| 		incomeStatistic[initTime] = float64(0) | ||||
| 	} | ||||
|  | ||||
| 	// 统计用户7天增加的曲线 | ||||
| 	var users []model.User | ||||
| 	res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range users { | ||||
| 			userStatistic[item.CreatedAt.Format("2006-01-02")] += 1 | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 统计7天Token 消耗 | ||||
| 	res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages) | ||||
| 	for _, item := range historyMessages { | ||||
| 		historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens) | ||||
| 	} | ||||
|  | ||||
| 	// 浮点数相加? | ||||
| 	// 统计最近7天的众筹 | ||||
| 	res = h.DB.Where("created_at > ?", startDate).Find(&rewards) | ||||
| 	for _, item := range rewards { | ||||
| 		incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64() | ||||
| 	} | ||||
|  | ||||
| 	// 统计最近7天的订单 | ||||
| 	res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders) | ||||
| 	for _, item := range orders { | ||||
| 		incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64() | ||||
| 	} | ||||
|  | ||||
| 	statsChart["users"] = userStatistic | ||||
| 	statsChart["historyMessage"] = historyMessagesStatistic | ||||
| 	statsChart["orders"] = incomeStatistic | ||||
|  | ||||
| 	stats.Chart = statsChart | ||||
|  | ||||
| 	resp.SUCCESS(c, stats) | ||||
| } | ||||
|   | ||||
							
								
								
									
										126
									
								
								api/handler/admin/function_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								api/handler/admin/function_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type FunctionHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler { | ||||
| 	return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *FunctionHandler) Save(c *gin.Context) { | ||||
| 	var data vo.Function | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var f = model.Function{ | ||||
| 		Id:          data.Id, | ||||
| 		Name:        data.Name, | ||||
| 		Label:       data.Label, | ||||
| 		Description: data.Description, | ||||
| 		Parameters:  utils.JsonEncode(data.Parameters), | ||||
| 		Action:      data.Action, | ||||
| 		Token:       data.Token, | ||||
| 		Enabled:     data.Enabled, | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Save(&f) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "error with save data:"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	data.Id = f.Id | ||||
| 	resp.SUCCESS(c, data) | ||||
| } | ||||
|  | ||||
| func (h *FunctionHandler) Set(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id    uint        `json:"id"` | ||||
| 		Filed string      `json:"filed"` | ||||
| 		Value interface{} `json:"value"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *FunctionHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.Function | ||||
| 	res := h.DB.Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No data found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	functions := make([]vo.Function, 0) | ||||
| 	for _, v := range items { | ||||
| 		var f vo.Function | ||||
| 		err := utils.CopyObject(v, &f) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		functions = append(functions, f) | ||||
| 	} | ||||
| 	resp.SUCCESS(c, functions) | ||||
| } | ||||
|  | ||||
| func (h *FunctionHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		res := h.DB.Delete(&model.Function{Id: uint(id)}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // GenToken generate function api access token | ||||
| func (h *FunctionHandler) GenToken(c *gin.Context) { | ||||
| 	// 创建 token | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||
| 		"user_id": 0, | ||||
| 		"expired": 0, | ||||
| 	}) | ||||
| 	tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) | ||||
| 	if err != nil { | ||||
| 		logger.Error("error with generate token", err) | ||||
| 		resp.ERROR(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, tokenString) | ||||
| } | ||||
							
								
								
									
										100
									
								
								api/handler/admin/order_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								api/handler/admin/order_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type OrderHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { | ||||
| 	return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *OrderHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var data struct { | ||||
| 		OrderNo  string   `json:"order_no"` | ||||
| 		Status   int      `json:"status"` | ||||
| 		PayTime  []string `json:"pay_time"` | ||||
| 		Page     int      `json:"page"` | ||||
| 		PageSize int      `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if data.OrderNo != "" { | ||||
| 		session = session.Where("order_no", data.OrderNo) | ||||
| 	} | ||||
| 	if len(data.PayTime) == 2 { | ||||
| 		start := utils.Str2stamp(data.PayTime[0] + " 00:00:00") | ||||
| 		end := utils.Str2stamp(data.PayTime[1] + " 00:00:00") | ||||
| 		session = session.Where("pay_time >= ? AND pay_time <= ?", start, end) | ||||
| 	} | ||||
| 	if data.Status >= 0 { | ||||
| 		session = session.Where("status", data.Status) | ||||
| 	} | ||||
| 	var total int64 | ||||
| 	session.Model(&model.Order{}).Count(&total) | ||||
| 	var items []model.Order | ||||
| 	var list = make([]vo.Order, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var order vo.Order | ||||
| 			err := utils.CopyObject(item, &order) | ||||
| 			if err == nil { | ||||
| 				order.Id = item.Id | ||||
| 				order.CreatedAt = item.CreatedAt.Unix() | ||||
| 				order.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 				list = append(list, order) | ||||
| 			} else { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
|  | ||||
| func (h *OrderHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		var item model.Order | ||||
| 		res := h.DB.First(&item, id) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "记录不存在!") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if item.Status == types.OrderPaidSuccess { | ||||
| 			resp.ERROR(c, "已支付订单不允许删除!") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
							
								
								
									
										71
									
								
								api/handler/admin/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								api/handler/admin/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type PowerLogHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler { | ||||
| 	return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *PowerLogHandler) List(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Username string   `json:"username"` | ||||
| 		Type     int      `json:"type"` | ||||
| 		Model    string   `json:"model"` | ||||
| 		Date     []string `json:"date"` | ||||
| 		Page     int      `json:"page"` | ||||
| 		PageSize int      `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if data.Model != "" { | ||||
| 		session = session.Where("model", data.Model) | ||||
| 	} | ||||
| 	if data.Type > 0 { | ||||
| 		session = session.Where("type", data.Type) | ||||
| 	} | ||||
| 	if len(data.Date) == 2 { | ||||
| 		start := data.Date[0] + " 00:00:00" | ||||
| 		end := data.Date[1] + " 00:00:00" | ||||
| 		session = session.Where("created_at >= ? AND created_at <= ?", start, end) | ||||
| 	} | ||||
|  | ||||
| 	var total int64 | ||||
| 	session.Model(&model.PowerLog{}).Count(&total) | ||||
| 	var items []model.PowerLog | ||||
| 	var list = make([]vo.PowerLog, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var log vo.PowerLog | ||||
| 			err := utils.CopyObject(item, &log) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			log.Id = item.Id | ||||
| 			log.CreatedAt = item.CreatedAt.Unix() | ||||
| 			log.TypeStr = item.Type.String() | ||||
| 			list = append(list, log) | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
							
								
								
									
										152
									
								
								api/handler/admin/product_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								api/handler/admin/product_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type ProductHandler struct { | ||||
| 	handler.BaseHandler | ||||
| } | ||||
|  | ||||
| func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler { | ||||
| 	return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *ProductHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id        uint    `json:"id"` | ||||
| 		Name      string  `json:"name"` | ||||
| 		Price     float64 `json:"price"` | ||||
| 		Discount  float64 `json:"discount"` | ||||
| 		Enabled   bool    `json:"enabled"` | ||||
| 		Days      int     `json:"days"` | ||||
| 		Power     int     `json:"power"` | ||||
| 		CreatedAt int64   `json:"created_at"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	item := model.Product{ | ||||
| 		Name:     data.Name, | ||||
| 		Price:    data.Price, | ||||
| 		Discount: data.Discount, | ||||
| 		Days:     data.Days, | ||||
| 		Power:    data.Power, | ||||
| 		Enabled:  data.Enabled} | ||||
| 	item.Id = data.Id | ||||
| 	if item.Id > 0 { | ||||
| 		item.CreatedAt = time.Unix(data.CreatedAt, 0) | ||||
| 	} | ||||
| 	res := h.DB.Save(&item) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var itemVo vo.Product | ||||
| 	err := utils.CopyObject(item, &itemVo) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "数据拷贝失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	itemVo.Id = item.Id | ||||
| 	itemVo.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 	resp.SUCCESS(c, itemVo) | ||||
| } | ||||
|  | ||||
| // List 模型列表 | ||||
| func (h *ProductHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	enable := h.GetBool(c, "enable") | ||||
| 	if enable { | ||||
| 		session = session.Where("enabled", enable) | ||||
| 	} | ||||
| 	var items []model.Product | ||||
| 	var list = make([]vo.Product, 0) | ||||
| 	res := session.Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var product vo.Product | ||||
| 			err := utils.CopyObject(item, &product) | ||||
| 			if err == nil { | ||||
| 				product.Id = item.Id | ||||
| 				product.CreatedAt = item.CreatedAt.Unix() | ||||
| 				product.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 				list = append(list, product) | ||||
| 			} else { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, list) | ||||
| } | ||||
|  | ||||
| func (h *ProductHandler) Enable(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id      uint `json:"id"` | ||||
| 		Enabled bool `json:"enabled"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *ProductHandler) Sort(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Ids   []uint `json:"ids"` | ||||
| 		Sorts []int  `json:"sorts"` | ||||
| 	} | ||||
|  | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for index, id := range data.Ids { | ||||
| 		res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| func (h *ProductHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
|  | ||||
| 	if id > 0 { | ||||
| 		res := h.DB.Where("id = ?", id).Delete(&model.Product{}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
| @@ -2,6 +2,7 @@ package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| @@ -13,18 +14,20 @@ import ( | ||||
|  | ||||
| type RewardHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { | ||||
| 	h := RewardHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *RewardHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var items []model.Reward | ||||
| 	res := h.db.Order("id DESC").Find(&items) | ||||
| 	res := h.DB.Order("id DESC").Find(&items) | ||||
| 	var rewards = make([]vo.Reward, 0) | ||||
| 	if res.Error == nil { | ||||
| 		userIds := make([]uint, 0) | ||||
| @@ -32,7 +35,7 @@ func (h *RewardHandler) List(c *gin.Context) { | ||||
| 			userIds = append(userIds, v.UserId) | ||||
| 		} | ||||
| 		var users []model.User | ||||
| 		h.db.Where("id IN ?", userIds).Find(&users) | ||||
| 		h.DB.Where("id IN ?", userIds).Find(&users) | ||||
| 		var userMap = make(map[uint]model.User) | ||||
| 		for _, u := range users { | ||||
| 			userMap[u.Id] = u | ||||
| @@ -46,7 +49,7 @@ func (h *RewardHandler) List(c *gin.Context) { | ||||
| 			} | ||||
|  | ||||
| 			r.Id = v.Id | ||||
| 			r.Username = userMap[v.UserId].Mobile | ||||
| 			r.Username = userMap[v.UserId].Username | ||||
| 			r.CreatedAt = v.CreatedAt.Unix() | ||||
| 			r.UpdatedAt = v.UpdatedAt.Unix() | ||||
| 			rewards = append(rewards, r) | ||||
| @@ -55,3 +58,21 @@ func (h *RewardHandler) List(c *gin.Context) { | ||||
|  | ||||
| 	resp.SUCCESS(c, rewards) | ||||
| } | ||||
|  | ||||
| func (h *RewardHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id uint | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	if data.Id > 0 { | ||||
| 		res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
							
								
								
									
										45
									
								
								api/handler/admin/upload_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								api/handler/admin/upload_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| package admin | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/handler" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type UploadHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	uploaderManager *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler { | ||||
| 	return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager} | ||||
| } | ||||
|  | ||||
| func (h *UploadHandler) Upload(c *gin.Context) { | ||||
| 	file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	userId := 0 | ||||
| 	res := h.DB.Create(&model.File{ | ||||
| 		UserId:    userId, | ||||
| 		Name:      file.Name, | ||||
| 		ObjKey:    file.ObjKey, | ||||
| 		URL:       file.URL, | ||||
| 		Ext:       file.Ext, | ||||
| 		Size:      file.Size, | ||||
| 		CreatedAt: time.Time{}, | ||||
| 	}) | ||||
| 	if res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "error with update database: "+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, file) | ||||
| } | ||||
| @@ -8,35 +8,40 @@ import ( | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type UserHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler { | ||||
| 	h := UserHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // List 用户列表 | ||||
| func (h *UserHandler) List(c *gin.Context) { | ||||
| 	if err := utils.CheckPermission(c, h.DB); err != nil { | ||||
| 		resp.NotPermission(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	page := h.GetInt(c, "page", 1) | ||||
| 	pageSize := h.GetInt(c, "page_size", 20) | ||||
| 	mobile := h.GetTrim(c, "mobile") | ||||
| 	username := h.GetTrim(c, "username") | ||||
|  | ||||
| 	offset := (page - 1) * pageSize | ||||
| 	var items []model.User | ||||
| 	var users = make([]vo.User, 0) | ||||
| 	var total int64 | ||||
|  | ||||
| 	session := h.db.Session(&gorm.Session{}) | ||||
| 	if mobile != "" { | ||||
| 		session = session.Where("mobile LIKE ?", "%"+mobile+"%") | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if username != "" { | ||||
| 		session = session.Where("username LIKE ?", "%"+username+"%") | ||||
| 	} | ||||
|  | ||||
| 	session.Model(&model.User{}).Count(&total) | ||||
| @@ -63,13 +68,13 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id          uint     `json:"id"` | ||||
| 		Password    string   `json:"password"` | ||||
| 		Mobile      string   `json:"mobile"` | ||||
| 		Calls       int      `json:"calls"` | ||||
| 		ImgCalls    int      `json:"img_calls"` | ||||
| 		Username    string   `json:"username"` | ||||
| 		ChatRoles   []string `json:"chat_roles"` | ||||
| 		ChatModels  []string `json:"chat_models"` | ||||
| 		ChatModels  []int    `json:"chat_models"` | ||||
| 		ExpiredTime string   `json:"expired_time"` | ||||
| 		Status      bool     `json:"status"` | ||||
| 		Vip         bool     `json:"vip"` | ||||
| 		Power       int      `json:"power"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -79,39 +84,60 @@ func (h *UserHandler) Save(c *gin.Context) { | ||||
| 	var res *gorm.DB | ||||
| 	var userVo vo.User | ||||
| 	if data.Id > 0 { // 更新 | ||||
| 		user.Id = data.Id | ||||
| 		// 此处需要用 map 更新,用结构体无法更新 0 值 | ||||
| 		res = h.db.Model(&user).Updates(map[string]interface{}{ | ||||
| 			"mobile":           data.Mobile, | ||||
| 			"calls":            data.Calls, | ||||
| 			"img_calls":        data.ImgCalls, | ||||
| 			"status":           data.Status, | ||||
| 			"chat_roles_json":  utils.JsonEncode(data.ChatRoles), | ||||
| 			"chat_models_json": utils.JsonEncode(data.ChatModels), | ||||
| 			"expired_time":     utils.Str2stamp(data.ExpiredTime), | ||||
| 		}) | ||||
| 		res = h.DB.Where("id", data.Id).First(&user) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "user not found") | ||||
| 			return | ||||
| 		} | ||||
| 		var oldPower = user.Power | ||||
| 		user.Username = data.Username | ||||
| 		user.Status = data.Status | ||||
| 		user.Vip = data.Vip | ||||
| 		user.Power = data.Power | ||||
| 		user.ChatRoles = utils.JsonEncode(data.ChatRoles) | ||||
| 		user.ChatModels = utils.JsonEncode(data.ChatModels) | ||||
| 		user.ExpiredTime = utils.Str2stamp(data.ExpiredTime) | ||||
|  | ||||
| 		res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "更新数据库失败!") | ||||
| 			return | ||||
| 		} | ||||
| 		// 记录算力日志 | ||||
| 		if oldPower != user.Power { | ||||
| 			mark := types.PowerAdd | ||||
| 			amount := user.Power - oldPower | ||||
| 			if oldPower > user.Power { | ||||
| 				mark = types.PowerSub | ||||
| 				amount = oldPower - user.Power | ||||
| 			} | ||||
| 			h.DB.Create(&model.PowerLog{ | ||||
| 				UserId:    user.Id, | ||||
| 				Username:  user.Username, | ||||
| 				Type:      types.PowerGift, | ||||
| 				Amount:    amount, | ||||
| 				Balance:   user.Power, | ||||
| 				Mark:      mark, | ||||
| 				Model:     "管理员", | ||||
| 				Remark:    fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID:%d", oldPower, user.Power, h.GetLoginUserId(c)), | ||||
| 				CreatedAt: time.Now(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} else { | ||||
| 		salt := utils.RandString(8) | ||||
| 		u := model.User{ | ||||
| 			Mobile:      data.Mobile, | ||||
| 			Username:    data.Username, | ||||
| 			Nickname:    fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)), | ||||
| 			Password:    utils.GenPassword(data.Password, salt), | ||||
| 			Avatar:      "/images/avatar/user.png", | ||||
| 			Salt:        salt, | ||||
| 			Power:       data.Power, | ||||
| 			Status:      true, | ||||
| 			ChatRoles:   utils.JsonEncode(data.ChatRoles), | ||||
| 			ChatModels:  utils.JsonEncode(data.ChatModels), | ||||
| 			ExpiredTime: utils.Str2stamp(data.ExpiredTime), | ||||
| 			ChatConfig: utils.JsonEncode(types.UserChatConfig{ | ||||
| 				ApiKeys: map[types.Platform]string{ | ||||
| 					types.OpenAI:  "", | ||||
| 					types.Azure:   "", | ||||
| 					types.ChatGLM: "", | ||||
| 				}, | ||||
| 			}), | ||||
| 			Calls:    data.Calls, | ||||
| 			ImgCalls: data.ImgCalls, | ||||
| 		} | ||||
| 		res = h.db.Create(&u) | ||||
| 		res = h.DB.Create(&u) | ||||
| 		_ = utils.CopyObject(u, &userVo) | ||||
| 		userVo.Id = u.Id | ||||
| 		userVo.CreatedAt = u.CreatedAt.Unix() | ||||
| @@ -138,7 +164,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	var user model.User | ||||
| 	res := h.db.First(&user, data.Id) | ||||
| 	res := h.DB.First(&user, data.Id) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No user found") | ||||
| 		return | ||||
| @@ -146,7 +172,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) { | ||||
|  | ||||
| 	password := utils.GenPassword(data.Password, user.Salt) | ||||
| 	user.Password = password | ||||
| 	res = h.db.Updates(&user) | ||||
| 	res = h.DB.Updates(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c) | ||||
| 	} else { | ||||
| @@ -156,36 +182,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) { | ||||
|  | ||||
| func (h *UserHandler) Remove(c *gin.Context) { | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	if id > 0 { | ||||
| 		tx := h.db.Begin() | ||||
| 		res := h.db.Where("id = ?", id).Delete(&model.User{}) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "删除失败") | ||||
| 			return | ||||
| 		} | ||||
| 		// 删除聊天记录 | ||||
| 		res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{}) | ||||
| 		if res.Error != nil { | ||||
| 			tx.Rollback() | ||||
| 			resp.ERROR(c, "删除失败") | ||||
| 			return | ||||
| 		} | ||||
| 		// 删除聊天历史记录 | ||||
| 		res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{}) | ||||
| 		if res.Error != nil { | ||||
| 			tx.Rollback() | ||||
| 			resp.ERROR(c, "删除失败") | ||||
| 			return | ||||
| 		} | ||||
| 		// 删除登录日志 | ||||
| 		res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{}) | ||||
| 		if res.Error != nil { | ||||
| 			tx.Rollback() | ||||
| 			resp.ERROR(c, "删除失败") | ||||
| 			return | ||||
| 		} | ||||
| 		tx.Commit() | ||||
| 	if id <= 0 { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	// 删除用户 | ||||
| 	res := h.DB.Where("id = ?", id).Delete(&model.User{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "删除失败") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 删除聊天记录 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{}) | ||||
| 	// 删除聊天历史记录 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{}) | ||||
| 	// 删除登录日志 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{}) | ||||
| 	// 删除算力日志 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{}) | ||||
| 	// 删除众筹日志 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.Reward{}) | ||||
| 	// 删除绘图任务 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}) | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.SdJob{}) | ||||
| 	//  删除订单 | ||||
| 	h.DB.Where("user_id = ?", id).Delete(&model.Order{}) | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| @@ -193,10 +215,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) { | ||||
| 	page := h.GetInt(c, "page", 1) | ||||
| 	pageSize := h.GetInt(c, "page_size", 20) | ||||
| 	var total int64 | ||||
| 	h.db.Model(&model.UserLoginLog{}).Count(&total) | ||||
| 	h.DB.Model(&model.UserLoginLog{}).Count(&total) | ||||
| 	offset := (page - 1) * pageSize | ||||
| 	var items []model.UserLoginLog | ||||
| 	res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items) | ||||
| 	res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "获取数据失败") | ||||
| 		return | ||||
|   | ||||
| @@ -4,8 +4,11 @@ import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -15,6 +18,7 @@ var logger = logger2.GetLogger() | ||||
|  | ||||
| type BaseHandler struct { | ||||
| 	App *core.AppServer | ||||
| 	DB  *gorm.DB | ||||
| } | ||||
|  | ||||
| func (h *BaseHandler) GetTrim(c *gin.Context, key string) string { | ||||
| @@ -49,3 +53,35 @@ func (h *BaseHandler) GetUserKey(c *gin.Context) string { | ||||
| 	} | ||||
| 	return fmt.Sprintf("users/%v", userId) | ||||
| } | ||||
|  | ||||
| func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint { | ||||
| 	userId, ok := c.Get(types.LoginUserID) | ||||
| 	if !ok { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return uint(utils.IntValue(utils.InterfaceToString(userId), 0)) | ||||
| } | ||||
|  | ||||
| func (h *BaseHandler) IsLogin(c *gin.Context) bool { | ||||
| 	return h.GetLoginUserId(c) > 0 | ||||
| } | ||||
|  | ||||
| func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) { | ||||
| 	value, exists := c.Get(types.LoginUserCache) | ||||
| 	if exists { | ||||
| 		return value.(model.User), nil | ||||
| 	} | ||||
|  | ||||
| 	userId, ok := c.Get(types.LoginUserID) | ||||
| 	if !ok { | ||||
| 		return model.User{}, errors.New("user not login") | ||||
| 	} | ||||
|  | ||||
| 	var user model.User | ||||
| 	res := h.DB.First(&user, userId) | ||||
| 	// 更新缓存 | ||||
| 	if res.Error == nil { | ||||
| 		c.Set(types.LoginUserCache, user) | ||||
| 	} | ||||
| 	return user, res.Error | ||||
| } | ||||
|   | ||||
| @@ -12,34 +12,34 @@ import ( | ||||
|  | ||||
| type ChatModelHandler struct { | ||||
| 	BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { | ||||
| 	h := ChatModelHandler{db: db} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| 	return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // List 模型列表 | ||||
| func (h *ChatModelHandler) List(c *gin.Context) { | ||||
| 	var items []model.ChatModel | ||||
| 	var chatModels = make([]vo.ChatModel, 0) | ||||
| 	// 只加载用户订阅的 AI 模型 | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	var res *gorm.DB | ||||
| 	// 如果用户没有登录,则加载所有开放模型 | ||||
| 	if !h.IsLogin(c) { | ||||
| 		res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items) | ||||
| 	} else { | ||||
| 		user, _ := h.GetLoginUser(c) | ||||
| 		var models []int | ||||
| 		err := utils.JsonDecode(user.ChatModels, &models) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "当前用户没有订阅任何模型") | ||||
| 			return | ||||
| 		} | ||||
| 		// 查询用户有权限访问的模型以及所有开放的模型 | ||||
| 		res = h.DB.Where("enabled = ?", true).Where( | ||||
| 			h.DB.Where("id IN ?", models).Or("open =?", true), | ||||
| 		).Order("sort_num ASC").Find(&items) | ||||
| 	} | ||||
|  | ||||
| 	var models []string | ||||
| 	err = utils.JsonDecode(user.ChatModels, &models) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "当前用户没有订阅任何模型") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.db.Where("enabled = ?", true).Where("value IN ?", models).Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var cm vo.ChatModel | ||||
|   | ||||
| @@ -14,27 +14,25 @@ import ( | ||||
|  | ||||
| type ChatRoleHandler struct { | ||||
| 	BaseHandler | ||||
| 	db *gorm.DB | ||||
| } | ||||
|  | ||||
| func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { | ||||
| 	handler := &ChatRoleHandler{db: db} | ||||
| 	handler.App = app | ||||
| 	return handler | ||||
| 	return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // List get user list | ||||
| // List 获取用户聊天应用列表 | ||||
| func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 	all := h.GetBool(c, "all") | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var roles []model.ChatRole | ||||
| 	res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles) | ||||
| 	res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No roles found,"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取所有角色 | ||||
| 	if all { | ||||
| 	if userId == 0 || all { | ||||
| 		// 转成 vo | ||||
| 		var roleVos = make([]vo.ChatRole, 0) | ||||
| 		for _, r := range roles { | ||||
| @@ -49,13 +47,8 @@ func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	if userId == 0 { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	h.db.First(&user, userId) | ||||
| 	h.DB.First(&user, userId) | ||||
| 	var roleKeys []string | ||||
| 	err := utils.JsonDecode(user.ChatRoles, &roleKeys) | ||||
| 	if err != nil { | ||||
| @@ -80,7 +73,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) { | ||||
|  | ||||
| // UpdateRole 更新用户聊天角色 | ||||
| func (h *ChatRoleHandler) UpdateRole(c *gin.Context) { | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| @@ -94,7 +87,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) | ||||
| 	res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("添加应用失败:", err) | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
|   | ||||
| @@ -9,7 +9,7 @@ import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -19,7 +19,7 @@ import ( | ||||
| // 微软 Azure 模型消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendAzureMessage( | ||||
| 	chatCtx []interface{}, | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| @@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage( | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| @@ -44,7 +44,7 @@ func (h *ChatHandler) sendAzureMessage( | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| @@ -56,9 +56,6 @@ func (h *ChatHandler) sendAzureMessage( | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		var functionCall = false | ||||
| 		var functionName string | ||||
| 		var arguments = make([]string, 0) | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| @@ -68,34 +65,17 @@ func (h *ChatHandler) sendAzureMessage( | ||||
|  | ||||
| 			var responseBody = types.ApiResponse{} | ||||
| 			err = json.Unmarshal([]byte(line[6:]), &responseBody) | ||||
| 			if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 | ||||
| 			if err != nil { // 数据解析出错 | ||||
| 				logger.Error(err, line) | ||||
| 				utils.ReplyMessage(ws, ErrorMsg) | ||||
| 				utils.ReplyMessage(ws, "") | ||||
| 				utils.ReplyMessage(ws, ErrImg) | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			fun := responseBody.Choices[0].Delta.FunctionCall | ||||
| 			if functionCall && fun.Name == "" { | ||||
| 				arguments = append(arguments, fun.Arguments) | ||||
| 			if len(responseBody.Choices) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if !utils.IsEmptyValue(fun) { | ||||
| 				functionName = fun.Name | ||||
| 				f := h.App.Functions[functionName] | ||||
| 				if f != nil { | ||||
| 					functionCall = true | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())}) | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			// 初始化 role | ||||
| 			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { | ||||
| 				message.Role = responseBody.Choices[0].Delta.Role | ||||
| @@ -121,54 +101,8 @@ func (h *ChatHandler) sendAzureMessage( | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if functionCall { // 调用函数完成任务 | ||||
| 			var params map[string]interface{} | ||||
| 			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) | ||||
| 			logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params) | ||||
|  | ||||
| 			// for creating image, check if the user's img_calls > 0 | ||||
| 			if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 { | ||||
| 				utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**") | ||||
| 				utils.ReplyMessage(ws, "") | ||||
| 			} else { | ||||
| 				f := h.App.Functions[functionName] | ||||
| 				if functionName == types.FuncMidJourney { | ||||
| 					params["user_id"] = userVo.Id | ||||
| 					params["role_id"] = role.Id | ||||
| 					params["chat_id"] = session.ChatId | ||||
| 					params["icon"] = "/images/avatar/mid_journey.png" | ||||
| 					params["session_id"] = session.SessionId | ||||
| 				} | ||||
| 				data, err := f.Invoke(params) | ||||
| 				if err != nil { | ||||
| 					msg := "调用函数出错:" + err.Error() | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 						Type:    types.WsMiddle, | ||||
| 						Content: msg, | ||||
| 					}) | ||||
| 					contents = append(contents, msg) | ||||
| 				} else { | ||||
| 					content := data | ||||
| 					if functionName == types.FuncMidJourney { | ||||
| 						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) | ||||
| 						h.mjService.ChatClients.Put(session.SessionId, ws) | ||||
| 						// update user's img_calls | ||||
| 						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) | ||||
| 					} | ||||
|  | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 						Type:    types.WsMiddle, | ||||
| 						Content: content, | ||||
| 					}) | ||||
| 					contents = append(contents, content) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			// 更新用户的对话次数 | ||||
| 			h.subUserCalls(userVo, session) | ||||
|  | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| @@ -177,78 +111,64 @@ func (h *ChatHandler) sendAzureMessage( | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.ChatConfig.EnableContext && functionCall == false { | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			if h.App.ChatConfig.EnableHistory { | ||||
| 				useContext := true | ||||
| 				if functionCall { | ||||
| 					useContext = false | ||||
| 				} | ||||
|  | ||||
| 				// for prompt | ||||
| 				promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 				if err != nil { | ||||
| 					logger.Error(err) | ||||
| 				} | ||||
| 				historyUserMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.PromptMsg, | ||||
| 					Icon:       userVo.Avatar, | ||||
| 					Content:    prompt, | ||||
| 					Tokens:     promptToken, | ||||
| 					UseContext: useContext, | ||||
| 				} | ||||
| 				historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 				historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 				res := h.db.Save(&historyUserMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 				} | ||||
|  | ||||
| 				// 计算本次对话消耗的总 token 数量 | ||||
| 				var totalTokens = 0 | ||||
| 				if functionCall { // prompt + 函数名 + 参数 token | ||||
| 					tokens, _ := utils.CalcTokens(functionName, req.Model) | ||||
| 					totalTokens += tokens | ||||
| 					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) | ||||
| 					totalTokens += tokens | ||||
| 				} else { | ||||
| 					totalTokens, _ = utils.CalcTokens(message.Content, req.Model) | ||||
| 				} | ||||
| 				totalTokens += getTotalTokens(req) | ||||
|  | ||||
| 				historyReplyMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.ReplyMsg, | ||||
| 					Icon:       role.Icon, | ||||
| 					Content:    message.Content, | ||||
| 					Tokens:     totalTokens, | ||||
| 					UseContext: useContext, | ||||
| 				} | ||||
| 				historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 				historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 				res = h.db.Create(&historyReplyMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save reply history message: ", res.Error) | ||||
| 				} | ||||
|  | ||||
| 				// 更新用户信息 | ||||
| 				h.db.Model(&model.User{}).Where("id = ?", userVo.Id). | ||||
| 					UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			replyTokens += getTotalTokens(req) | ||||
|  | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     replyTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| @@ -259,7 +179,8 @@ func (h *ChatHandler) sendAzureMessage( | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				h.db.Create(&chatItem) | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
|   | ||||
| @@ -9,7 +9,7 @@ import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -36,7 +36,7 @@ type baiduResp struct { | ||||
| // 百度文心一言消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendBaiduMessage( | ||||
| 	chatCtx []interface{}, | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| @@ -46,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage( | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| @@ -61,7 +61,7 @@ func (h *ChatHandler) sendBaiduMessage( | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| @@ -85,6 +85,11 @@ func (h *ChatHandler) sendBaiduMessage( | ||||
| 				content = line[5:] | ||||
| 			} | ||||
|  | ||||
| 			// 处理代码换行 | ||||
| 			if len(content) == 0 { | ||||
| 				content = "\n" | ||||
| 			} | ||||
|  | ||||
| 			var resp baiduResp | ||||
| 			err := utils.JsonDecode(content, &resp) | ||||
| 			if err != nil { | ||||
| @@ -123,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage( | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			// 更新用户的对话次数 | ||||
| 			h.subUserCalls(userVo, session) | ||||
|  | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| @@ -133,64 +135,63 @@ func (h *ChatHandler) sendBaiduMessage( | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.ChatConfig.EnableContext { | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			if h.App.ChatConfig.EnableHistory { | ||||
| 				// for prompt | ||||
| 				promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 				if err != nil { | ||||
| 					logger.Error(err) | ||||
| 				} | ||||
| 				historyUserMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.PromptMsg, | ||||
| 					Icon:       userVo.Avatar, | ||||
| 					Content:    prompt, | ||||
| 					Tokens:     promptToken, | ||||
| 					UseContext: true, | ||||
| 				} | ||||
| 				historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 				historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 				res := h.db.Save(&historyUserMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 				} | ||||
|  | ||||
| 				// for reply | ||||
| 				// 计算本次对话消耗的总 token 数量 | ||||
| 				replyToken, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 				totalTokens := replyToken + getTotalTokens(req) | ||||
| 				historyReplyMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.ReplyMsg, | ||||
| 					Icon:       role.Icon, | ||||
| 					Content:    message.Content, | ||||
| 					Tokens:     totalTokens, | ||||
| 					UseContext: true, | ||||
| 				} | ||||
| 				historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 				historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 				res = h.db.Create(&historyReplyMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save reply history message: ", res.Error) | ||||
| 				} | ||||
| 				// 更新用户信息 | ||||
| 				h.db.Model(&model.User{}).Where("id = ?", userVo.Id). | ||||
| 					UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyTokens + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| @@ -201,7 +202,8 @@ func (h *ChatHandler) sendBaiduMessage( | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				h.db.Create(&chatItem) | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
|   | ||||
| @@ -6,8 +6,7 @@ import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/handler" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/service/mj" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| @@ -16,40 +15,44 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" | ||||
|  | ||||
| var ErrImg = "" | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type ChatHandler struct { | ||||
| 	handler.BaseHandler | ||||
| 	db        *gorm.DB | ||||
| 	leveldb   *store.LevelDB | ||||
| 	redis     *redis.Client | ||||
| 	mjService *mj.Service | ||||
| 	redis         *redis.Client | ||||
| 	uploadManager *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler { | ||||
| 	h := ChatHandler{ | ||||
| 		db:        db, | ||||
| 		leveldb:   levelDB, | ||||
| 		redis:     redis, | ||||
| 		mjService: service, | ||||
| func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler { | ||||
| 	return &ChatHandler{ | ||||
| 		BaseHandler:   handler.BaseHandler{App: app, DB: db}, | ||||
| 		redis:         redis, | ||||
| 		uploadManager: manager, | ||||
| 	} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| } | ||||
|  | ||||
| var chatConfig types.ChatConfig | ||||
| func (h *ChatHandler) Init() { | ||||
| 	// 如果后台有上传微信客服微信二维码,则覆盖 | ||||
| 	if h.App.SysConfig.WechatCardURL != "" { | ||||
| 		ErrImg = fmt.Sprintf("", h.App.SysConfig.WechatCardURL) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ChatHandle 处理聊天 WebSocket 请求 | ||||
| func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| @@ -67,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	// get model info | ||||
| 	var chatModel model.ChatModel | ||||
| 	res := h.db.First(&chatModel, modelId) | ||||
| 	res := h.DB.First(&chatModel, modelId) | ||||
| 	if res.Error != nil || chatModel.Enabled == false { | ||||
| 		utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!") | ||||
| 		c.Abort() | ||||
| @@ -76,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
|  | ||||
| 	session := h.App.ChatSession.Get(sessionId) | ||||
| 	if session == nil { | ||||
| 		user, err := utils.GetLoginUser(c, h.db) | ||||
| 		user, err := h.GetLoginUser(c) | ||||
| 		if err != nil { | ||||
| 			logger.Info("用户未登录") | ||||
| 			c.Abort() | ||||
| @@ -85,7 +88,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 		session = &types.ChatSession{ | ||||
| 			SessionId: sessionId, | ||||
| 			ClientIP:  c.ClientIP(), | ||||
| 			Username:  user.Mobile, | ||||
| 			Username:  user.Username, | ||||
| 			UserId:    user.Id, | ||||
| 		} | ||||
| 		h.App.ChatSession.Put(sessionId, session) | ||||
| @@ -93,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
|  | ||||
| 	// use old chat data override the chat model and role ID | ||||
| 	var chat model.ChatItem | ||||
| 	res = h.db.Where("chat_id=?", chatId).First(&chat) | ||||
| 	res = h.DB.Where("chat_id = ?", chatId).First(&chat) | ||||
| 	if res.Error == nil { | ||||
| 		chatModel.Id = chat.ModelId | ||||
| 		roleId = int(chat.RoleId) | ||||
| @@ -101,28 +104,24 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
|  | ||||
| 	session.ChatId = chatId | ||||
| 	session.Model = types.ChatModel{ | ||||
| 		Id:       chatModel.Id, | ||||
| 		Value:    chatModel.Value, | ||||
| 		Weight:   chatModel.Weight, | ||||
| 		Platform: types.Platform(chatModel.Platform)} | ||||
| 		Id:          chatModel.Id, | ||||
| 		Name:        chatModel.Name, | ||||
| 		Value:       chatModel.Value, | ||||
| 		Power:       chatModel.Power, | ||||
| 		MaxTokens:   chatModel.MaxTokens, | ||||
| 		MaxContext:  chatModel.MaxContext, | ||||
| 		Temperature: chatModel.Temperature, | ||||
| 		Platform:    types.Platform(chatModel.Platform)} | ||||
| 	logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) | ||||
| 	var chatRole model.ChatRole | ||||
| 	res = h.db.First(&chatRole, roleId) | ||||
| 	res = h.DB.First(&chatRole, roleId) | ||||
| 	if res.Error != nil || !chatRole.Enable { | ||||
| 		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 初始化聊天配置 | ||||
| 	var config model.Config | ||||
| 	h.db.Where("marker", "chat").First(&config) | ||||
| 	err = utils.JsonDecode(config.Config, &chatConfig) | ||||
| 	if err != nil { | ||||
| 		utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
| 	h.Init() | ||||
|  | ||||
| 	// 保存会话连接 | ||||
| 	h.App.ChatClients.Put(sessionId, client) | ||||
| @@ -130,7 +129,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 		for { | ||||
| 			_, msg, err := client.Receive() | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 				client.Close() | ||||
| 				h.App.ChatClients.Delete(sessionId) | ||||
| 				cancelFunc := h.App.ReqCancelFunc.Get(sessionId) | ||||
| @@ -141,19 +139,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			message := string(msg) | ||||
| 			logger.Info("Receive a message: ", message) | ||||
| 			//utils.ReplyMessage(client, "这是一条测试消息!") | ||||
| 			var message types.WsMessage | ||||
| 			err = utils.JsonDecode(string(msg), &message) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// 心跳消息 | ||||
| 			if message.Type == "heartbeat" { | ||||
| 				logger.Debug("收到 Chat 心跳消息:", message.Content) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			logger.Info("Receive a message: ", message.Content) | ||||
|  | ||||
| 			ctx, cancel := context.WithCancel(context.Background()) | ||||
| 			h.App.ReqCancelFunc.Put(sessionId, cancel) | ||||
| 			// 回复消息 | ||||
| 			err = h.sendMessage(ctx, session, chatRole, message, client) | ||||
| 			err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) | ||||
| 			} else { | ||||
| 				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) | ||||
| 				logger.Info("回答完毕: " + string(message)) | ||||
| 				logger.Infof("回答完毕: %v", message.Content) | ||||
| 			} | ||||
|  | ||||
| 		} | ||||
| @@ -161,16 +170,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			logger.Error("Recover message from error: ", r) | ||||
| 		} | ||||
| 	}() | ||||
| 	if !h.App.Debug { | ||||
| 		defer func() { | ||||
| 			if r := recover(); r != nil { | ||||
| 				logger.Error("Recover message from error: ", r) | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	var user model.User | ||||
| 	res := h.db.Model(&model.User{}).First(&user, session.UserId) | ||||
| 	res := h.DB.Model(&model.User{}).First(&user, session.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		utils.ReplyMessage(ws, "非法用户,请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!") | ||||
| 		return res.Error | ||||
| 	} | ||||
| 	var userVo vo.User | ||||
| @@ -182,98 +193,99 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
|  | ||||
| 	if userVo.Status == false { | ||||
| 		utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { | ||||
| 		utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!") | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 	if userVo.Power < session.Model.Power { | ||||
| 		utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型的单次对话需要消耗的算力(%d)!", userVo.Power, session.Model.Power)) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() { | ||||
| 		utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度 | ||||
| 	promptTokens, err := utils.CalcTokens(prompt, session.Model.Value) | ||||
| 	if promptTokens > session.Model.MaxContext { | ||||
| 		utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!") | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var req = types.ApiRequest{ | ||||
| 		Model:  session.Model.Value, | ||||
| 		Stream: true, | ||||
| 	} | ||||
| 	switch session.Model.Platform { | ||||
| 	case types.Azure: | ||||
| 		req.Temperature = h.App.ChatConfig.Azure.Temperature | ||||
| 		req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens | ||||
| 	case types.Azure, types.ChatGLM, types.Baidu, types.XunFei: | ||||
| 		req.Temperature = session.Model.Temperature | ||||
| 		req.MaxTokens = session.Model.MaxTokens | ||||
| 		break | ||||
| 	case types.ChatGLM: | ||||
| 		req.Temperature = h.App.ChatConfig.ChatGML.Temperature | ||||
| 		req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens | ||||
| 		break | ||||
| 	case types.Baidu: | ||||
| 		req.Temperature = h.App.ChatConfig.OpenAI.Temperature | ||||
| 		// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持 | ||||
| 	case types.OpenAI: | ||||
| 		req.Temperature = h.App.ChatConfig.OpenAI.Temperature | ||||
| 		req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens | ||||
| 		req.Temperature = session.Model.Temperature | ||||
| 		req.MaxTokens = session.Model.MaxTokens | ||||
| 		// OpenAI 支持函数功能 | ||||
| 		if h.App.SysConfig.EnabledFunction { | ||||
| 			var functions = make([]types.Function, 0) | ||||
| 			for _, f := range types.InnerFunctions { | ||||
| 				if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney { | ||||
| 					continue | ||||
| 				} | ||||
| 				functions = append(functions, f) | ||||
| 			} | ||||
| 			req.Functions = functions | ||||
| 		var items []model.Function | ||||
| 		res := h.DB.Where("enabled", true).Find(&items) | ||||
| 		if res.Error != nil { | ||||
| 			break | ||||
| 		} | ||||
| 	case types.XunFei: | ||||
| 		req.Temperature = h.App.ChatConfig.XunFei.Temperature | ||||
| 		req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens | ||||
|  | ||||
| 		var tools = make([]interface{}, 0) | ||||
| 		for _, v := range items { | ||||
| 			var parameters map[string]interface{} | ||||
| 			err = utils.JsonDecode(v.Parameters, ¶meters) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			required := parameters["required"] | ||||
| 			delete(parameters, "required") | ||||
| 			tools = append(tools, gin.H{ | ||||
| 				"type": "function", | ||||
| 				"function": gin.H{ | ||||
| 					"name":        v.Name, | ||||
| 					"description": v.Description, | ||||
| 					"parameters":  parameters, | ||||
| 					"required":    required, | ||||
| 				}, | ||||
| 			}) | ||||
| 		} | ||||
|  | ||||
| 		if len(tools) > 0 { | ||||
| 			req.Tools = tools | ||||
| 			req.ToolChoice = "auto" | ||||
| 		} | ||||
| 	case types.QWen: | ||||
| 		req.Parameters = map[string]interface{}{ | ||||
| 			"max_tokens":  session.Model.MaxTokens, | ||||
| 			"temperature": session.Model.Temperature, | ||||
| 		} | ||||
| 		break | ||||
|  | ||||
| 	default: | ||||
| 		utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// 加载聊天上下文 | ||||
| 	var chatCtx []interface{} | ||||
| 	if h.App.ChatConfig.EnableContext { | ||||
| 	chatCtx := make([]types.Message, 0) | ||||
| 	messages := make([]types.Message, 0) | ||||
| 	if h.App.SysConfig.EnableContext { | ||||
| 		if h.App.ChatContexts.Has(session.ChatId) { | ||||
| 			chatCtx = h.App.ChatContexts.Get(session.ChatId) | ||||
| 			messages = h.App.ChatContexts.Get(session.ChatId) | ||||
| 		} else { | ||||
| 			// calculate the tokens of current request, to prevent to exceeding the max tokens num | ||||
| 			tokens := req.MaxTokens | ||||
| 			for _, f := range types.InnerFunctions { | ||||
| 				tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model) | ||||
| 				tokens += tks | ||||
| 			} | ||||
|  | ||||
| 			// loading the role context | ||||
| 			var messages []types.Message | ||||
| 			err := utils.JsonDecode(role.Context, &messages) | ||||
| 			if err == nil { | ||||
| 				for _, v := range messages { | ||||
| 					tks, _ := utils.CalcTokens(v.Content, req.Model) | ||||
| 					if tokens+tks >= types.ModelToTokens[req.Model] { | ||||
| 						break | ||||
| 					} | ||||
| 					tokens += tks | ||||
| 					chatCtx = append(chatCtx, v) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			// loading recent chat history as chat context | ||||
| 			if chatConfig.ContextDeep > 0 { | ||||
| 				var historyMessages []model.HistoryMessage | ||||
| 				res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages) | ||||
| 			_ = utils.JsonDecode(role.Context, &messages) | ||||
| 			if h.App.SysConfig.ContextDeep > 0 { | ||||
| 				var historyMessages []model.ChatMessage | ||||
| 				res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages) | ||||
| 				if res.Error == nil { | ||||
| 					for i := len(historyMessages) - 1; i >= 0; i-- { | ||||
| 						msg := historyMessages[i] | ||||
| 						if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] { | ||||
| 							break | ||||
| 						} | ||||
| 						tokens += msg.Tokens | ||||
| 						ms := types.Message{Role: "user", Content: msg.Content} | ||||
| 						if msg.Type == types.ReplyMsg { | ||||
| 							ms.Role = "assistant" | ||||
| @@ -283,6 +295,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 计算当前请求的 token 总长度,确保不会超出最大上下文长度 | ||||
| 		// MaxContextLength = Response + Tool + Prompt + Context | ||||
| 		tokens := req.MaxTokens // 最大响应长度 | ||||
| 		tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model) | ||||
| 		tokens += tks + promptTokens | ||||
|  | ||||
| 		for _, v := range messages { | ||||
| 			tks, _ := utils.CalcTokens(v.Content, req.Model) | ||||
| 			// 上下文 token 超出了模型的最大上下文长度 | ||||
| 			if tokens+tks >= session.Model.MaxContext { | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			// 上下文的深度超出了模型的最大上下文深度 | ||||
| 			if len(chatCtx) >= h.App.SysConfig.ContextDeep { | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			tokens += tks | ||||
| 			chatCtx = append(chatCtx, v) | ||||
| 		} | ||||
|  | ||||
| 		logger.Debugf("聊天上下文:%+v", chatCtx) | ||||
| 	} | ||||
| 	reqMgs := make([]interface{}, 0) | ||||
| @@ -290,10 +325,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| 		reqMgs = append(reqMgs, m) | ||||
| 	} | ||||
|  | ||||
| 	req.Messages = append(reqMgs, map[string]interface{}{ | ||||
| 		"role":    "user", | ||||
| 		"content": prompt, | ||||
| 	}) | ||||
| 	if session.Model.Platform == types.QWen { | ||||
| 		req.Input = map[string]interface{}{"prompt": prompt} | ||||
| 		if len(reqMgs) > 0 { | ||||
| 			req.Input["messages"] = reqMgs | ||||
| 		} | ||||
| 	} else { | ||||
| 		req.Messages = append(reqMgs, map[string]interface{}{ | ||||
| 			"role":    "user", | ||||
| 			"content": prompt, | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	switch session.Model.Platform { | ||||
| 	case types.Azure: | ||||
| @@ -306,7 +348,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| 		return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	case types.XunFei: | ||||
| 		return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
|  | ||||
| 	case types.QWen: | ||||
| 		return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) | ||||
| 	} | ||||
| 	utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 		Type:    types.WsMiddle, | ||||
| @@ -318,8 +361,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio | ||||
| // Tokens 统计 token 数量 | ||||
| func (h *ChatHandler) Tokens(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Text  string `json:"text"` | ||||
| 		Model string `json:"model"` | ||||
| 		Text   string `json:"text"` | ||||
| 		Model  string `json:"model"` | ||||
| 		ChatId string `json:"chat_id"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -327,10 +371,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文) | ||||
| 	if data.Text == "" { | ||||
| 		var item model.HistoryMessage | ||||
| 	if data.Text == "" && data.ChatId != "" { | ||||
| 		var item model.ChatMessage | ||||
| 		userId, _ := c.Get(types.LoginUserID) | ||||
| 		res := h.db.Where("user_id = ?", userId).Last(&item) | ||||
| 		res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, res.Error.Error()) | ||||
| 			return | ||||
| @@ -380,39 +424,37 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) { | ||||
|  | ||||
| // 发送请求到 OpenAI 服务器 | ||||
| // useOwnApiKey: 是否使用了用户自己的 API KEY | ||||
| func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) { | ||||
|  | ||||
| func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) { | ||||
| 	res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) | ||||
| 	if res.Error != nil { | ||||
| 		return nil, errors.New("no available key, please import key") | ||||
| 	} | ||||
| 	var apiURL string | ||||
| 	switch platform { | ||||
| 	case types.Azure: | ||||
| 		md := strings.Replace(req.Model, ".", "", 1) | ||||
| 		apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1) | ||||
| 		apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) | ||||
| 		break | ||||
| 	case types.ChatGLM: | ||||
| 		apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1) | ||||
| 		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) | ||||
| 		req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 | ||||
| 		req.Messages = nil | ||||
| 		break | ||||
| 	case types.Baidu: | ||||
| 		apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1) | ||||
| 		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) | ||||
| 		break | ||||
| 	case types.QWen: | ||||
| 		apiURL = apiKey.ApiURL | ||||
| 		req.Messages = nil | ||||
| 		break | ||||
| 	default: | ||||
| 		apiURL = h.App.ChatConfig.OpenAI.ApiURL | ||||
| 		apiURL = apiKey.ApiURL | ||||
| 	} | ||||
| 	if *apiKey == "" { | ||||
| 		var key model.ApiKey | ||||
| 		res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key) | ||||
| 		if res.Error != nil { | ||||
| 			return nil, errors.New("no available key, please import key") | ||||
| 		} | ||||
| 		// 更新 API KEY 的最后使用时间 | ||||
| 		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 		*apiKey = key.Value | ||||
| 	} | ||||
|  | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 	// 百度文心,需要串接 access_token | ||||
| 	if platform == types.Baidu { | ||||
| 		token, err := h.getBaiduToken(*apiKey) | ||||
| 		token, err := h.getBaiduToken(apiKey.Value) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -420,6 +462,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf | ||||
| 		apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token) | ||||
| 	} | ||||
|  | ||||
| 	logger.Debugf(utils.JsonEncode(req)) | ||||
|  | ||||
| 	// 创建 HttpClient 请求对象 | ||||
| 	var client *http.Client | ||||
| 	requestBody, err := json.Marshal(req) | ||||
| @@ -433,9 +477,9 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf | ||||
|  | ||||
| 	request = request.WithContext(ctx) | ||||
| 	request.Header.Set("Content-Type", "application/json") | ||||
| 	proxyURL := h.App.Config.ProxyURL | ||||
| 	if proxyURL != "" && platform == types.OpenAI { // 使用代理 | ||||
| 		proxy, _ := url.Parse(proxyURL) | ||||
| 	var proxyURL string | ||||
| 	if apiKey.ProxyURL != "" { // 使用代理 | ||||
| 		proxy, _ := url.Parse(apiKey.ProxyURL) | ||||
| 		client = &http.Client{ | ||||
| 			Transport: &http.Transport{ | ||||
| 				Proxy: http.ProxyURL(proxy), | ||||
| @@ -444,35 +488,79 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf | ||||
| 	} else { | ||||
| 		client = http.DefaultClient | ||||
| 	} | ||||
| 	logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model) | ||||
| 	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model) | ||||
| 	switch platform { | ||||
| 	case types.Azure: | ||||
| 		request.Header.Set("api-key", *apiKey) | ||||
| 		request.Header.Set("api-key", apiKey.Value) | ||||
| 		break | ||||
| 	case types.ChatGLM: | ||||
| 		token, err := h.getChatGLMToken(*apiKey) | ||||
| 		token, err := h.getChatGLMToken(apiKey.Value) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		logger.Info(token) | ||||
| 		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) | ||||
| 		break | ||||
| 	case types.Baidu: | ||||
| 		request.RequestURI = "" | ||||
| 	case types.OpenAI: | ||||
| 		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) | ||||
| 		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) | ||||
| 		break | ||||
| 	case types.QWen: | ||||
| 		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) | ||||
| 		request.Header.Set("X-DashScope-SSE", "enable") | ||||
| 		break | ||||
| 	} | ||||
| 	return client.Do(request) | ||||
| } | ||||
|  | ||||
| // 扣减用户的对话次数 | ||||
| func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) { | ||||
| 	// 仅当用户没有导入自己的 API KEY 时才进行扣减 | ||||
| 	if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { | ||||
| 		num := 1 | ||||
| 		if session.Model.Weight > 0 { | ||||
| 			num = session.Model.Weight | ||||
| 		} | ||||
| 		h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num)) | ||||
| // 扣减用户算力 | ||||
| func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) { | ||||
| 	power := 1 | ||||
| 	if session.Model.Power > 0 { | ||||
| 		power = session.Model.Power | ||||
| 	} | ||||
| 	res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power)) | ||||
| 	if res.Error == nil { | ||||
| 		// 记录算力消费日志 | ||||
| 		var u model.User | ||||
| 		h.DB.Where("id", userVo.Id).First(&u) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    userVo.Id, | ||||
| 			Username:  userVo.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Balance:   u.Power, | ||||
| 			Model:     session.Model.Value, | ||||
| 			Remark:    fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| // 将AI回复消息中生成的图片链接下载到本地 | ||||
| func (h *ChatHandler) extractImgUrl(text string) string { | ||||
| 	pattern := `!\[([^\]]*)]\(([^)]+)\)` | ||||
| 	re := regexp.MustCompile(pattern) | ||||
| 	matches := re.FindAllStringSubmatch(text, -1) | ||||
|  | ||||
| 	// 下载图片并替换链接地址 | ||||
| 	for _, match := range matches { | ||||
| 		imageURL := match[2] | ||||
| 		logger.Debug(imageURL) | ||||
| 		// 对于相同地址的图片,已经被替换了,就不再重复下载了 | ||||
| 		if !strings.Contains(text, imageURL) { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false) | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with download image: ", err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		text = strings.ReplaceAll(text, imageURL, newImgURL) | ||||
| 	} | ||||
| 	return text | ||||
| } | ||||
|   | ||||
| @@ -6,27 +6,29 @@ import ( | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // List 获取会话列表 | ||||
| func (h *ChatHandler) List(c *gin.Context) { | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	if userId == 0 { | ||||
| 		resp.ERROR(c, "The parameter 'user_id' is needed.") | ||||
| 	if !h.IsLogin(c) { | ||||
| 		resp.SUCCESS(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var items = make([]vo.ChatItem, 0) | ||||
| 	var chats []model.ChatItem | ||||
| 	res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats) | ||||
| 	res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats) | ||||
| 	if res.Error == nil { | ||||
| 		var roleIds = make([]uint, 0) | ||||
| 		for _, chat := range chats { | ||||
| 			roleIds = append(roleIds, chat.RoleId) | ||||
| 		} | ||||
| 		var roles []model.ChatRole | ||||
| 		res = h.db.Find(&roles, roleIds) | ||||
| 		res = h.DB.Find(&roles, roleIds) | ||||
| 		if res.Error == nil { | ||||
| 			roleMap := make(map[uint]model.ChatRole) | ||||
| 			for _, role := range roles { | ||||
| @@ -58,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title) | ||||
| 	res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Failed to update database") | ||||
| 		return | ||||
| @@ -70,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) { | ||||
| // Clear 清空所有聊天记录 | ||||
| func (h *ChatHandler) Clear(c *gin.Context) { | ||||
| 	// 获取当前登录用户所有的聊天会话 | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var chats []model.ChatItem | ||||
| 	res := h.db.Where("user_id = ?", user.Id).Find(&chats) | ||||
| 	res := h.DB.Where("user_id = ?", user.Id).Find(&chats) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No chats found") | ||||
| 		return | ||||
| @@ -89,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) { | ||||
| 		// 清空会话上下文 | ||||
| 		h.App.ChatContexts.Delete(chat.ChatId) | ||||
| 	} | ||||
| 	err = h.db.Transaction(func(tx *gorm.DB) error { | ||||
| 		res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{}) | ||||
| 	err = h.DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{}) | ||||
| 		if res.Error != nil { | ||||
| 			return res.Error | ||||
| 		} | ||||
|  | ||||
| 		res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.HistoryMessage{}) | ||||
| 		res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{}) | ||||
| 		if res.Error != nil { | ||||
| 			return res.Error | ||||
| 		} | ||||
| @@ -116,9 +118,9 @@ func (h *ChatHandler) Clear(c *gin.Context) { | ||||
| // History 获取聊天历史记录 | ||||
| func (h *ChatHandler) History(c *gin.Context) { | ||||
| 	chatId := c.Query("chat_id") // 会话 ID | ||||
| 	var items []model.HistoryMessage | ||||
| 	var items []model.ChatMessage | ||||
| 	var messages = make([]vo.HistoryMessage, 0) | ||||
| 	res := h.db.Where("chat_id = ?", chatId).Find(&items) | ||||
| 	res := h.DB.Where("chat_id = ?", chatId).Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No history message") | ||||
| 		return | ||||
| @@ -144,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{}) | ||||
| 	res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Failed to update database") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 删除当前会话的聊天记录 | ||||
| 	res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{}) | ||||
| 	res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Failed to remove chat from database.") | ||||
| 		return | ||||
| @@ -179,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	var chatItem model.ChatItem | ||||
| 	res := h.db.Where("chat_id = ?", chatId).First(&chatItem) | ||||
| 	res := h.DB.Where("chat_id = ?", chatId).First(&chatItem) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "No chat found") | ||||
| 		return | ||||
|   | ||||
| @@ -10,7 +10,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"gorm.io/gorm" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -20,7 +20,7 @@ import ( | ||||
| // 清华大学 ChatGML 消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendChatGLMMessage( | ||||
| 	chatCtx []interface{}, | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| @@ -30,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage( | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| @@ -45,7 +45,7 @@ func (h *ChatHandler) sendChatGLMMessage( | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| @@ -72,6 +72,10 @@ func (h *ChatHandler) sendChatGLMMessage( | ||||
| 			if strings.HasPrefix(line, "data:") { | ||||
| 				content = line[5:] | ||||
| 			} | ||||
| 			// 处理代码换行 | ||||
| 			if len(content) == 0 { | ||||
| 				content = "\n" | ||||
| 			} | ||||
| 			switch event { | ||||
| 			case "add": | ||||
| 				if len(contents) == 0 { | ||||
| @@ -103,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage( | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			// 更新用户的对话次数 | ||||
| 			h.subUserCalls(userVo, session) | ||||
|  | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| @@ -113,64 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage( | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.ChatConfig.EnableContext { | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			if h.App.ChatConfig.EnableHistory { | ||||
| 				// for prompt | ||||
| 				promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 				if err != nil { | ||||
| 					logger.Error(err) | ||||
| 				} | ||||
| 				historyUserMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.PromptMsg, | ||||
| 					Icon:       userVo.Avatar, | ||||
| 					Content:    prompt, | ||||
| 					Tokens:     promptToken, | ||||
| 					UseContext: true, | ||||
| 				} | ||||
| 				historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 				historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 				res := h.db.Save(&historyUserMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 				} | ||||
|  | ||||
| 				// for reply | ||||
| 				// 计算本次对话消耗的总 token 数量 | ||||
| 				replyToken, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 				totalTokens := replyToken + getTotalTokens(req) | ||||
| 				historyReplyMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.ReplyMsg, | ||||
| 					Icon:       role.Icon, | ||||
| 					Content:    message.Content, | ||||
| 					Tokens:     totalTokens, | ||||
| 					UseContext: true, | ||||
| 				} | ||||
| 				historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 				historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 				res = h.db.Create(&historyReplyMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save reply history message: ", res.Error) | ||||
| 				} | ||||
| 				// 更新用户信息 | ||||
| 				h.db.Model(&model.User{}).Where("id = ?", userVo.Id). | ||||
| 					UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyTokens + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| @@ -181,7 +182,8 @@ func (h *ChatHandler) sendChatGLMMessage( | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				h.db.Create(&chatItem) | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
|   | ||||
| @@ -9,16 +9,18 @@ import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	req2 "github.com/imroc/req/v3" | ||||
| ) | ||||
|  | ||||
| // OPenAI 消息发送实现 | ||||
| func (h *ChatHandler) sendOpenAiMessage( | ||||
| 	chatCtx []interface{}, | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| @@ -28,7 +30,7 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| @@ -43,7 +45,11 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, "") | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		if response.Body != nil { | ||||
| 			all, _ := io.ReadAll(response.Body) | ||||
| 			logger.Error(string(all)) | ||||
| 		} | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| @@ -55,8 +61,8 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		var functionCall = false | ||||
| 		var functionName string | ||||
| 		var function model.Function | ||||
| 		var toolCall = false | ||||
| 		var arguments = make([]string, 0) | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
| 		for scanner.Scan() { | ||||
| @@ -70,28 +76,41 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 			if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 | ||||
| 				logger.Error(err, line) | ||||
| 				utils.ReplyMessage(ws, ErrorMsg) | ||||
| 				utils.ReplyMessage(ws, "") | ||||
| 				utils.ReplyMessage(ws, ErrImg) | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			var tool types.ToolCall | ||||
| 			if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { | ||||
| 				tool = responseBody.Choices[0].Delta.ToolCalls[0] | ||||
| 				if toolCall && tool.Function.Name == "" { | ||||
| 					arguments = append(arguments, tool.Function.Arguments) | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			// 兼容 Function Call | ||||
| 			fun := responseBody.Choices[0].Delta.FunctionCall | ||||
| 			if functionCall && fun.Name == "" { | ||||
| 			if fun.Name != "" { | ||||
| 				tool = *new(types.ToolCall) | ||||
| 				tool.Function.Name = fun.Name | ||||
| 			} else if toolCall { | ||||
| 				arguments = append(arguments, fun.Arguments) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if !utils.IsEmptyValue(fun) { | ||||
| 				functionName = fun.Name | ||||
| 				f := h.App.Functions[functionName] | ||||
| 				if f != nil { | ||||
| 					functionCall = true | ||||
| 			if !utils.IsEmptyValue(tool) { | ||||
| 				res := h.DB.Where("name = ?", tool.Function.Name).First(&function) | ||||
| 				if res.Error == nil { | ||||
| 					toolCall = true | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())}) | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)}) | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 | ||||
| 			if responseBody.Choices[0].FinishReason == "tool_calls" || | ||||
| 				responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| @@ -120,55 +139,40 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if functionCall { // 调用函数完成任务 | ||||
| 		if toolCall { // 调用函数完成任务 | ||||
| 			var params map[string]interface{} | ||||
| 			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) | ||||
| 			logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params) | ||||
|  | ||||
| 			// for creating image, check if the user's img_calls > 0 | ||||
| 			if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 { | ||||
| 				utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**") | ||||
| 				utils.ReplyMessage(ws, "") | ||||
| 			logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) | ||||
| 			params["user_id"] = userVo.Id | ||||
| 			var apiRes types.BizVo | ||||
| 			r, err := req2.C().R().SetHeader("Content-Type", "application/json"). | ||||
| 				SetHeader("Authorization", function.Token). | ||||
| 				SetBody(params). | ||||
| 				SetSuccessResult(&apiRes).Post(function.Action) | ||||
| 			errMsg := "" | ||||
| 			if err != nil { | ||||
| 				errMsg = err.Error() | ||||
| 			} else if r.IsErrorState() { | ||||
| 				errMsg = r.Status | ||||
| 			} | ||||
| 			if errMsg != "" || apiRes.Code != types.Success { | ||||
| 				msg := "调用函数工具出错:" + apiRes.Message + errMsg | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: msg, | ||||
| 				}) | ||||
| 				contents = append(contents, msg) | ||||
| 			} else { | ||||
| 				f := h.App.Functions[functionName] | ||||
| 				if functionName == types.FuncMidJourney { | ||||
| 					params["user_id"] = userVo.Id | ||||
| 					params["role_id"] = role.Id | ||||
| 					params["chat_id"] = session.ChatId | ||||
| 					params["icon"] = "/images/avatar/mid_journey.png" | ||||
| 					params["session_id"] = session.SessionId | ||||
| 				} | ||||
| 				data, err := f.Invoke(params) | ||||
| 				if err != nil { | ||||
| 					msg := "调用函数出错:" + err.Error() | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 						Type:    types.WsMiddle, | ||||
| 						Content: msg, | ||||
| 					}) | ||||
| 					contents = append(contents, msg) | ||||
| 				} else { | ||||
| 					content := data | ||||
| 					if functionName == types.FuncMidJourney { | ||||
| 						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) | ||||
| 						h.mjService.ChatClients.Put(session.SessionId, ws) | ||||
| 						// update user's img_calls | ||||
| 						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) | ||||
| 					} | ||||
|  | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 						Type:    types.WsMiddle, | ||||
| 						Content: content, | ||||
| 					}) | ||||
| 					contents = append(contents, content) | ||||
| 				} | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: apiRes.Data, | ||||
| 				}) | ||||
| 				contents = append(contents, utils.InterfaceToString(apiRes.Data)) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			// 更新用户的对话次数 | ||||
| 			h.subUserCalls(userVo, session) | ||||
|  | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| @@ -176,78 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.ChatConfig.EnableContext && functionCall == false { | ||||
| 			if h.App.SysConfig.EnableContext && toolCall == false { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			if h.App.ChatConfig.EnableHistory { | ||||
| 				useContext := true | ||||
| 				if functionCall { | ||||
| 					useContext = false | ||||
| 				} | ||||
|  | ||||
| 				// for prompt | ||||
| 				promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 				if err != nil { | ||||
| 					logger.Error(err) | ||||
| 				} | ||||
| 				historyUserMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.PromptMsg, | ||||
| 					Icon:       userVo.Avatar, | ||||
| 					Content:    prompt, | ||||
| 					Tokens:     promptToken, | ||||
| 					UseContext: useContext, | ||||
| 				} | ||||
| 				historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 				historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 				res := h.db.Save(&historyUserMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 				} | ||||
|  | ||||
| 				// 计算本次对话消耗的总 token 数量 | ||||
| 				var totalTokens = 0 | ||||
| 				if functionCall { // prompt + 函数名 + 参数 token | ||||
| 					tokens, _ := utils.CalcTokens(functionName, req.Model) | ||||
| 					totalTokens += tokens | ||||
| 					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) | ||||
| 					totalTokens += tokens | ||||
| 				} else { | ||||
| 					totalTokens, _ = utils.CalcTokens(message.Content, req.Model) | ||||
| 				} | ||||
| 				totalTokens += getTotalTokens(req) | ||||
|  | ||||
| 				historyReplyMsg := model.HistoryMessage{ | ||||
| 					UserId:     userVo.Id, | ||||
| 					ChatId:     session.ChatId, | ||||
| 					RoleId:     role.Id, | ||||
| 					Type:       types.ReplyMsg, | ||||
| 					Icon:       role.Icon, | ||||
| 					Content:    message.Content, | ||||
| 					Tokens:     totalTokens, | ||||
| 					UseContext: useContext, | ||||
| 				} | ||||
| 				historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 				historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 				res = h.db.Create(&historyReplyMsg) | ||||
| 				if res.Error != nil { | ||||
| 					logger.Error("failed to save reply history message: ", res.Error) | ||||
| 				} | ||||
|  | ||||
| 				// 更新用户信息 | ||||
| 				h.db.Model(&model.User{}).Where("id = ?", userVo.Id). | ||||
| 					UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) | ||||
| 			useContext := true | ||||
| 			if toolCall { | ||||
| 				useContext = false | ||||
| 			} | ||||
|  | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: useContext, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			var replyTokens = 0 | ||||
| 			if toolCall { // prompt + 函数名 + 参数 token | ||||
| 				tokens, _ := utils.CalcTokens(function.Name, req.Model) | ||||
| 				replyTokens += tokens | ||||
| 				tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) | ||||
| 				replyTokens += tokens | ||||
| 			} else { | ||||
| 				replyTokens, _ = utils.CalcTokens(message.Content, req.Model) | ||||
| 			} | ||||
| 			replyTokens += getTotalTokens(req) | ||||
|  | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    h.extractImgUrl(message.Content), | ||||
| 				Tokens:     replyTokens, | ||||
| 				UseContext: useContext, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| @@ -258,17 +261,20 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				h.db.Create(&chatItem) | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error()) | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
| 		var res types.ApiError | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```") | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
|  | ||||
| @@ -276,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage( | ||||
| 		if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") | ||||
| 			// 移除当前 API key | ||||
| 			h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{}) | ||||
| 			h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) | ||||
| 		} else if strings.Contains(res.Error.Message, "You exceeded your current quota") { | ||||
| 			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") | ||||
| 		} else if strings.Contains(res.Error.Message, "This model's maximum context length") { | ||||
|   | ||||
							
								
								
									
										240
									
								
								api/handler/chatimpl/qwen_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								api/handler/chatimpl/qwen_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,240 @@ | ||||
| package chatimpl | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| ) | ||||
|  | ||||
| type qWenResp struct { | ||||
| 	Output struct { | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 		Text         string `json:"text"` | ||||
| 	} `json:"output,omitempty"` | ||||
| 	Usage struct { | ||||
| 		TotalTokens  int `json:"total_tokens"` | ||||
| 		InputTokens  int `json:"input_tokens"` | ||||
| 		OutputTokens int `json:"output_tokens"` | ||||
| 	} `json:"usage,omitempty"` | ||||
| 	RequestID string `json:"request_id"` | ||||
|  | ||||
| 	Code    string `json:"code,omitempty"` | ||||
| 	Message string `json:"message,omitempty"` | ||||
| } | ||||
|  | ||||
| // 通义千问消息发送实现 | ||||
| func (h *ChatHandler) sendQWenMessage( | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| 	session *types.ChatSession, | ||||
| 	role model.ChatRole, | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	start := time.Now() | ||||
| 	var apiKey = model.ApiKey{} | ||||
| 	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) | ||||
| 	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "context canceled") { | ||||
| 			logger.Info("用户取消了请求:", prompt) | ||||
| 			return nil | ||||
| 		} else if strings.Contains(err.Error(), "no available key") { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
|  | ||||
| 		utils.ReplyMessage(ws, ErrorMsg) | ||||
| 		utils.ReplyMessage(ws, ErrImg) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		defer response.Body.Close() | ||||
| 	} | ||||
| 	contentType := response.Header.Get("Content-Type") | ||||
| 	if strings.Contains(contentType, "text/event-stream") { | ||||
| 		replyCreatedAt := time.Now() // 记录回复时间 | ||||
| 		// 循环读取 Chunk 消息 | ||||
| 		var message = types.Message{} | ||||
| 		var contents = make([]string, 0) | ||||
| 		scanner := bufio.NewScanner(response.Body) | ||||
|  | ||||
| 		var content, lastText, newText string | ||||
| 		var outPutStart = false | ||||
|  | ||||
| 		for scanner.Scan() { | ||||
| 			line := scanner.Text() | ||||
| 			if len(line) < 5 || strings.HasPrefix(line, "id:") || | ||||
| 				strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if strings.HasPrefix(line, "data:") { | ||||
| 				content = line[5:] | ||||
| 			} | ||||
|  | ||||
| 			var resp qWenResp | ||||
| 			if len(contents) == 0 { // 发送消息头 | ||||
| 				if !outPutStart { | ||||
| 					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) | ||||
| 					outPutStart = true | ||||
| 					continue | ||||
| 				} else { | ||||
| 					// 处理代码换行 | ||||
| 					content = "\n" | ||||
| 				} | ||||
| 			} else { | ||||
| 				err := utils.JsonDecode(content, &resp) | ||||
| 				if err != nil { | ||||
| 					logger.Error("error with parse data line: ", content) | ||||
| 					utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) | ||||
| 					break | ||||
| 				} | ||||
| 				if resp.Message != "" { | ||||
| 					utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message)) | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			//通过比较 lastText(上一次的文本)和 currentText(当前的文本), | ||||
| 			//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。 | ||||
| 			//每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。 | ||||
| 			currentText := resp.Output.Text | ||||
| 			if currentText != lastText { | ||||
| 				// 提取新增文本 | ||||
| 				newText = strings.Replace(currentText, lastText, "", 1) | ||||
| 				utils.ReplyChunkMessage(ws, types.WsMessage{ | ||||
| 					Type:    types.WsMiddle, | ||||
| 					Content: utils.InterfaceToString(newText), | ||||
| 				}) | ||||
| 				lastText = currentText // 更新 lastText | ||||
| 			} | ||||
| 			contents = append(contents, newText) | ||||
|  | ||||
| 			if resp.Output.FinishReason == "stop" { | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 		} //end for | ||||
|  | ||||
| 		if err := scanner.Err(); err != nil { | ||||
| 			if strings.Contains(err.Error(), "context canceled") { | ||||
| 				logger.Info("用户取消了请求:", prompt) | ||||
| 			} else { | ||||
| 				logger.Error("信息读取出错:", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 消息发送成功 | ||||
| 		if len(contents) > 0 { | ||||
| 			if message.Role == "" { | ||||
| 				message.Role = "assistant" | ||||
| 			} | ||||
| 			message.Content = strings.Join(contents, "") | ||||
| 			useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 			// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 			if h.App.SysConfig.EnableContext { | ||||
| 				chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 				chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 				h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 			} | ||||
|  | ||||
| 			// 追加聊天记录 | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    template.HTMLEscapeString(prompt), | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.DB.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyTokens + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.ChatMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 				Model:      req.Model, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.DB.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// 更新用户算力 | ||||
| 			h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 			// 保存当前会话 | ||||
| 			var chatItem model.ChatItem | ||||
| 			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 			if res.Error != nil { | ||||
| 				chatItem.ChatId = session.ChatId | ||||
| 				chatItem.UserId = session.UserId | ||||
| 				chatItem.RoleId = role.Id | ||||
| 				chatItem.ModelId = session.Model.Id | ||||
| 				if utf8.RuneCountInString(prompt) > 30 { | ||||
| 					chatItem.Title = string([]rune(prompt)[:30]) + "..." | ||||
| 				} else { | ||||
| 					chatItem.Title = prompt | ||||
| 				} | ||||
| 				chatItem.Model = req.Model | ||||
| 				h.DB.Create(&chatItem) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := io.ReadAll(response.Body) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with reading response: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		var res struct { | ||||
| 			Code int    `json:"error_code"` | ||||
| 			Msg  string `json:"error_msg"` | ||||
| 		} | ||||
| 		err = json.Unmarshal(body, &res) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error with decode response: %v", err) | ||||
| 		} | ||||
| 		utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -12,7 +12,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| 	"html/template" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| @@ -49,10 +49,17 @@ type xunFeiResp struct { | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| var Model2URL = map[string]string{ | ||||
| 	"general":     "v1.1", | ||||
| 	"generalv2":   "v2.1", | ||||
| 	"generalv3":   "v3.1", | ||||
| 	"generalv3.5": "v3.5", | ||||
| } | ||||
|  | ||||
| // 科大讯飞消息发送实现 | ||||
|  | ||||
| func (h *ChatHandler) sendXunFeiMessage( | ||||
| 	chatCtx []interface{}, | ||||
| 	chatCtx []types.Message, | ||||
| 	req types.ApiRequest, | ||||
| 	userVo vo.User, | ||||
| 	ctx context.Context, | ||||
| @@ -61,35 +68,26 @@ func (h *ChatHandler) sendXunFeiMessage( | ||||
| 	prompt string, | ||||
| 	ws *types.WsClient) error { | ||||
| 	promptCreatedAt := time.Now() // 记录提问时间 | ||||
| 	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] | ||||
| 	if apiKey == "" { | ||||
| 		var key model.ApiKey | ||||
| 		res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key) | ||||
| 		if res.Error != nil { | ||||
| 			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 			return nil | ||||
| 		} | ||||
| 		// 更新 API KEY 的最后使用时间 | ||||
| 		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 		apiKey = key.Value | ||||
| 	var apiKey model.ApiKey | ||||
| 	res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) | ||||
| 	if res.Error != nil { | ||||
| 		utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") | ||||
| 		return nil | ||||
| 	} | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
|  | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| 	key := strings.Split(apiKey, "|") | ||||
| 	key := strings.Split(apiKey.Value, "|") | ||||
| 	if len(key) != 3 { | ||||
| 		utils.ReplyMessage(ws, "非法的 API KEY!") | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var apiURL string | ||||
| 	if req.Model == "generalv2" { | ||||
| 		apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v2.1", 1) | ||||
| 	} else { | ||||
| 		apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v1.1", 1) | ||||
| 	} | ||||
|  | ||||
| 	apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) | ||||
| 	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) | ||||
| 	wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) | ||||
| 	//握手并建立websocket 连接 | ||||
| 	conn, resp, err := d.Dial(wsURL, nil) | ||||
| @@ -139,6 +137,10 @@ func (h *ChatHandler) sendXunFeiMessage( | ||||
| 		} | ||||
|  | ||||
| 		content = result.Payload.Choices.Text[0].Content | ||||
| 		// 处理代码换行 | ||||
| 		if len(content) == 0 { | ||||
| 			content = "\n" | ||||
| 		} | ||||
| 		contents = append(contents, content) | ||||
| 		// 第一个结果 | ||||
| 		if result.Payload.Choices.Status == 0 { | ||||
| @@ -166,9 +168,6 @@ func (h *ChatHandler) sendXunFeiMessage( | ||||
|  | ||||
| 	// 消息发送成功 | ||||
| 	if len(contents) > 0 { | ||||
| 		// 更新用户的对话次数 | ||||
| 		h.subUserCalls(userVo, session) | ||||
|  | ||||
| 		if message.Role == "" { | ||||
| 			message.Role = "assistant" | ||||
| 		} | ||||
| @@ -176,64 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage( | ||||
| 		useMsg := types.Message{Role: "user", Content: prompt} | ||||
|  | ||||
| 		// 更新上下文消息,如果是调用函数则不需要更新上下文 | ||||
| 		if h.App.ChatConfig.EnableContext { | ||||
| 		if h.App.SysConfig.EnableContext { | ||||
| 			chatCtx = append(chatCtx, useMsg)  // 提问消息 | ||||
| 			chatCtx = append(chatCtx, message) // 回复消息 | ||||
| 			h.App.ChatContexts.Put(session.ChatId, chatCtx) | ||||
| 		} | ||||
|  | ||||
| 		// 追加聊天记录 | ||||
| 		if h.App.ChatConfig.EnableHistory { | ||||
| 			// for prompt | ||||
| 			promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			historyUserMsg := model.HistoryMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.PromptMsg, | ||||
| 				Icon:       userVo.Avatar, | ||||
| 				Content:    prompt, | ||||
| 				Tokens:     promptToken, | ||||
| 				UseContext: true, | ||||
| 			} | ||||
| 			historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 			historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 			res := h.db.Save(&historyUserMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 			} | ||||
|  | ||||
| 			// for reply | ||||
| 			// 计算本次对话消耗的总 token 数量 | ||||
| 			replyToken, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 			totalTokens := replyToken + getTotalTokens(req) | ||||
| 			historyReplyMsg := model.HistoryMessage{ | ||||
| 				UserId:     userVo.Id, | ||||
| 				ChatId:     session.ChatId, | ||||
| 				RoleId:     role.Id, | ||||
| 				Type:       types.ReplyMsg, | ||||
| 				Icon:       role.Icon, | ||||
| 				Content:    message.Content, | ||||
| 				Tokens:     totalTokens, | ||||
| 				UseContext: true, | ||||
| 			} | ||||
| 			historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 			historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 			res = h.db.Create(&historyReplyMsg) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("failed to save reply history message: ", res.Error) | ||||
| 			} | ||||
| 			// 更新用户信息 | ||||
| 			h.db.Model(&model.User{}).Where("id = ?", userVo.Id). | ||||
| 				UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) | ||||
| 		// for prompt | ||||
| 		promptToken, err := utils.CalcTokens(prompt, req.Model) | ||||
| 		if err != nil { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
| 		historyUserMsg := model.ChatMessage{ | ||||
| 			UserId:     userVo.Id, | ||||
| 			ChatId:     session.ChatId, | ||||
| 			RoleId:     role.Id, | ||||
| 			Type:       types.PromptMsg, | ||||
| 			Icon:       userVo.Avatar, | ||||
| 			Content:    template.HTMLEscapeString(prompt), | ||||
| 			Tokens:     promptToken, | ||||
| 			UseContext: true, | ||||
| 			Model:      req.Model, | ||||
| 		} | ||||
| 		historyUserMsg.CreatedAt = promptCreatedAt | ||||
| 		historyUserMsg.UpdatedAt = promptCreatedAt | ||||
| 		res := h.DB.Save(&historyUserMsg) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("failed to save prompt history message: ", res.Error) | ||||
| 		} | ||||
|  | ||||
| 		// for reply | ||||
| 		// 计算本次对话消耗的总 token 数量 | ||||
| 		replyTokens, _ := utils.CalcTokens(message.Content, req.Model) | ||||
| 		totalTokens := replyTokens + getTotalTokens(req) | ||||
| 		historyReplyMsg := model.ChatMessage{ | ||||
| 			UserId:     userVo.Id, | ||||
| 			ChatId:     session.ChatId, | ||||
| 			RoleId:     role.Id, | ||||
| 			Type:       types.ReplyMsg, | ||||
| 			Icon:       role.Icon, | ||||
| 			Content:    message.Content, | ||||
| 			Tokens:     totalTokens, | ||||
| 			UseContext: true, | ||||
| 			Model:      req.Model, | ||||
| 		} | ||||
| 		historyReplyMsg.CreatedAt = replyCreatedAt | ||||
| 		historyReplyMsg.UpdatedAt = replyCreatedAt | ||||
| 		res = h.DB.Create(&historyReplyMsg) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("failed to save reply history message: ", res.Error) | ||||
| 		} | ||||
|  | ||||
| 		// 更新用户算力 | ||||
| 		h.subUserPower(userVo, session, promptToken, replyTokens) | ||||
|  | ||||
| 		// 保存当前会话 | ||||
| 		var chatItem model.ChatItem | ||||
| 		res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 		res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) | ||||
| 		if res.Error != nil { | ||||
| 			chatItem.ChatId = session.ChatId | ||||
| 			chatItem.UserId = session.UserId | ||||
| @@ -244,7 +243,8 @@ func (h *ChatHandler) sendXunFeiMessage( | ||||
| 			} else { | ||||
| 				chatItem.Title = prompt | ||||
| 			} | ||||
| 			h.db.Create(&chatItem) | ||||
| 			chatItem.Model = req.Model | ||||
| 			h.DB.Create(&chatItem) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -260,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} { | ||||
| 		"parameter": map[string]interface{}{ | ||||
| 			"chat": map[string]interface{}{ | ||||
| 				"domain":      req.Model, | ||||
| 				"temperature": float64(req.Temperature), | ||||
| 				"temperature": req.Temperature, | ||||
| 				"top_k":       int64(6), | ||||
| 				"max_tokens":  int64(req.MaxTokens), | ||||
| 				"auditing":    "default", | ||||
|   | ||||
							
								
								
									
										39
									
								
								api/handler/config_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								api/handler/config_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type ConfigHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { | ||||
| 	return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // Get 获取指定的系统配置 | ||||
| func (h *ConfigHandler) Get(c *gin.Context) { | ||||
| 	key := c.Query("key") | ||||
| 	var config model.Config | ||||
| 	res := h.DB.Where("marker", key).First(&config) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var value map[string]interface{} | ||||
| 	err := utils.JsonDecode(config.Config, &value) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, value) | ||||
| } | ||||
							
								
								
									
										274
									
								
								api/handler/function_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								api/handler/function_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,274 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type FunctionHandler struct { | ||||
| 	BaseHandler | ||||
| 	config        types.ChatPlusApiConfig | ||||
| 	uploadManager *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { | ||||
| 	return &FunctionHandler{ | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: server, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 		config:        config.ApiConfig, | ||||
| 		uploadManager: manager, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type resVo struct { | ||||
| 	Code    types.BizCode `json:"code"` | ||||
| 	Message string        `json:"message"` | ||||
| 	Data    struct { | ||||
| 		Title     string     `json:"title"` | ||||
| 		UpdatedAt string     `json:"updated_at"` | ||||
| 		Items     []dataItem `json:"items"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type dataItem struct { | ||||
| 	Title  string `json:"title"` | ||||
| 	Url    string `json:"url"` | ||||
| 	Remark string `json:"remark"` | ||||
| } | ||||
|  | ||||
| // check authorization | ||||
| func (h *FunctionHandler) checkAuth(c *gin.Context) error { | ||||
| 	tokenString := c.GetHeader(types.UserAuthHeader) | ||||
| 	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { | ||||
| 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | ||||
| 			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) | ||||
| 		} | ||||
|  | ||||
| 		return []byte(h.App.Config.Session.SecretKey), nil | ||||
| 	}) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with parse auth token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	claims, ok := token.Claims.(jwt.MapClaims) | ||||
| 	if !ok || !token.Valid { | ||||
| 		return errors.New("token is invalid") | ||||
| 	} | ||||
|  | ||||
| 	expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) | ||||
| 	if expr > 0 && int64(expr) < time.Now().Unix() { | ||||
| 		return errors.New("token is expired") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WeiBo 微博热搜 | ||||
| func (h *FunctionHandler) WeiBo(c *gin.Context) { | ||||
| 	if err := h.checkAuth(c); err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.config.Token == "" { | ||||
| 		resp.ERROR(c, "无效的 API Token") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL) | ||||
| 	var res resVo | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("AppId", h.config.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)). | ||||
| 		SetSuccessResult(&res).Get(url) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		resp.ERROR(c, res.Message) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	builder := make([]string, 0) | ||||
| 	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt)) | ||||
| 	for i, v := range res.Data.Items { | ||||
| 		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark)) | ||||
| 	} | ||||
| 	resp.SUCCESS(c, strings.Join(builder, "\n\n")) | ||||
| } | ||||
|  | ||||
| // ZaoBao 今日早报 | ||||
| func (h *FunctionHandler) ZaoBao(c *gin.Context) { | ||||
| 	if err := h.checkAuth(c); err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.config.Token == "" { | ||||
| 		resp.ERROR(c, "无效的 API Token") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL) | ||||
| 	var res resVo | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("AppId", h.config.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)). | ||||
| 		SetSuccessResult(&res).Get(url) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		resp.ERROR(c, res.Message) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	builder := make([]string, 0) | ||||
| 	builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt)) | ||||
| 	for _, v := range res.Data.Items { | ||||
| 		builder = append(builder, v.Title) | ||||
| 	} | ||||
| 	builder = append(builder, fmt.Sprintf("%s", res.Data.Title)) | ||||
| 	resp.SUCCESS(c, strings.Join(builder, "\n\n")) | ||||
| } | ||||
|  | ||||
| type imgReq struct { | ||||
| 	Model  string `json:"model"` | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| } | ||||
|  | ||||
| type imgRes struct { | ||||
| 	Created int64 `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		RevisedPrompt string `json:"revised_prompt"` | ||||
| 		Url           string `json:"url"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type ErrRes struct { | ||||
| 	Error struct { | ||||
| 		Code    interface{} `json:"code"` | ||||
| 		Message string      `json:"message"` | ||||
| 		Param   interface{} `json:"param"` | ||||
| 		Type    string      `json:"type"` | ||||
| 	} `json:"error"` | ||||
| } | ||||
|  | ||||
| // Dall3 DallE3 AI 绘图 | ||||
| func (h *FunctionHandler) Dall3(c *gin.Context) { | ||||
| 	if err := h.checkAuth(c); err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var params map[string]interface{} | ||||
| 	if err := c.ShouldBindJSON(¶ms); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	logger.Debugf("绘画参数:%+v", params) | ||||
| 	var user model.User | ||||
| 	tx := h.DB.Where("id = ?", params["user_id"]).First(&user) | ||||
| 	if tx.Error != nil { | ||||
| 		resp.ERROR(c, "当前用户不存在!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if user.Power < h.App.SysConfig.DallPower { | ||||
| 		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	prompt := utils.InterfaceToString(params["prompt"]) | ||||
| 	// get image generation API KEY | ||||
| 	var apiKey model.ApiKey | ||||
| 	tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) | ||||
| 	if tx.Error != nil { | ||||
| 		resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// translate prompt | ||||
| 	const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" | ||||
| 	pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"])) | ||||
| 	if err == nil { | ||||
| 		logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt) | ||||
| 		prompt = pt | ||||
| 	} | ||||
| 	var res imgRes | ||||
| 	var errRes ErrRes | ||||
| 	var request *req.Request | ||||
| 	if apiKey.ProxyURL != "" { | ||||
| 		request = req.C().SetProxyURL(apiKey.ProxyURL).R() | ||||
| 	} else { | ||||
| 		request = req.C().R() | ||||
| 	} | ||||
| 	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) | ||||
| 	r, err := request.SetHeader("Content-Type", "application/json"). | ||||
| 		SetHeader("Authorization", "Bearer "+apiKey.Value). | ||||
| 		SetBody(imgReq{ | ||||
| 			Model:  "dall-e-3", | ||||
| 			Prompt: prompt, | ||||
| 			N:      1, | ||||
| 			Size:   "1024x1024", | ||||
| 		}). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		SetSuccessResult(&res).Post(apiKey.ApiURL) | ||||
| 	if r.IsErrorState() { | ||||
| 		resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message) | ||||
| 		return | ||||
| 	} | ||||
| 	// 更新 API KEY 的最后使用时间 | ||||
| 	h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) | ||||
| 	logger.Debugf("%+v", res) | ||||
| 	// 存储图片 | ||||
| 	imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "下载图片失败: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL) | ||||
| 	// 更新用户算力 | ||||
| 	tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		var u model.User | ||||
| 		h.DB.Where("id", user.Id).First(&u) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    h.App.SysConfig.DallPower, | ||||
| 			Balance:   u.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     "dall-e-3", | ||||
| 			Remark:    fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, content) | ||||
| } | ||||
							
								
								
									
										93
									
								
								api/handler/invite_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								api/handler/invite_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // InviteHandler 用户邀请 | ||||
| type InviteHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler { | ||||
| 	return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // Code 获取当前用户邀请码 | ||||
| func (h *InviteHandler) Code(c *gin.Context) { | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var inviteCode model.InviteCode | ||||
| 	res := h.DB.Where("user_id = ?", userId).First(&inviteCode) | ||||
| 	// 如果邀请码不存在,则创建一个 | ||||
| 	if res.Error != nil { | ||||
| 		code := strings.ToUpper(utils.RandString(8)) | ||||
| 		for { | ||||
| 			res = h.DB.Where("code = ?", code).First(&inviteCode) | ||||
| 			if res.Error != nil { // 不存在相同的邀请码则退出 | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		inviteCode.UserId = userId | ||||
| 		inviteCode.Code = code | ||||
| 		h.DB.Create(&inviteCode) | ||||
| 	} | ||||
|  | ||||
| 	var codeVo vo.InviteCode | ||||
| 	err := utils.CopyObject(inviteCode, &codeVo) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "拷贝对象失败") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, codeVo) | ||||
| } | ||||
|  | ||||
| // List Log 用户邀请记录 | ||||
| func (h *InviteHandler) List(c *gin.Context) { | ||||
|  | ||||
| 	var data struct { | ||||
| 		Page     int `json:"page"` | ||||
| 		PageSize int `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId) | ||||
| 	var total int64 | ||||
| 	session.Model(&model.InviteLog{}).Count(&total) | ||||
| 	var items []model.InviteLog | ||||
| 	var list = make([]vo.InviteLog, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var v vo.InviteLog | ||||
| 			err := utils.CopyObject(item, &v) | ||||
| 			if err == nil { | ||||
| 				v.Id = item.Id | ||||
| 				v.CreatedAt = item.CreatedAt.Unix() | ||||
| 				list = append(list, v) | ||||
| 			} else { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
|  | ||||
| // Hits 访问邀请码 | ||||
| func (h *InviteHandler) Hits(c *gin.Context) { | ||||
| 	code := c.Query("code") | ||||
| 	h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1)) | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
| @@ -3,66 +3,58 @@ package handler | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/mj" | ||||
| 	"chatplus/service/mj/plus" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type MidJourneyHandler struct { | ||||
| 	BaseHandler | ||||
| 	redis     *redis.Client | ||||
| 	db        *gorm.DB | ||||
| 	mjService *mj.Service | ||||
| 	pool      *mj.ServicePool | ||||
| 	snowflake *service.Snowflake | ||||
| 	uploader  *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewMidJourneyHandler( | ||||
| 	app *core.AppServer, | ||||
| 	client *redis.Client, | ||||
| 	db *gorm.DB, | ||||
| 	mjService *mj.Service) *MidJourneyHandler { | ||||
| 	h := MidJourneyHandler{ | ||||
| 		redis:     client, | ||||
| 		db:        db, | ||||
| 		mjService: mjService, | ||||
| func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { | ||||
| 	return &MidJourneyHandler{ | ||||
| 		snowflake: snowflake, | ||||
| 		pool:      pool, | ||||
| 		uploader:  manager, | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: app, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 	} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| } | ||||
|  | ||||
| // Client WebSocket 客户端,用于通知任务状态变更 | ||||
| func (h *MidJourneyHandler) Client(c *gin.Context) { | ||||
| 	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sessionId := c.Query("session_id") | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	h.mjService.Clients.Put(sessionId, client) | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.ClientIP()) | ||||
| } | ||||
|  | ||||
| func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if user.ImgCalls <= 0 { | ||||
| 		resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") | ||||
| 	if user.Power < h.App.SysConfig.MjPower { | ||||
| 		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if !h.pool.HasAvailableService() { | ||||
| 		resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!") | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| @@ -70,30 +62,50 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { | ||||
|  | ||||
| } | ||||
|  | ||||
| // Image 创建一个绘画任务 | ||||
| func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 	if !h.App.Config.MjConfig.Enabled { | ||||
| 		resp.ERROR(c, "MidJourney service is disabled") | ||||
| // Client WebSocket 客户端,用于通知任务状态变更 | ||||
| func (h *MidJourneyHandler) Client(c *gin.Context) { | ||||
| 	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	if userId == 0 { | ||||
| 		logger.Info("Invalid user ID") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	h.pool.Clients.Put(uint(userId), client) | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) | ||||
| } | ||||
|  | ||||
| // Image 创建一个绘画任务 | ||||
| func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		SessionId string  `json:"session_id"` | ||||
| 		Prompt    string  `json:"prompt"` | ||||
| 		Rate      string  `json:"rate"` | ||||
| 		Model     string  `json:"model"` | ||||
| 		Chaos     int     `json:"chaos"` | ||||
| 		Raw       bool    `json:"raw"` | ||||
| 		Seed      int64   `json:"seed"` | ||||
| 		Stylize   int     `json:"stylize"` | ||||
| 		Img       string  `json:"img"` | ||||
| 		Weight    float32 `json:"weight"` | ||||
| 		SessionId string   `json:"session_id"` | ||||
| 		TaskType  string   `json:"task_type"` | ||||
| 		Prompt    string   `json:"prompt"` | ||||
| 		NegPrompt string   `json:"neg_prompt"` | ||||
| 		Rate      string   `json:"rate"` | ||||
| 		Model     string   `json:"model"` | ||||
| 		Chaos     int      `json:"chaos"` | ||||
| 		Raw       bool     `json:"raw"` | ||||
| 		Seed      int64    `json:"seed"` | ||||
| 		Stylize   int      `json:"stylize"` | ||||
| 		ImgArr    []string `json:"img_arr"` | ||||
| 		Tile      bool     `json:"tile"` | ||||
| 		Quality   float32  `json:"quality"` | ||||
| 		Weight    float32  `json:"weight"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	if !h.checkLimits(c) { | ||||
| 	if !h.preCheck(c) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -110,57 +122,99 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { | ||||
| 	if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { | ||||
| 		prompt += fmt.Sprintf(" --c %d", data.Chaos) | ||||
| 	} | ||||
| 	if data.Img != "" { | ||||
| 		prompt = fmt.Sprintf("%s %s", data.Img, prompt) | ||||
| 		if data.Weight > 0 { | ||||
| 			prompt += fmt.Sprintf(" --iw %f", data.Weight) | ||||
| 		} | ||||
| 	if data.Weight > 0 { | ||||
| 		prompt += fmt.Sprintf(" --iw %f", data.Weight) | ||||
| 	} | ||||
| 	if data.Raw { | ||||
| 		prompt += " --style raw" | ||||
| 	} | ||||
| 	if data.Quality > 0 { | ||||
| 		prompt += fmt.Sprintf(" --q %.2f", data.Quality) | ||||
| 	} | ||||
| 	if data.NegPrompt != "" { | ||||
| 		prompt += fmt.Sprintf(" --no %s", data.NegPrompt) | ||||
| 	} | ||||
| 	if data.Tile { | ||||
| 		prompt += " --tile " | ||||
| 	} | ||||
| 	if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { | ||||
| 		prompt += data.Model | ||||
| 		prompt += fmt.Sprintf(" %s", data.Model) | ||||
| 	} | ||||
|  | ||||
| 	// 处理融图和换脸的提示词 | ||||
| 	if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() { | ||||
| 		prompt = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ",")) | ||||
| 	} | ||||
|  | ||||
| 	idValue, _ := c.Get(types.LoginUserID) | ||||
| 	userId := utils.IntValue(utils.InterfaceToString(idValue), 0) | ||||
| 	// generate task id | ||||
| 	taskId, err := h.snowflake.Next(true) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "error with generate task id: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	job := model.MidJourneyJob{ | ||||
| 		Type:      types.TaskImage.String(), | ||||
| 		Type:      data.TaskType, | ||||
| 		UserId:    userId, | ||||
| 		TaskId:    taskId, | ||||
| 		Progress:  0, | ||||
| 		Prompt:    prompt, | ||||
| 		Power:     h.App.SysConfig.MjPower, | ||||
| 		CreatedAt: time.Now(), | ||||
| 	} | ||||
| 	if res := h.db.Create(&job); res.Error != nil { | ||||
| 	opt := "绘图" | ||||
| 	if data.TaskType == types.TaskBlend.String() { | ||||
| 		job.Prompt = "融图:" + strings.Join(data.ImgArr, ",") | ||||
| 		opt = "融图" | ||||
| 	} else if data.TaskType == types.TaskSwapFace.String() { | ||||
| 		job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",") | ||||
| 		opt = "换脸" | ||||
| 	} | ||||
|  | ||||
| 	if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "添加任务失败:"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.mjService.PushTask(types.MjTask{ | ||||
| 		Id:        int(job.Id), | ||||
| 	h.pool.PushTask(types.MjTask{ | ||||
| 		Id:        job.Id, | ||||
| 		TaskId:    taskId, | ||||
| 		SessionId: data.SessionId, | ||||
| 		Src:       types.TaskSrcImg, | ||||
| 		Type:      types.TaskImage, | ||||
| 		Type:      types.TaskType(data.TaskType), | ||||
| 		Prompt:    prompt, | ||||
| 		UserId:    userId, | ||||
| 		ImgArr:    data.ImgArr, | ||||
| 	}) | ||||
|  | ||||
| 	var jobVo vo.MidJourneyJob | ||||
| 	err := utils.CopyObject(job, &jobVo) | ||||
| 	if err == nil { | ||||
| 		// 推送任务到前端 | ||||
| 		client := h.mjService.Clients.Get(data.SessionId) | ||||
| 		if client != nil { | ||||
| 			utils.ReplyChunkMessage(client, jobVo) | ||||
| 		} | ||||
| 	client := h.pool.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
|  | ||||
| 	// update user's power | ||||
| 	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		user, _ := h.GetLoginUser(c) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    job.Power, | ||||
| 			Balance:   user.Power - job.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     "mid-journey", | ||||
| 			Remark:    fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| type reqVo struct { | ||||
| 	Src         string `json:"src"` | ||||
| 	Index       int    `json:"index"` | ||||
| 	ChannelId   string `json:"channel_id"` | ||||
| 	MessageId   string `json:"message_id"` | ||||
| 	MessageHash string `json:"message_hash"` | ||||
| 	SessionId   string `json:"session_id"` | ||||
| @@ -178,65 +232,44 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !h.checkLimits(c) { | ||||
| 	if !h.preCheck(c) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	idValue, _ := c.Get(types.LoginUserID) | ||||
| 	jobId := 0 | ||||
| 	userId := utils.IntValue(utils.InterfaceToString(idValue), 0) | ||||
| 	src := types.TaskSrc(data.Src) | ||||
| 	if src == types.TaskSrcImg { | ||||
| 		job := model.MidJourneyJob{ | ||||
| 			Type:      types.TaskUpscale.String(), | ||||
| 			UserId:    userId, | ||||
| 			Hash:      data.MessageHash, | ||||
| 			Progress:  0, | ||||
| 			Prompt:    data.Prompt, | ||||
| 			CreatedAt: time.Now(), | ||||
| 		} | ||||
| 		if res := h.db.Create(&job); res.Error == nil { | ||||
| 			jobId = int(job.Id) | ||||
| 		} else { | ||||
| 			resp.ERROR(c, "添加任务失败:"+res.Error.Error()) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var jobVo vo.MidJourneyJob | ||||
| 		err := utils.CopyObject(job, &jobVo) | ||||
| 		if err == nil { | ||||
| 			// 推送任务到前端 | ||||
| 			client := h.mjService.Clients.Get(data.SessionId) | ||||
| 			if client != nil { | ||||
| 				utils.ReplyChunkMessage(client, jobVo) | ||||
| 			} | ||||
| 		} | ||||
| 	taskId, _ := h.snowflake.Next(true) | ||||
| 	job := model.MidJourneyJob{ | ||||
| 		Type:        types.TaskUpscale.String(), | ||||
| 		ReferenceId: data.MessageId, | ||||
| 		UserId:      userId, | ||||
| 		TaskId:      taskId, | ||||
| 		Progress:    0, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		CreatedAt:   time.Now(), | ||||
| 	} | ||||
| 	h.mjService.PushTask(types.MjTask{ | ||||
| 		Id:          jobId, | ||||
| 	if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "添加任务失败:"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.pool.PushTask(types.MjTask{ | ||||
| 		Id:          job.Id, | ||||
| 		SessionId:   data.SessionId, | ||||
| 		Src:         src, | ||||
| 		Type:        types.TaskUpscale, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		UserId:      userId, | ||||
| 		RoleId:      data.RoleId, | ||||
| 		Icon:        data.Icon, | ||||
| 		ChatId:      data.ChatId, | ||||
| 		ChannelId:   data.ChannelId, | ||||
| 		Index:       data.Index, | ||||
| 		MessageId:   data.MessageId, | ||||
| 		MessageHash: data.MessageHash, | ||||
| 	}) | ||||
|  | ||||
| 	if src == types.TaskSrcChat { | ||||
| 		wsClient := h.App.ChatClients.Get(data.SessionId) | ||||
| 		if wsClient != nil { | ||||
| 			content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) | ||||
| 			utils.ReplyMessage(wsClient, content) | ||||
| 			if h.mjService.ChatClients.Get(data.SessionId) == nil { | ||||
| 				h.mjService.ChatClients.Put(data.SessionId, wsClient) | ||||
| 			} | ||||
| 		} | ||||
| 	client := h.pool.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| @@ -248,79 +281,100 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !h.checkLimits(c) { | ||||
| 	if !h.preCheck(c) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	idValue, _ := c.Get(types.LoginUserID) | ||||
| 	jobId := 0 | ||||
| 	userId := utils.IntValue(utils.InterfaceToString(idValue), 0) | ||||
| 	src := types.TaskSrc(data.Src) | ||||
| 	if src == types.TaskSrcImg { | ||||
| 		job := model.MidJourneyJob{ | ||||
| 			Type:      types.TaskVariation.String(), | ||||
| 			UserId:    userId, | ||||
| 			ImgURL:    "", | ||||
| 			Hash:      data.MessageHash, | ||||
| 			Progress:  0, | ||||
| 			Prompt:    data.Prompt, | ||||
| 			CreatedAt: time.Now(), | ||||
| 		} | ||||
| 		if res := h.db.Create(&job); res.Error == nil { | ||||
| 			jobId = int(job.Id) | ||||
| 		} else { | ||||
| 			resp.ERROR(c, "添加任务失败:"+res.Error.Error()) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var jobVo vo.MidJourneyJob | ||||
| 		err := utils.CopyObject(job, &jobVo) | ||||
| 		if err == nil { | ||||
| 			// 推送任务到前端 | ||||
| 			client := h.mjService.Clients.Get(data.SessionId) | ||||
| 			if client != nil { | ||||
| 				utils.ReplyChunkMessage(client, jobVo) | ||||
| 			} | ||||
| 		} | ||||
| 	taskId, _ := h.snowflake.Next(true) | ||||
| 	job := model.MidJourneyJob{ | ||||
| 		Type:        types.TaskVariation.String(), | ||||
| 		ChannelId:   data.ChannelId, | ||||
| 		ReferenceId: data.MessageId, | ||||
| 		UserId:      userId, | ||||
| 		TaskId:      taskId, | ||||
| 		Progress:    0, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		Power:       h.App.SysConfig.MjPower, | ||||
| 		CreatedAt:   time.Now(), | ||||
| 	} | ||||
| 	h.mjService.PushTask(types.MjTask{ | ||||
| 		Id:          jobId, | ||||
| 	if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "添加任务失败:"+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.pool.PushTask(types.MjTask{ | ||||
| 		Id:          job.Id, | ||||
| 		SessionId:   data.SessionId, | ||||
| 		Src:         src, | ||||
| 		Type:        types.TaskVariation, | ||||
| 		Prompt:      data.Prompt, | ||||
| 		UserId:      userId, | ||||
| 		RoleId:      data.RoleId, | ||||
| 		Icon:        data.Icon, | ||||
| 		ChatId:      data.ChatId, | ||||
| 		Index:       data.Index, | ||||
| 		ChannelId:   data.ChannelId, | ||||
| 		MessageId:   data.MessageId, | ||||
| 		MessageHash: data.MessageHash, | ||||
| 	}) | ||||
|  | ||||
| 	if src == types.TaskSrcChat { | ||||
| 		// 从聊天窗口发送的请求,记录客户端信息 | ||||
| 		wsClient := h.mjService.ChatClients.Get(data.SessionId) | ||||
| 		if wsClient != nil { | ||||
| 			content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) | ||||
| 			utils.ReplyMessage(wsClient, content) | ||||
| 			if h.mjService.Clients.Get(data.SessionId) == nil { | ||||
| 				h.mjService.Clients.Put(data.SessionId, wsClient) | ||||
| 			} | ||||
| 		} | ||||
| 	client := h.pool.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
|  | ||||
| 	// update user's power | ||||
| 	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		user, _ := h.GetLoginUser(c) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    job.Power, | ||||
| 			Balance:   user.Power - job.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     "mid-journey", | ||||
| 			Remark:    fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // JobList 获取 MJ 任务列表 | ||||
| func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| 	status := h.GetInt(c, "status", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| // ImgWall 照片墙 | ||||
| func (h *MidJourneyHandler) ImgWall(c *gin.Context) { | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	err, jobs := h.getData(true, 0, page, pageSize, true) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.db.Session(&gorm.Session{}) | ||||
| 	if status == 1 { | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| } | ||||
|  | ||||
| // JobList 获取 MJ 任务列表 | ||||
| func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| 	status := h.GetBool(c, "status") | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	publish := h.GetBool(c, "publish") | ||||
|  | ||||
| 	err, jobs := h.getData(status, userId, page, pageSize, publish) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| } | ||||
|  | ||||
| // JobList 获取 MJ 任务列表 | ||||
| func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if finish { | ||||
| 		session = session.Where("progress = ?", 100).Order("id DESC") | ||||
| 	} else { | ||||
| 		session = session.Where("progress < ?", 100).Order("id ASC") | ||||
| @@ -328,6 +382,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| 	if userId > 0 { | ||||
| 		session = session.Where("user_id = ?", userId) | ||||
| 	} | ||||
| 	if publish { | ||||
| 		session = session.Where("publish = ?", publish) | ||||
| 	} | ||||
| 	if page > 0 && pageSize > 0 { | ||||
| 		offset := (page - 1) * pageSize | ||||
| 		session = session.Offset(offset).Limit(pageSize) | ||||
| @@ -336,8 +393,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| 	var items []model.MidJourneyJob | ||||
| 	res := session.Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, types.NoData) | ||||
| 		return | ||||
| 		return res.Error, nil | ||||
| 	} | ||||
|  | ||||
| 	var jobs = make([]vo.MidJourneyJob, 0) | ||||
| @@ -347,20 +403,94 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if item.Progress < 100 { | ||||
| 			// 30 分钟还没完成的任务直接删除 | ||||
| 			if time.Now().Sub(item.CreatedAt) > time.Minute*30 { | ||||
| 				h.db.Delete(&item) | ||||
| 				continue | ||||
| 			} | ||||
| 			if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 | ||||
| 				image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) | ||||
|  | ||||
| 		if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" { | ||||
| 			// discord 服务器图片需要使用代理转发图片数据流 | ||||
| 			if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") { | ||||
| 				image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) | ||||
| 				if err == nil { | ||||
| 					job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) | ||||
| 				} | ||||
| 			} else { | ||||
| 				job.ImgURL = job.OrgURL | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		jobs = append(jobs, job) | ||||
| 	} | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| 	return nil, jobs | ||||
| } | ||||
|  | ||||
| // Remove remove task image | ||||
| func (h *MidJourneyHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint   `json:"id"` | ||||
| 		UserId uint   `json:"user_id"` | ||||
| 		ImgURL string `json:"img_url"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove job recode | ||||
| 	res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove image | ||||
| 	err := h.uploader.GetUploadHandler().Delete(data.ImgURL) | ||||
| 	if err != nil { | ||||
| 		logger.Error("remove image failed: ", err) | ||||
| 	} | ||||
|  | ||||
| 	client := h.pool.Clients.Get(data.UserId) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Notify MidJourney Plus 服务任务回调处理 | ||||
| func (h *MidJourneyHandler) Notify(c *gin.Context) { | ||||
| 	var data plus.CBReq | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		logger.Error("非法任务回调:%+v", err) | ||||
| 		return | ||||
| 	} | ||||
| 	err := h.pool.Notify(data) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 	} else { | ||||
| 		userId := h.GetLoginUserId(c) | ||||
| 		client := h.pool.Clients.Get(userId) | ||||
| 		if client != nil { | ||||
| 			_ = client.Send([]byte("Task Updated")) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Publish 发布图片到画廊显示 | ||||
| func (h *MidJourneyHandler) Publish(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint `json:"id"` | ||||
| 		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
							
								
								
									
										55
									
								
								api/handler/order_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								api/handler/order_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type OrderHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { | ||||
| 	return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *OrderHandler) List(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Page     int `json:"page"` | ||||
| 		PageSize int `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess) | ||||
| 	var total int64 | ||||
| 	session.Model(&model.Order{}).Count(&total) | ||||
| 	var items []model.Order | ||||
| 	var list = make([]vo.Order, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var order vo.Order | ||||
| 			err := utils.CopyObject(item, &order) | ||||
| 			if err == nil { | ||||
| 				order.Id = item.Id | ||||
| 				order.CreatedAt = item.CreatedAt.Unix() | ||||
| 				order.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 				list = append(list, order) | ||||
| 			} else { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
							
								
								
									
										590
									
								
								api/handler/payment_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										590
									
								
								api/handler/payment_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,590 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/payment" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"embed" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/shopspring/decimal" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	PayWayAlipay = "支付宝" | ||||
| 	PayWayXunHu  = "虎皮椒" | ||||
| 	PayWayJs     = "PayJS" | ||||
| ) | ||||
|  | ||||
| // PaymentHandler 支付服务回调 handler | ||||
| type PaymentHandler struct { | ||||
| 	BaseHandler | ||||
| 	alipayService  *payment.AlipayService | ||||
| 	huPiPayService *payment.HuPiPayService | ||||
| 	js             *payment.PayJS | ||||
| 	snowflake      *service.Snowflake | ||||
| 	fs             embed.FS | ||||
| 	lock           sync.Mutex | ||||
| } | ||||
|  | ||||
| func NewPaymentHandler( | ||||
| 	server *core.AppServer, | ||||
| 	alipayService *payment.AlipayService, | ||||
| 	huPiPayService *payment.HuPiPayService, | ||||
| 	js *payment.PayJS, | ||||
| 	db *gorm.DB, | ||||
| 	snowflake *service.Snowflake, | ||||
| 	fs embed.FS) *PaymentHandler { | ||||
| 	return &PaymentHandler{ | ||||
| 		alipayService:  alipayService, | ||||
| 		huPiPayService: huPiPayService, | ||||
| 		js:             js, | ||||
| 		snowflake:      snowflake, | ||||
| 		fs:             fs, | ||||
| 		lock:           sync.Mutex{}, | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: server, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h *PaymentHandler) DoPay(c *gin.Context) { | ||||
| 	orderNo := h.GetTrim(c, "order_no") | ||||
| 	payWay := h.GetTrim(c, "pay_way") | ||||
|  | ||||
| 	if orderNo == "" { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var order model.Order | ||||
| 	res := h.DB.Where("order_no = ?", orderNo).First(&order) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Order not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回 | ||||
| 	if order.Status == types.OrderPaidSuccess { | ||||
| 		resp.ERROR(c, "This order had been paid, please do not pay twice") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 更新扫码状态 | ||||
| 	h.DB.Model(&order).UpdateColumn("status", types.OrderScanned) | ||||
| 	if payWay == "alipay" { // 支付宝 | ||||
| 		// 生成支付链接 | ||||
| 		notifyURL := h.App.Config.AlipayConfig.NotifyURL | ||||
| 		returnURL := "" // 关闭同步回跳 | ||||
| 		amount := fmt.Sprintf("%.2f", order.Amount) | ||||
|  | ||||
| 		uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "error with generate pay url: "+err.Error()) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		c.Redirect(302, uri) | ||||
| 		return | ||||
| 	} else if payWay == "hupi" { // 虎皮椒支付 | ||||
| 		params := payment.HuPiPayReq{ | ||||
| 			Version:      "1.1", | ||||
| 			TradeOrderId: orderNo, | ||||
| 			TotalFee:     fmt.Sprintf("%f", order.Amount), | ||||
| 			Title:        order.Subject, | ||||
| 			NotifyURL:    h.App.Config.HuPiPayConfig.NotifyURL, | ||||
| 			WapName:      "极客学长", | ||||
| 		} | ||||
| 		r, err := h.huPiPayService.Pay(params) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		c.Redirect(302, r.URL) | ||||
| 	} | ||||
| 	resp.ERROR(c, "Invalid operations") | ||||
| } | ||||
|  | ||||
| // OrderQuery 查询订单状态 | ||||
| func (h *PaymentHandler) OrderQuery(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		OrderNo string `json:"order_no"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var order model.Order | ||||
| 	res := h.DB.Where("order_no = ?", data.OrderNo).First(&order) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Order not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if order.Status == types.OrderPaidSuccess { | ||||
| 		resp.SUCCESS(c, gin.H{"status": order.Status}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	counter := 0 | ||||
| 	for { | ||||
| 		time.Sleep(time.Second) | ||||
| 		var item model.Order | ||||
| 		h.DB.Where("order_no = ?", data.OrderNo).First(&item) | ||||
| 		if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status { | ||||
| 			order.Status = item.Status | ||||
| 			break | ||||
| 		} | ||||
| 		counter++ | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, gin.H{"status": order.Status}) | ||||
| } | ||||
|  | ||||
| // PayQrcode 生成支付 URL 二维码 | ||||
| func (h *PaymentHandler) PayQrcode(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		PayWay    string `json:"pay_way"` // 支付方式 | ||||
| 		ProductId uint   `json:"product_id"` | ||||
| 		UserId    int    `json:"user_id"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var product model.Product | ||||
| 	res := h.DB.First(&product, data.ProductId) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Product not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	orderNo, err := h.snowflake.Next(false) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "error with generate trade no: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	res = h.DB.First(&user, data.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Invalid user ID") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var payWay string | ||||
| 	var notifyURL string | ||||
| 	switch data.PayWay { | ||||
| 	case "hupi": | ||||
| 		payWay = PayWayXunHu | ||||
| 		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL | ||||
| 	case "payjs": | ||||
| 		payWay = PayWayJs | ||||
| 		notifyURL = h.App.Config.JPayConfig.NotifyURL | ||||
| 	default: | ||||
| 		payWay = PayWayAlipay | ||||
| 		notifyURL = h.App.Config.AlipayConfig.NotifyURL | ||||
| 	} | ||||
| 	// 创建订单 | ||||
| 	remark := types.OrderRemark{ | ||||
| 		Days:     product.Days, | ||||
| 		Power:    product.Power, | ||||
| 		Name:     product.Name, | ||||
| 		Price:    product.Price, | ||||
| 		Discount: product.Discount, | ||||
| 	} | ||||
|  | ||||
| 	amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64() | ||||
| 	order := model.Order{ | ||||
| 		UserId:    user.Id, | ||||
| 		Username:  user.Username, | ||||
| 		ProductId: product.Id, | ||||
| 		OrderNo:   orderNo, | ||||
| 		Subject:   product.Name, | ||||
| 		Amount:    amount, | ||||
| 		Status:    types.OrderNotPaid, | ||||
| 		PayWay:    payWay, | ||||
| 		Remark:    utils.JsonEncode(remark), | ||||
| 	} | ||||
| 	res = h.DB.Create(&order) | ||||
| 	if res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "error with create order: "+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// PayJs 单独处理,只能用官方生成的二维码 | ||||
| 	if data.PayWay == "payjs" { | ||||
| 		params := payment.JPayReq{ | ||||
| 			TotalFee:   int(math.Ceil(order.Amount * 100)), | ||||
| 			OutTradeNo: order.OrderNo, | ||||
| 			Subject:    product.Name, | ||||
| 		} | ||||
| 		r := h.js.Pay(params) | ||||
| 		if r.IsOK() { | ||||
| 			resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode}) | ||||
| 			return | ||||
| 		} else { | ||||
| 			resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var logo string | ||||
| 	if data.PayWay == "alipay" { | ||||
| 		logo = "res/img/alipay.jpg" | ||||
| 	} else if data.PayWay == "hupi" { | ||||
| 		if h.App.Config.HuPiPayConfig.Name == "wechat" { | ||||
| 			logo = "res/img/wechat-pay.jpg" | ||||
| 		} else { | ||||
| 			logo = "res/img/alipay.jpg" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	file, err := h.fs.Open(logo) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "error with open qrcode log file: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	parse, err := url.Parse(notifyURL) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay) | ||||
| 	imgData, err := utils.GenQrcode(imageURL, 400, file) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	imgDataBase64 := base64.StdEncoding.EncodeToString(imgData) | ||||
| 	resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL}) | ||||
| } | ||||
|  | ||||
| // Mobile 移动端支付 | ||||
| func (h *PaymentHandler) Mobile(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		PayWay    string `json:"pay_way"` // 支付方式 | ||||
| 		ProductId uint   `json:"product_id"` | ||||
| 		UserId    int    `json:"user_id"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var product model.Product | ||||
| 	res := h.DB.First(&product, data.ProductId) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Product not found") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	orderNo, err := h.snowflake.Next(false) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "error with generate trade no: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	res = h.DB.First(&user, data.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "Invalid user ID") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64() | ||||
| 	var payWay string | ||||
| 	var notifyURL, returnURL string | ||||
| 	var payURL string | ||||
| 	switch data.PayWay { | ||||
| 	case "hupi": | ||||
| 		payWay = PayWayXunHu | ||||
| 		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL | ||||
| 		returnURL = h.App.Config.HuPiPayConfig.ReturnURL | ||||
| 		params := payment.HuPiPayReq{ | ||||
| 			Version:      "1.1", | ||||
| 			TradeOrderId: orderNo, | ||||
| 			TotalFee:     fmt.Sprintf("%f", amount), | ||||
| 			Title:        product.Name, | ||||
| 			NotifyURL:    notifyURL, | ||||
| 			ReturnURL:    returnURL, | ||||
| 			CallbackURL:  returnURL, | ||||
| 			WapName:      "极客学长", | ||||
| 		} | ||||
| 		r, err := h.huPiPayService.Pay(params) | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with generating Pay URL: ", err.Error()) | ||||
| 			resp.ERROR(c, "error with generating Pay URL: "+err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		payURL = r.URL | ||||
| 	case "payjs": | ||||
| 		payWay = PayWayJs | ||||
| 		notifyURL = h.App.Config.JPayConfig.NotifyURL | ||||
| 		returnURL = h.App.Config.JPayConfig.ReturnURL | ||||
| 		totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart() | ||||
| 		params := url.Values{} | ||||
| 		params.Add("total_fee", fmt.Sprintf("%d", totalFee)) | ||||
| 		params.Add("out_trade_no", orderNo) | ||||
| 		params.Add("body", product.Name) | ||||
| 		params.Add("notify_url", notifyURL) | ||||
| 		params.Add("auto", "0") | ||||
| 		payURL = h.js.PayH5(params) | ||||
| 	case "alipay": | ||||
| 		payWay = PayWayAlipay | ||||
| 		notifyURL = h.App.Config.AlipayConfig.NotifyURL | ||||
| 		returnURL = h.App.Config.AlipayConfig.ReturnURL | ||||
| 		payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name) | ||||
| 		if err != nil { | ||||
| 			resp.ERROR(c, "error with generating Pay URL: "+err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 	default: | ||||
| 		resp.ERROR(c, "Unsupported pay way: "+data.PayWay) | ||||
| 		return | ||||
| 	} | ||||
| 	// 创建订单 | ||||
| 	remark := types.OrderRemark{ | ||||
| 		Days:     product.Days, | ||||
| 		Power:    product.Power, | ||||
| 		Name:     product.Name, | ||||
| 		Price:    product.Price, | ||||
| 		Discount: product.Discount, | ||||
| 	} | ||||
|  | ||||
| 	order := model.Order{ | ||||
| 		UserId:    user.Id, | ||||
| 		Username:  user.Username, | ||||
| 		ProductId: product.Id, | ||||
| 		OrderNo:   orderNo, | ||||
| 		Subject:   product.Name, | ||||
| 		Amount:    amount, | ||||
| 		Status:    types.OrderNotPaid, | ||||
| 		PayWay:    payWay, | ||||
| 		Remark:    utils.JsonEncode(remark), | ||||
| 	} | ||||
| 	res = h.DB.Create(&order) | ||||
| 	if res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "error with create order: "+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, payURL) | ||||
| } | ||||
|  | ||||
| // 异步通知回调公共逻辑 | ||||
| func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { | ||||
| 	var order model.Order | ||||
| 	res := h.DB.Where("order_no = ?", orderNo).First(&order) | ||||
| 	if res.Error != nil { | ||||
| 		err := fmt.Errorf("error with fetch order: %v", res.Error) | ||||
| 		logger.Error(err) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	h.lock.Lock() | ||||
| 	defer h.lock.Unlock() | ||||
|  | ||||
| 	// 已支付订单,直接返回 | ||||
| 	if order.Status == types.OrderPaidSuccess { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var user model.User | ||||
| 	res = h.DB.First(&user, order.UserId) | ||||
| 	if res.Error != nil { | ||||
| 		err := fmt.Errorf("error with fetch user info: %v", res.Error) | ||||
| 		logger.Error(err) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var remark types.OrderRemark | ||||
| 	err := utils.JsonDecode(order.Remark, &remark) | ||||
| 	if err != nil { | ||||
| 		err := fmt.Errorf("error with decode order remark: %v", err) | ||||
| 		logger.Error(err) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var opt string | ||||
| 	var power int | ||||
| 	if user.Vip { // 已经是 VIP 用户 | ||||
| 		if remark.Days > 0 { // 只延期 VIP,不增加调用次数 | ||||
| 			user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix() | ||||
| 		} else { // 充值点卡,直接增加次数即可 | ||||
| 			user.Power += remark.Power | ||||
| 			opt = "点卡充值" | ||||
| 			power = remark.Power | ||||
| 		} | ||||
|  | ||||
| 	} else {                 // 非 VIP 用户 | ||||
| 		if remark.Days > 0 { // vip 套餐:days > 0, power == 0 | ||||
| 			user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix() | ||||
| 			user.Power += h.App.SysConfig.VipMonthPower | ||||
| 			user.Vip = true | ||||
| 			opt = "VIP充值" | ||||
| 			power = h.App.SysConfig.VipMonthPower | ||||
| 		} else { //点卡:days == 0, calls > 0 | ||||
| 			user.Power += remark.Power | ||||
| 			opt = "点卡充值" | ||||
| 			power = remark.Power | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 更新用户信息 | ||||
| 	res = h.DB.Updates(&user) | ||||
| 	if res.Error != nil { | ||||
| 		err := fmt.Errorf("error with update user info: %v", res.Error) | ||||
| 		logger.Error(err) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 更新订单状态 | ||||
| 	order.PayTime = time.Now().Unix() | ||||
| 	order.Status = types.OrderPaidSuccess | ||||
| 	order.TradeNo = tradeNo | ||||
| 	res = h.DB.Updates(&order) | ||||
| 	if res.Error != nil { | ||||
| 		err := fmt.Errorf("error with update order info: %v", res.Error) | ||||
| 		logger.Error(err) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 更新产品销量 | ||||
| 	h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1)) | ||||
|  | ||||
| 	// 记录算力充值日志 | ||||
| 	if opt != "" { | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerRecharge, | ||||
| 			Amount:    power, | ||||
| 			Balance:   user.Power, | ||||
| 			Mark:      types.PowerAdd, | ||||
| 			Model:     order.PayWay, | ||||
| 			Remark:    fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // GetPayWays 获取支付方式 | ||||
| func (h *PaymentHandler) GetPayWays(c *gin.Context) { | ||||
| 	data := gin.H{} | ||||
| 	if h.App.Config.AlipayConfig.Enabled { | ||||
| 		data["alipay"] = gin.H{"name": "alipay"} | ||||
| 	} | ||||
| 	if h.App.Config.HuPiPayConfig.Enabled { | ||||
| 		data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name} | ||||
| 	} | ||||
| 	if h.App.Config.JPayConfig.Enabled { | ||||
| 		data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, data) | ||||
| } | ||||
|  | ||||
| // HuPiPayNotify 虎皮椒支付异步回调 | ||||
| func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) { | ||||
| 	err := c.Request.ParseForm() | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	orderNo := c.Request.Form.Get("trade_order_id") | ||||
| 	tradeNo := c.Request.Form.Get("open_order_id") | ||||
| 	logger.Infof("收到虎皮椒订单支付回调,订单 NO:%s,交易流水号:%s", orderNo, tradeNo) | ||||
|  | ||||
| 	if err = h.huPiPayService.Check(tradeNo); err != nil { | ||||
| 		logger.Error("订单校验失败:", err) | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
| 	err = h.notify(orderNo, tradeNo) | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.String(http.StatusOK, "success") | ||||
| } | ||||
|  | ||||
| // AlipayNotify 支付宝支付回调 | ||||
| func (h *PaymentHandler) AlipayNotify(c *gin.Context) { | ||||
| 	err := c.Request.ParseForm() | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// TODO:验证交易签名 | ||||
| 	res := h.alipayService.TradeVerify(c.Request.Form) | ||||
| 	logger.Infof("验证支付结果:%+v", res) | ||||
| 	if !res.Success() { | ||||
| 		logger.Error("订单校验失败:", res.Message) | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tradeNo := c.Request.Form.Get("trade_no") | ||||
| 	err = h.notify(res.OutTradeNo, tradeNo) | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.String(http.StatusOK, "success") | ||||
| } | ||||
|  | ||||
| // PayJsNotify PayJs 支付异步回调 | ||||
| func (h *PaymentHandler) PayJsNotify(c *gin.Context) { | ||||
| 	err := c.Request.ParseForm() | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	orderNo := c.Request.Form.Get("out_trade_no") | ||||
| 	returnCode := c.Request.Form.Get("return_code") | ||||
| 	logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode) | ||||
| 	// 支付失败 | ||||
| 	if returnCode != "1" { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 校验订单支付状态 | ||||
| 	tradeNo := c.Request.Form.Get("payjs_order_id") | ||||
| 	err = h.js.Check(tradeNo) | ||||
| 	if err != nil { | ||||
| 		logger.Error("订单校验失败:", err) | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.notify(orderNo, tradeNo) | ||||
| 	if err != nil { | ||||
| 		c.String(http.StatusOK, "fail") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.String(http.StatusOK, "success") | ||||
| } | ||||
							
								
								
									
										67
									
								
								api/handler/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								api/handler/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type PowerLogHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler { | ||||
| 	return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| func (h *PowerLogHandler) List(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Model    string   `json:"model"` | ||||
| 		Date     []string `json:"date"` | ||||
| 		Page     int      `json:"page"` | ||||
| 		PageSize int      `json:"page_size"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	session = session.Where("user_id", userId) | ||||
| 	if data.Model != "" { | ||||
| 		session = session.Where("model", data.Model) | ||||
| 	} | ||||
| 	if len(data.Date) == 2 { | ||||
| 		start := data.Date[0] + " 00:00:00" | ||||
| 		end := data.Date[1] + " 00:00:00" | ||||
| 		session = session.Where("created_at >= ? AND created_at <= ?", start, end) | ||||
| 	} | ||||
|  | ||||
| 	var total int64 | ||||
| 	session.Model(&model.PowerLog{}).Count(&total) | ||||
| 	var items []model.PowerLog | ||||
| 	var list = make([]vo.PowerLog, 0) | ||||
| 	offset := (data.Page - 1) * data.PageSize | ||||
| 	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var log vo.PowerLog | ||||
| 			err := utils.CopyObject(item, &log) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			log.Id = item.Id | ||||
| 			log.CreatedAt = item.CreatedAt.Unix() | ||||
| 			log.TypeStr = item.Type.String() | ||||
| 			list = append(list, log) | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) | ||||
| } | ||||
							
								
								
									
										41
									
								
								api/handler/product_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								api/handler/product_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type ProductHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler { | ||||
| 	return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // List 模型列表 | ||||
| func (h *ProductHandler) List(c *gin.Context) { | ||||
| 	var items []model.Product | ||||
| 	var list = make([]vo.Product, 0) | ||||
| 	res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items) | ||||
| 	if res.Error == nil { | ||||
| 		for _, item := range items { | ||||
| 			var product vo.Product | ||||
| 			err := utils.CopyObject(item, &product) | ||||
| 			if err == nil { | ||||
| 				product.Id = item.Id | ||||
| 				product.CreatedAt = item.CreatedAt.Unix() | ||||
| 				product.UpdatedAt = item.UpdatedAt.Unix() | ||||
| 				list = append(list, product) | ||||
| 			} else { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	resp.SUCCESS(c, list) | ||||
| } | ||||
							
								
								
									
										60
									
								
								api/handler/prompt_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								api/handler/prompt_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" | ||||
| const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" | ||||
|  | ||||
| type PromptHandler struct { | ||||
| 	BaseHandler | ||||
| } | ||||
|  | ||||
| func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler { | ||||
| 	return &PromptHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // Rewrite translate and rewrite prompt with ChatGPT | ||||
| func (h *PromptHandler) Rewrite(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Prompt string `json:"prompt"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(rewritePromptTemplate, data.Prompt)) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, content) | ||||
| } | ||||
|  | ||||
| func (h *PromptHandler) Translate(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Prompt string `json:"prompt"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, data.Prompt)) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, content) | ||||
| } | ||||
| @@ -4,22 +4,25 @@ import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"math" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type RewardHandler struct { | ||||
| 	BaseHandler | ||||
| 	db *gorm.DB | ||||
| 	lock sync.Mutex | ||||
| } | ||||
|  | ||||
| func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler { | ||||
| 	h := RewardHandler{db: db} | ||||
| 	h.App = server | ||||
| 	return &h | ||||
| func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { | ||||
| 	return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}} | ||||
| } | ||||
|  | ||||
| // Verify 打赏码核销 | ||||
| @@ -32,11 +35,20 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.HACKER(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 移除转账单号中间的空格,防止有人复制的时候多复制了空格 | ||||
| 	data.TxId = strings.ReplaceAll(data.TxId, " ", "") | ||||
|  | ||||
| 	h.lock.Lock() | ||||
| 	defer h.lock.Unlock() | ||||
|  | ||||
| 	var item model.Reward | ||||
| 	res := h.db.Where("tx_id = ?", data.TxId).First(&item) | ||||
| 	res := h.DB.Where("tx_id = ?", data.TxId).First(&item) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "无效的众筹交易流水号!") | ||||
| 		return | ||||
| @@ -47,16 +59,13 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	if err != nil { | ||||
| 		resp.HACKER(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tx := h.db.Begin() | ||||
| 	calls := (item.Amount + 0.1) * 10 | ||||
| 	res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls)) | ||||
| 	tx := h.DB.Begin() | ||||
| 	exchange := vo.RewardExchange{} | ||||
| 	power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice) | ||||
| 	exchange.Power = int(power) | ||||
| 	res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power)) | ||||
| 	if res.Error != nil { | ||||
| 		tx.Rollback() | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
| @@ -64,13 +73,26 @@ func (h *RewardHandler) Verify(c *gin.Context) { | ||||
| 	// 更新核销状态 | ||||
| 	item.Status = true | ||||
| 	item.UserId = user.Id | ||||
| 	res = h.db.Updates(&item) | ||||
| 	item.Exchange = utils.JsonEncode(exchange) | ||||
| 	res = tx.Updates(&item) | ||||
| 	if res.Error != nil { | ||||
| 		tx.Rollback() | ||||
| 		resp.ERROR(c, "更新数据库失败!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 记录算力充值日志 | ||||
| 	h.DB.Create(&model.PowerLog{ | ||||
| 		UserId:    user.Id, | ||||
| 		Username:  user.Username, | ||||
| 		Type:      types.PowerReward, | ||||
| 		Amount:    exchange.Power, | ||||
| 		Balance:   user.Power + exchange.Power, | ||||
| 		Mark:      types.PowerAdd, | ||||
| 		Model:     "众筹支付", | ||||
| 		Remark:    fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice), | ||||
| 		CreatedAt: time.Now(), | ||||
| 	}) | ||||
| 	tx.Commit() | ||||
| 	resp.SUCCESS(c) | ||||
|  | ||||
|   | ||||
| @@ -3,35 +3,40 @@ package handler | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/service/sd" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"gorm.io/gorm" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gorilla/websocket" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type SdJobHandler struct { | ||||
| 	BaseHandler | ||||
| 	redis   *redis.Client | ||||
| 	db      *gorm.DB | ||||
| 	service *sd.Service | ||||
| 	redis    *redis.Client | ||||
| 	pool     *sd.ServicePool | ||||
| 	uploader *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler { | ||||
| 	h := SdJobHandler{ | ||||
| 		redis:   redisCli, | ||||
| 		db:      db, | ||||
| 		service: service, | ||||
| func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler { | ||||
| 	return &SdJobHandler{ | ||||
| 		pool:     pool, | ||||
| 		uploader: manager, | ||||
| 		BaseHandler: BaseHandler{ | ||||
| 			App: app, | ||||
| 			DB:  db, | ||||
| 		}, | ||||
| 	} | ||||
| 	h.App = app | ||||
| 	return &h | ||||
| } | ||||
|  | ||||
| // Client WebSocket 客户端,用于通知任务状态变更 | ||||
| @@ -39,25 +44,36 @@ func (h *SdJobHandler) Client(c *gin.Context) { | ||||
| 	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| 	if userId == 0 { | ||||
| 		logger.Info("Invalid user ID") | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sessionId := c.Query("session_id") | ||||
| 	client := types.NewWsClient(ws) | ||||
| 	// 删除旧的连接 | ||||
| 	h.service.Clients.Put(sessionId, client) | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.ClientIP()) | ||||
| 	h.pool.Clients.Put(uint(userId), client) | ||||
| 	logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) | ||||
| } | ||||
|  | ||||
| func (h *SdJobHandler) checkLimits(c *gin.Context) bool { | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if user.ImgCalls <= 0 { | ||||
| 		resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") | ||||
| 	if !h.pool.HasAvailableService() { | ||||
| 		resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!") | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if user.Power < h.App.SysConfig.SdPower { | ||||
| 		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| @@ -67,11 +83,6 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool { | ||||
|  | ||||
| // Image 创建一个绘画任务 | ||||
| func (h *SdJobHandler) Image(c *gin.Context) { | ||||
| 	if !h.App.Config.SdConfig.Enabled { | ||||
| 		resp.ERROR(c, "Stable Diffusion service is disabled") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !h.checkLimits(c) { | ||||
| 		return | ||||
| 	} | ||||
| @@ -129,45 +140,85 @@ func (h *SdJobHandler) Image(c *gin.Context) { | ||||
| 		Params:    utils.JsonEncode(params), | ||||
| 		Prompt:    data.Prompt, | ||||
| 		Progress:  0, | ||||
| 		Started:   false, | ||||
| 		Power:     h.App.SysConfig.SdPower, | ||||
| 		CreatedAt: time.Now(), | ||||
| 	} | ||||
| 	res := h.db.Create(&job) | ||||
| 	res := h.DB.Create(&job) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "error with save job: "+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.service.PushTask(types.SdTask{ | ||||
| 	h.pool.PushTask(types.SdTask{ | ||||
| 		Id:        int(job.Id), | ||||
| 		SessionId: data.SessionId, | ||||
| 		Src:       types.TaskSrcImg, | ||||
| 		Type:      types.TaskImage, | ||||
| 		Prompt:    data.Prompt, | ||||
| 		Params:    params, | ||||
| 		UserId:    userId, | ||||
| 	}) | ||||
| 	var jobVo vo.SdJob | ||||
| 	err := utils.CopyObject(job, &jobVo) | ||||
| 	if err == nil { | ||||
| 		// 推送任务到前端 | ||||
| 		client := h.service.Clients.Get(data.SessionId) | ||||
| 		if client != nil { | ||||
| 			utils.ReplyChunkMessage(client, jobVo) | ||||
| 		} | ||||
|  | ||||
| 	client := h.pool.Clients.Get(uint(job.UserId)) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
|  | ||||
| 	// update user's power | ||||
| 	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) | ||||
| 	// 记录算力变化日志 | ||||
| 	if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 		user, _ := h.GetLoginUser(c) | ||||
| 		h.DB.Create(&model.PowerLog{ | ||||
| 			UserId:    user.Id, | ||||
| 			Username:  user.Username, | ||||
| 			Type:      types.PowerConsume, | ||||
| 			Amount:    job.Power, | ||||
| 			Balance:   user.Power - job.Power, | ||||
| 			Mark:      types.PowerSub, | ||||
| 			Model:     "stable-diffusion", | ||||
| 			Remark:    fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId), | ||||
| 			CreatedAt: time.Now(), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // JobList 获取 stable diffusion 任务列表 | ||||
| func (h *SdJobHandler) JobList(c *gin.Context) { | ||||
| 	status := h.GetInt(c, "status", 0) | ||||
| 	userId := h.GetInt(c, "user_id", 0) | ||||
| // ImgWall 照片墙 | ||||
| func (h *SdJobHandler) ImgWall(c *gin.Context) { | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	err, jobs := h.getData(true, 0, page, pageSize, true) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	session := h.db.Session(&gorm.Session{}) | ||||
| 	if status == 1 { | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| } | ||||
|  | ||||
| // JobList 获取 SD 任务列表 | ||||
| func (h *SdJobHandler) JobList(c *gin.Context) { | ||||
| 	status := h.GetBool(c, "status") | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	page := h.GetInt(c, "page", 0) | ||||
| 	pageSize := h.GetInt(c, "page_size", 0) | ||||
| 	publish := h.GetBool(c, "publish") | ||||
|  | ||||
| 	err, jobs := h.getData(status, userId, page, pageSize, publish) | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, jobs) | ||||
| } | ||||
|  | ||||
| // JobList 获取 MJ 任务列表 | ||||
| func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { | ||||
|  | ||||
| 	session := h.DB.Session(&gorm.Session{}) | ||||
| 	if finish { | ||||
| 		session = session.Where("progress = ?", 100).Order("id DESC") | ||||
| 	} else { | ||||
| 		session = session.Where("progress < ?", 100).Order("id ASC") | ||||
| @@ -175,6 +226,9 @@ func (h *SdJobHandler) JobList(c *gin.Context) { | ||||
| 	if userId > 0 { | ||||
| 		session = session.Where("user_id = ?", userId) | ||||
| 	} | ||||
| 	if publish { | ||||
| 		session = session.Where("publish", publish) | ||||
| 	} | ||||
| 	if page > 0 && pageSize > 0 { | ||||
| 		offset := (page - 1) * pageSize | ||||
| 		session = session.Offset(offset).Limit(pageSize) | ||||
| @@ -183,8 +237,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) { | ||||
| 	var items []model.SdJob | ||||
| 	res := session.Find(&items) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, types.NoData) | ||||
| 		return | ||||
| 		return res.Error, nil | ||||
| 	} | ||||
|  | ||||
| 	var jobs = make([]vo.SdJob, 0) | ||||
| @@ -194,14 +247,69 @@ func (h *SdJobHandler) JobList(c *gin.Context) { | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if item.Progress < 100 { | ||||
| 			// 30 分钟还没完成的任务直接删除 | ||||
| 			if time.Now().Sub(item.CreatedAt) > time.Minute*30 { | ||||
| 				h.db.Delete(&item) | ||||
| 				continue | ||||
| 			// 正在运行中任务使用代理访问图片 | ||||
| 			image, err := utils.DownloadImage(item.ImgURL, "") | ||||
| 			if err == nil { | ||||
| 				job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) | ||||
| 			} | ||||
| 		} | ||||
| 		jobs = append(jobs, job) | ||||
| 	} | ||||
| 	resp.SUCCESS(c, jobs) | ||||
|  | ||||
| 	return nil, jobs | ||||
| } | ||||
|  | ||||
| // Remove remove task image | ||||
| func (h *SdJobHandler) Remove(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint   `json:"id"` | ||||
| 		UserId uint   `json:"user_id"` | ||||
| 		ImgURL string `json:"img_url"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove job recode | ||||
| 	res := h.DB.Delete(&model.SdJob{Id: data.Id}) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove image | ||||
| 	err := h.uploader.GetUploadHandler().Delete(data.ImgURL) | ||||
| 	if err != nil { | ||||
| 		logger.Error("remove image failed: ", err) | ||||
| 	} | ||||
|  | ||||
| 	client := h.pool.Clients.Get(data.UserId) | ||||
| 	if client != nil { | ||||
| 		_ = client.Send([]byte("Task Updated")) | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Publish 发布/取消发布图片到画廊显示 | ||||
| func (h *SdJobHandler) Publish(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Id     uint `json:"id"` | ||||
| 		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
| @@ -4,33 +4,45 @@ import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/service/sms" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| ) | ||||
|  | ||||
| const CodeStorePrefix = "/verify/codes/" | ||||
|  | ||||
| type SmsHandler struct { | ||||
| 	BaseHandler | ||||
| 	leveldb *store.LevelDB | ||||
| 	sms     *service.AliYunSmsService | ||||
| 	redis   *redis.Client | ||||
| 	sms     *sms.ServiceManager | ||||
| 	smtp    *service.SmtpService | ||||
| 	captcha *service.CaptchaService | ||||
| } | ||||
|  | ||||
| func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler { | ||||
| 	handler := &SmsHandler{leveldb: db, sms: sms, captcha: captcha} | ||||
| 	handler.App = app | ||||
| 	return handler | ||||
| func NewSmsHandler( | ||||
| 	app *core.AppServer, | ||||
| 	client *redis.Client, | ||||
| 	sms *sms.ServiceManager, | ||||
| 	smtp *service.SmtpService, | ||||
| 	captcha *service.CaptchaService) *SmsHandler { | ||||
| 	return &SmsHandler{ | ||||
| 		redis:       client, | ||||
| 		sms:         sms, | ||||
| 		captcha:     captcha, | ||||
| 		smtp:        smtp, | ||||
| 		BaseHandler: BaseHandler{App: app}} | ||||
| } | ||||
|  | ||||
| // SendCode 发送验证码短信 | ||||
| // SendCode 发送验证码 | ||||
| func (h *SmsHandler) SendCode(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Mobile string `json:"mobile"` | ||||
| 		Key    string `json:"key"` | ||||
| 		Dots   string `json:"dots"` | ||||
| 		Receiver string `json:"receiver"` // 接收者 | ||||
| 		Key      string `json:"key"` | ||||
| 		Dots     string `json:"dots"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| @@ -43,14 +55,28 @@ func (h *SmsHandler) SendCode(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	code := utils.RandomNumber(6) | ||||
| 	err := h.sms.SendVerifyCode(data.Mobile, code) | ||||
| 	var err error | ||||
| 	if strings.Contains(data.Receiver, "@") { // email | ||||
| 		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") { | ||||
| 			resp.ERROR(c, "系统已禁用邮箱注册!") | ||||
| 			return | ||||
| 		} | ||||
| 		err = h.smtp.SendVerifyCode(data.Receiver, code) | ||||
| 	} else { | ||||
| 		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") { | ||||
| 			resp.ERROR(c, "系统已禁用手机号注册!") | ||||
| 			return | ||||
| 		} | ||||
| 		err = h.sms.GetService().SendVerifyCode(data.Receiver, code) | ||||
|  | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 存储验证码,等待后面注册验证 | ||||
| 	err = h.leveldb.Put(CodeStorePrefix+data.Mobile, code) | ||||
| 	_, err = h.redis.Set(c, CodeStorePrefix+data.Receiver, code, 0).Result() | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, "验证码保存失败") | ||||
| 		return | ||||
| @@ -58,13 +84,3 @@ func (h *SmsHandler) SendCode(c *gin.Context) { | ||||
|  | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| type statusVo struct { | ||||
| 	EnabledMsgService bool `json:"enabled_msg_service"` | ||||
| 	EnabledRegister   bool `json:"enabled_register"` | ||||
| } | ||||
|  | ||||
| // Status check if the message service is enabled | ||||
| func (h *SmsHandler) Status(c *gin.Context) { | ||||
| 	resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsg, EnabledRegister: h.App.SysConfig.EnabledRegister}) | ||||
| } | ||||
|   | ||||
							
								
								
									
										17
									
								
								api/handler/test_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								api/handler/test_handler.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/payment" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type TestHandler struct { | ||||
| 	db        *gorm.DB | ||||
| 	snowflake *service.Snowflake | ||||
| 	js        *payment.PayJS | ||||
| } | ||||
|  | ||||
| func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler { | ||||
| 	return &TestHandler{db: db, snowflake: snowflake, js: js} | ||||
| } | ||||
| @@ -3,29 +3,92 @@ package handler | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"gorm.io/gorm" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type UploadHandler struct { | ||||
| 	BaseHandler | ||||
| 	db              *gorm.DB | ||||
| 	uploaderManager *oss.UploaderManager | ||||
| } | ||||
|  | ||||
| func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler { | ||||
| 	handler := &UploadHandler{db: db, uploaderManager: manager} | ||||
| 	handler.App = app | ||||
| 	return handler | ||||
| 	return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager} | ||||
| } | ||||
|  | ||||
| func (h *UploadHandler) Upload(c *gin.Context) { | ||||
| 	fileURL, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") | ||||
| 	file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") | ||||
| 	if err != nil { | ||||
| 		resp.ERROR(c, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, fileURL) | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	res := h.DB.Create(&model.File{ | ||||
| 		UserId:    int(userId), | ||||
| 		Name:      file.Name, | ||||
| 		ObjKey:    file.ObjKey, | ||||
| 		URL:       file.URL, | ||||
| 		Ext:       file.Ext, | ||||
| 		Size:      file.Size, | ||||
| 		CreatedAt: time.Time{}, | ||||
| 	}) | ||||
| 	if res.Error != nil || res.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "error with update database: "+res.Error.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, file) | ||||
| } | ||||
|  | ||||
| func (h *UploadHandler) List(c *gin.Context) { | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	var items []model.File | ||||
| 	var files = make([]vo.File, 0) | ||||
| 	h.DB.Where("user_id = ?", userId).Find(&items) | ||||
| 	if len(items) > 0 { | ||||
| 		for _, v := range items { | ||||
| 			var file vo.File | ||||
| 			err := utils.CopyObject(v, &file) | ||||
| 			if err != nil { | ||||
| 				logger.Error(err) | ||||
| 				continue | ||||
| 			} | ||||
| 			file.CreatedAt = v.CreatedAt.Unix() | ||||
| 			files = append(files, file) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	resp.SUCCESS(c, files) | ||||
| } | ||||
|  | ||||
| // Remove remove files | ||||
| func (h *UploadHandler) Remove(c *gin.Context) { | ||||
| 	userId := h.GetLoginUserId(c) | ||||
| 	id := h.GetInt(c, "id", 0) | ||||
| 	var file model.File | ||||
| 	tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file) | ||||
| 	if tx.Error != nil || file.Id == 0 { | ||||
| 		resp.ERROR(c, "file not existed") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// remove database | ||||
| 	tx = h.DB.Model(&model.File{}).Delete("id = ?", id) | ||||
| 	if tx.Error != nil || tx.RowsAffected == 0 { | ||||
| 		resp.ERROR(c, "failed to update database") | ||||
| 		return | ||||
| 	} | ||||
| 	// remove files | ||||
| 	objectKey := file.ObjKey | ||||
| 	if objectKey == "" { | ||||
| 		objectKey = file.URL | ||||
| 	} | ||||
| 	_ = h.uploaderManager.GetUploadHandler().Delete(objectKey) | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
| @@ -3,17 +3,17 @@ package handler | ||||
| import ( | ||||
| 	"chatplus/core" | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"chatplus/utils/resp" | ||||
| 	"fmt" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/lionsoul2014/ip2region/binding/golang/xdb" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -21,9 +21,7 @@ import ( | ||||
|  | ||||
| type UserHandler struct { | ||||
| 	BaseHandler | ||||
| 	db       *gorm.DB | ||||
| 	searcher *xdb.Searcher | ||||
| 	leveldb  *store.LevelDB | ||||
| 	redis    *redis.Client | ||||
| } | ||||
|  | ||||
| @@ -31,85 +29,113 @@ func NewUserHandler( | ||||
| 	app *core.AppServer, | ||||
| 	db *gorm.DB, | ||||
| 	searcher *xdb.Searcher, | ||||
| 	levelDB *store.LevelDB, | ||||
| 	client *redis.Client) *UserHandler { | ||||
| 	handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB, redis: client} | ||||
| 	handler.App = app | ||||
| 	return handler | ||||
| 	return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client} | ||||
| } | ||||
|  | ||||
| // Register user register | ||||
| func (h *UserHandler) Register(c *gin.Context) { | ||||
| 	// parameters process | ||||
| 	var data struct { | ||||
| 		Mobile   string `json:"mobile"` | ||||
| 		Password string `json:"password"` | ||||
| 		Code     int    `json:"code"` | ||||
| 		RegWay     string `json:"reg_way"` | ||||
| 		Username   string `json:"username"` | ||||
| 		Password   string `json:"password"` | ||||
| 		Code       string `json:"code"` | ||||
| 		InviteCode string `json:"invite_code"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
| 	data.Password = strings.TrimSpace(data.Password) | ||||
|  | ||||
| 	if len(data.Mobile) < 10 { | ||||
| 		resp.ERROR(c, "请输入合法的手机号") | ||||
| 		return | ||||
| 	} | ||||
| 	if len(data.Password) < 8 { | ||||
| 		resp.ERROR(c, "密码长度不能少于8个字符") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查验证码 | ||||
| 	key := CodeStorePrefix + data.Mobile | ||||
| 	if h.App.SysConfig.EnabledMsg { | ||||
| 		var code int | ||||
| 		err := h.leveldb.Get(key, &code) | ||||
| 	var key string | ||||
| 	if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" { | ||||
| 		key = CodeStorePrefix + data.Username | ||||
| 		code, err := h.redis.Get(c, key).Result() | ||||
| 		if err != nil || code != data.Code { | ||||
| 			resp.ERROR(c, "短信验证码错误") | ||||
| 			resp.ERROR(c, "验证码错误") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 验证邀请码 | ||||
| 	inviteCode := model.InviteCode{} | ||||
| 	if data.InviteCode != "" { | ||||
| 		res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode) | ||||
| 		if res.Error != nil { | ||||
| 			resp.ERROR(c, "无效的邀请码") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// check if the username is exists | ||||
| 	var item model.User | ||||
| 	res := h.db.Where("mobile = ?", data.Mobile).First(&item) | ||||
| 	if res.RowsAffected > 0 { | ||||
| 		resp.ERROR(c, "该手机号码已经被注册,请更换其他手机号") | ||||
| 	res := h.DB.Where("username = ?", data.Username).First(&item) | ||||
| 	if item.Id > 0 { | ||||
| 		resp.ERROR(c, "该用户名已经被注册") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	salt := utils.RandString(8) | ||||
| 	user := model.User{ | ||||
| 		Username:   data.Username, | ||||
| 		Password:   utils.GenPassword(data.Password, salt), | ||||
| 		Nickname:   fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)), | ||||
| 		Avatar:     "/images/avatar/user.png", | ||||
| 		Salt:       salt, | ||||
| 		Status:     true, | ||||
| 		Mobile:     data.Mobile, | ||||
| 		ChatRoles:  utils.JsonEncode([]string{"gpt"}),               // 默认只订阅通用助手角色 | ||||
| 		ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型 | ||||
| 		ChatConfig: utils.JsonEncode(types.UserChatConfig{ | ||||
| 			ApiKeys: map[types.Platform]string{ | ||||
| 				types.OpenAI:  "", | ||||
| 				types.Azure:   "", | ||||
| 				types.ChatGLM: "", | ||||
| 			}, | ||||
| 		}), | ||||
| 		Calls:    h.App.SysConfig.UserInitCalls, | ||||
| 		ImgCalls: h.App.SysConfig.InitImgCalls, | ||||
| 		Power:      h.App.SysConfig.InitPower, | ||||
| 	} | ||||
| 	res = h.db.Create(&user) | ||||
|  | ||||
| 	res = h.DB.Create(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "保存数据失败") | ||||
| 		logger.Error(res.Error) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.App.SysConfig.EnabledMsg { | ||||
| 		_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码 | ||||
| 	// 记录邀请关系 | ||||
| 	if data.InviteCode != "" { | ||||
| 		// 增加邀请数量 | ||||
| 		h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) | ||||
| 		if h.App.SysConfig.InvitePower > 0 { | ||||
| 			h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower)) | ||||
| 			// 记录邀请算力充值日志 | ||||
| 			var inviter model.User | ||||
| 			h.DB.Where("id", inviteCode.UserId).First(&inviter) | ||||
| 			h.DB.Create(&model.PowerLog{ | ||||
| 				UserId:    inviter.Id, | ||||
| 				Username:  inviter.Username, | ||||
| 				Type:      types.PowerInvite, | ||||
| 				Amount:    h.App.SysConfig.InvitePower, | ||||
| 				Balance:   inviter.Power, | ||||
| 				Mark:      types.PowerAdd, | ||||
| 				Model:     "", | ||||
| 				Remark:    fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), | ||||
| 				CreatedAt: time.Now(), | ||||
| 			}) | ||||
| 		} | ||||
|  | ||||
| 		// 添加邀请记录 | ||||
| 		h.DB.Create(&model.InviteLog{ | ||||
| 			InviterId:  inviteCode.UserId, | ||||
| 			UserId:     user.Id, | ||||
| 			Username:   user.Username, | ||||
| 			InviteCode: inviteCode.Code, | ||||
| 			Remark:     fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	_ = h.redis.Del(c, key) // 注册成功,删除短信验证码 | ||||
|  | ||||
| 	// 自动登录创建 token | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ | ||||
| 		"user_id": user.Id, | ||||
| @@ -132,7 +158,7 @@ func (h *UserHandler) Register(c *gin.Context) { | ||||
| // Login 用户登录 | ||||
| func (h *UserHandler) Login(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Mobile   string `json:"username"` | ||||
| 		Username string `json:"username"` | ||||
| 		Password string `json:"password"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| @@ -140,7 +166,7 @@ func (h *UserHandler) Login(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	var user model.User | ||||
| 	res := h.db.Where("mobile = ?", data.Mobile).First(&user) | ||||
| 	res := h.DB.Where("username = ?", data.Username).First(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "用户名不存在") | ||||
| 		return | ||||
| @@ -160,11 +186,11 @@ func (h *UserHandler) Login(c *gin.Context) { | ||||
| 	// 更新最后登录时间和IP | ||||
| 	user.LastLoginIp = c.ClientIP() | ||||
| 	user.LastLoginAt = time.Now().Unix() | ||||
| 	h.db.Model(&user).Updates(user) | ||||
| 	h.DB.Model(&user).Updates(user) | ||||
|  | ||||
| 	h.db.Create(&model.UserLoginLog{ | ||||
| 	h.DB.Create(&model.UserLoginLog{ | ||||
| 		UserId:       user.Id, | ||||
| 		Username:     user.Mobile, | ||||
| 		Username:     user.Username, | ||||
| 		LoginIp:      c.ClientIP(), | ||||
| 		LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()), | ||||
| 	}) | ||||
| @@ -207,7 +233,7 @@ func (h *UserHandler) Logout(c *gin.Context) { | ||||
|  | ||||
| // Session 获取/验证会话 | ||||
| func (h *UserHandler) Session(c *gin.Context) { | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err == nil { | ||||
| 		var userVo vo.User | ||||
| 		err := utils.CopyObject(user, &userVo) | ||||
| @@ -223,23 +249,23 @@ func (h *UserHandler) Session(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| type userProfile struct { | ||||
| 	Id          uint                 `json:"id"` | ||||
| 	Mobile      string               `json:"mobile"` | ||||
| 	Avatar      string               `json:"avatar"` | ||||
| 	ChatConfig  types.UserChatConfig `json:"chat_config"` | ||||
| 	Calls       int                  `json:"calls"` | ||||
| 	ImgCalls    int                  `json:"img_calls"` | ||||
| 	TotalTokens int64                `json:"total_tokens"` | ||||
| 	Id          uint   `json:"id"` | ||||
| 	Nickname    string `json:"nickname"` | ||||
| 	Username    string `json:"username"` | ||||
| 	Avatar      string `json:"avatar"` | ||||
| 	Power       int    `json:"power"` | ||||
| 	ExpiredTime int64  `json:"expired_time"` | ||||
| 	Vip         bool   `json:"vip"` | ||||
| } | ||||
|  | ||||
| func (h *UserHandler) Profile(c *gin.Context) { | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.db.First(&user, user.Id) | ||||
| 	h.DB.First(&user, user.Id) | ||||
| 	var profile userProfile | ||||
| 	err = utils.CopyObject(user, &profile) | ||||
| 	if err != nil { | ||||
| @@ -259,15 +285,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
| 	h.db.First(&user, user.Id) | ||||
| 	h.DB.First(&user, user.Id) | ||||
| 	user.Avatar = data.Avatar | ||||
| 	user.ChatConfig = utils.JsonEncode(data.ChatConfig) | ||||
| 	res := h.db.Updates(&user) | ||||
| 	user.Nickname = data.Nickname | ||||
| 	res := h.DB.Updates(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新用户信息失败") | ||||
| 		return | ||||
| @@ -276,8 +302,8 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) { | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // Password 更新密码 | ||||
| func (h *UserHandler) Password(c *gin.Context) { | ||||
| // UpdatePass 更新密码 | ||||
| func (h *UserHandler) UpdatePass(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		OldPass  string `json:"old_pass"` | ||||
| 		Password string `json:"password"` | ||||
| @@ -292,21 +318,21 @@ func (h *UserHandler) Password(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	password := utils.GenPassword(data.OldPass, user.Salt) | ||||
| 	logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass) | ||||
| 	logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass) | ||||
| 	if password != user.Password { | ||||
| 		resp.ERROR(c, "原密码错误") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	newPass := utils.GenPassword(data.Password, user.Salt) | ||||
| 	res := h.db.Model(&user).UpdateColumn("password", newPass) | ||||
| 	res := h.DB.Model(&user).UpdateColumn("password", newPass) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("更新数据库失败: ", res.Error) | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| @@ -316,46 +342,83 @@ func (h *UserHandler) Password(c *gin.Context) { | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|  | ||||
| // BindMobile 绑定手机号 | ||||
| func (h *UserHandler) BindMobile(c *gin.Context) { | ||||
| // ResetPass 重置密码 | ||||
| func (h *UserHandler) ResetPass(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Mobile string `json:"mobile"` | ||||
| 		Code   int    `json:"code"` | ||||
| 		Username string `json:"username"` | ||||
| 		Code     string `json:"code"`     // 验证码 | ||||
| 		Password string `json:"password"` // 新密码 | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查手机号是否被其他账号绑定 | ||||
| 	var item model.User | ||||
| 	res := h.db.Where("mobile = ?", data.Mobile).First(&item) | ||||
| 	if res.Error == nil { | ||||
| 		resp.ERROR(c, "该手机号已经被其他账号绑定") | ||||
| 	var user model.User | ||||
| 	res := h.DB.Where("username", data.Username).First(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "用户不存在!") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查验证码 | ||||
| 	key := CodeStorePrefix + data.Mobile | ||||
| 	var code int | ||||
| 	err := h.leveldb.Get(key, &code) | ||||
| 	key := CodeStorePrefix + data.Username | ||||
| 	code, err := h.redis.Get(c, key).Result() | ||||
| 	if err != nil || code != data.Code { | ||||
| 		resp.ERROR(c, "短信验证码错误") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	user, err := utils.GetLoginUser(c, h.db) | ||||
| 	password := utils.GenPassword(data.Password, user.Salt) | ||||
| 	user.Password = password | ||||
| 	res = h.DB.Updates(&user) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c) | ||||
| 	} else { | ||||
| 		h.redis.Del(c, key) | ||||
| 		resp.SUCCESS(c) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // BindUsername 重置账号 | ||||
| func (h *UserHandler) BindUsername(c *gin.Context) { | ||||
| 	var data struct { | ||||
| 		Username string `json:"username"` | ||||
| 		Code     string `json:"code"` | ||||
| 	} | ||||
| 	if err := c.ShouldBindJSON(&data); err != nil { | ||||
| 		resp.ERROR(c, types.InvalidArgs) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查验证码 | ||||
| 	key := CodeStorePrefix + data.Username | ||||
| 	code, err := h.redis.Get(c, key).Result() | ||||
| 	if err != nil || code != data.Code { | ||||
| 		resp.ERROR(c, "验证码错误") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 检查手机号是否被其他账号绑定 | ||||
| 	var item model.User | ||||
| 	res := h.DB.Where("username = ?", data.Username).First(&item) | ||||
| 	if res.Error == nil { | ||||
| 		resp.ERROR(c, "该账号已经被其他账号绑定") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	user, err := h.GetLoginUser(c) | ||||
| 	if err != nil { | ||||
| 		resp.NotAuth(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	res = h.db.Model(&user).UpdateColumn("mobile", data.Mobile) | ||||
| 	res = h.DB.Model(&user).UpdateColumn("username", data.Username) | ||||
| 	if res.Error != nil { | ||||
| 		resp.ERROR(c, "更新数据库失败") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	_ = h.leveldb.Delete(key) // 删除短信验证码 | ||||
| 	_ = h.redis.Del(c, key) // 删除短信验证码 | ||||
| 	resp.SUCCESS(c) | ||||
| } | ||||
|   | ||||
							
								
								
									
										217
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										217
									
								
								api/main.go
									
									
									
									
									
								
							| @@ -8,15 +8,15 @@ import ( | ||||
| 	"chatplus/handler/chatimpl" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/service" | ||||
| 	"chatplus/service/fun" | ||||
| 	"chatplus/service/mj" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/service/payment" | ||||
| 	"chatplus/service/sd" | ||||
| 	"chatplus/service/sms" | ||||
| 	"chatplus/service/wx" | ||||
| 	"chatplus/store" | ||||
| 	"context" | ||||
| 	"embed" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| @@ -25,6 +25,8 @@ import ( | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
|  | ||||
| 	"github.com/lionsoul2014/ip2region/binding/golang/xdb" | ||||
| 	"go.uber.org/fx" | ||||
| 	"gorm.io/gorm" | ||||
| @@ -32,7 +34,7 @@ import ( | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| //go:embed res/ip2region.xdb | ||||
| //go:embed res | ||||
| var xdbFS embed.FS | ||||
|  | ||||
| // AppLifecycle 应用程序生命周期 | ||||
| @@ -56,19 +58,15 @@ func main() { | ||||
| 	if configFile == "" { | ||||
| 		configFile = "config.toml" | ||||
| 	} | ||||
| 	var debug bool | ||||
| 	debugEnv := os.Getenv("DEBUG") | ||||
| 	if debugEnv == "" { | ||||
| 		debug = true | ||||
| 	} else { | ||||
| 		debug, _ = strconv.ParseBool(os.Getenv("DEBUG")) | ||||
| 	} | ||||
| 	debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG")) | ||||
| 	logger.Info("Loading config file: ", configFile) | ||||
| 	defer func() { | ||||
| 		if err := recover(); err != nil { | ||||
| 			logger.Error("Panic Error:", err) | ||||
| 		} | ||||
| 	}() | ||||
| 	if !debug { | ||||
| 		defer func() { | ||||
| 			if err := recover(); err != nil { | ||||
| 				logger.Error("Panic Error:", err) | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	app := fx.New( | ||||
| 		// 初始化配置应用配置 | ||||
| @@ -93,9 +91,12 @@ func main() { | ||||
| 		// 初始化数据库 | ||||
| 		fx.Provide(store.NewGormConfig), | ||||
| 		fx.Provide(store.NewMysql), | ||||
| 		fx.Provide(store.NewLevelDB), | ||||
| 		fx.Provide(store.NewRedisClient), | ||||
|  | ||||
| 		fx.Provide(func() embed.FS { | ||||
| 			return xdbFS | ||||
| 		}), | ||||
|  | ||||
| 		// 创建 Ip2Region 查询对象 | ||||
| 		fx.Provide(func() (*xdb.Searcher, error) { | ||||
| 			file, err := xdbFS.Open("res/ip2region.xdb") | ||||
| @@ -110,9 +111,6 @@ func main() { | ||||
| 			return xdb.NewWithBuffer(cBuff) | ||||
| 		}), | ||||
|  | ||||
| 		// 创建函数 | ||||
| 		fx.Provide(fun.NewFunctions), | ||||
|  | ||||
| 		// 创建控制器 | ||||
| 		fx.Provide(handler.NewChatRoleHandler), | ||||
| 		fx.Provide(handler.NewUserHandler), | ||||
| @@ -124,6 +122,11 @@ func main() { | ||||
| 		fx.Provide(handler.NewMidJourneyHandler), | ||||
| 		fx.Provide(handler.NewChatModelHandler), | ||||
| 		fx.Provide(handler.NewSdJobHandler), | ||||
| 		fx.Provide(handler.NewPaymentHandler), | ||||
| 		fx.Provide(handler.NewOrderHandler), | ||||
| 		fx.Provide(handler.NewProductHandler), | ||||
| 		fx.Provide(handler.NewConfigHandler), | ||||
| 		fx.Provide(handler.NewPowerLogHandler), | ||||
|  | ||||
| 		fx.Provide(admin.NewConfigHandler), | ||||
| 		fx.Provide(admin.NewAdminHandler), | ||||
| @@ -133,15 +136,22 @@ func main() { | ||||
| 		fx.Provide(admin.NewRewardHandler), | ||||
| 		fx.Provide(admin.NewDashboardHandler), | ||||
| 		fx.Provide(admin.NewChatModelHandler), | ||||
| 		fx.Provide(admin.NewProductHandler), | ||||
| 		fx.Provide(admin.NewOrderHandler), | ||||
| 		fx.Provide(admin.NewChatHandler), | ||||
| 		fx.Provide(admin.NewPowerLogHandler), | ||||
|  | ||||
| 		// 创建服务 | ||||
| 		fx.Provide(service.NewAliYunSmsService), | ||||
| 		fx.Provide(sms.NewSendServiceManager), | ||||
| 		fx.Provide(func(config *types.AppConfig) *service.CaptchaService { | ||||
| 			return service.NewCaptchaService(config.ApiConfig) | ||||
| 		}), | ||||
| 		fx.Provide(oss.NewUploaderManager), | ||||
| 		fx.Provide(mj.NewService), | ||||
|  | ||||
| 		// 邮件服务 | ||||
| 		fx.Provide(service.NewSmtpService), | ||||
|  | ||||
| 		// 微信机器人服务 | ||||
| 		fx.Provide(wx.NewWeChatBot), | ||||
| 		fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) { | ||||
| @@ -153,34 +163,32 @@ func main() { | ||||
| 			} | ||||
| 		}), | ||||
|  | ||||
| 		// MidJourney 机器人 | ||||
| 		fx.Provide(mj.NewBot), | ||||
| 		fx.Provide(mj.NewClient), | ||||
| 		fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) { | ||||
| 			if config.MjConfig.Enabled { | ||||
| 				err := bot.Run() | ||||
| 				if err != nil { | ||||
| 					log.Fatal("MidJourney 服务启动失败:", err) | ||||
| 				} | ||||
| 			} | ||||
| 		}), | ||||
| 		fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) { | ||||
| 			if config.MjConfig.Enabled { | ||||
| 				go func() { | ||||
| 					mjService.Run() | ||||
| 				}() | ||||
| 		// MidJourney service pool | ||||
| 		fx.Provide(mj.NewServicePool), | ||||
| 		fx.Invoke(func(pool *mj.ServicePool) { | ||||
| 			if pool.HasAvailableService() { | ||||
| 				pool.DownloadImages() | ||||
| 				pool.CheckTaskNotify() | ||||
| 				pool.SyncTaskProgress() | ||||
| 			} | ||||
| 		}), | ||||
|  | ||||
| 		// Stable Diffusion 机器人 | ||||
| 		fx.Provide(sd.NewService), | ||||
| 		fx.Invoke(func(config *types.AppConfig, service *sd.Service) { | ||||
| 			if config.SdConfig.Enabled { | ||||
| 		fx.Provide(sd.NewServicePool), | ||||
|  | ||||
| 		fx.Provide(payment.NewAlipayService), | ||||
| 		fx.Provide(payment.NewHuPiPay), | ||||
| 		fx.Provide(payment.NewPayJS), | ||||
| 		fx.Provide(service.NewSnowflake), | ||||
| 		fx.Provide(service.NewXXLJobExecutor), | ||||
| 		fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) { | ||||
| 			if config.XXLConfig.Enabled { | ||||
| 				go func() { | ||||
| 					service.Run() | ||||
| 					log.Fatal(exec.Run()) | ||||
| 				}() | ||||
| 			} | ||||
| 		}), | ||||
|  | ||||
| 		// 注册路由 | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { | ||||
| 			group := s.Engine.Group("/api/role/") | ||||
| @@ -195,8 +203,9 @@ func main() { | ||||
| 			group.GET("session", h.Session) | ||||
| 			group.GET("profile", h.Profile) | ||||
| 			group.POST("profile/update", h.ProfileUpdate) | ||||
| 			group.POST("password", h.Password) | ||||
| 			group.POST("bind/mobile", h.BindMobile) | ||||
| 			group.POST("password", h.UpdatePass) | ||||
| 			group.POST("bind/username", h.BindUsername) | ||||
| 			group.POST("resetPass", h.ResetPass) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) { | ||||
| 			group := s.Engine.Group("/api/chat/") | ||||
| @@ -212,10 +221,11 @@ func main() { | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { | ||||
| 			s.Engine.POST("/api/upload", h.Upload) | ||||
| 			s.Engine.GET("/api/upload/list", h.List) | ||||
| 			s.Engine.GET("/api/upload/remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) { | ||||
| 			group := s.Engine.Group("/api/sms/") | ||||
| 			group.GET("status", h.Status) | ||||
| 			group.POST("code", h.SendCode) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) { | ||||
| @@ -229,17 +239,28 @@ func main() { | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { | ||||
| 			group := s.Engine.Group("/api/mj/") | ||||
| 			group.Any("client", h.Client) | ||||
| 			group.POST("image", h.Image) | ||||
| 			group.POST("upscale", h.Upscale) | ||||
| 			group.POST("variation", h.Variation) | ||||
| 			group.GET("jobs", h.JobList) | ||||
| 			group.Any("client", h.Client) | ||||
| 			group.GET("imgWall", h.ImgWall) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 			group.POST("notify", h.Notify) | ||||
| 			group.POST("publish", h.Publish) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { | ||||
| 			group := s.Engine.Group("/api/sd") | ||||
| 			group.Any("client", h.Client) | ||||
| 			group.POST("image", h.Image) | ||||
| 			group.GET("jobs", h.JobList) | ||||
| 			group.Any("client", h.Client) | ||||
| 			group.GET("imgWall", h.ImgWall) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 			group.POST("publish", h.Publish) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) { | ||||
| 			group := s.Engine.Group("/api/config/") | ||||
| 			group.GET("get", h.Get) | ||||
| 		}), | ||||
|  | ||||
| 		// 管理后台控制器 | ||||
| @@ -253,13 +274,18 @@ func main() { | ||||
| 			group.POST("login", h.Login) | ||||
| 			group.GET("logout", h.Logout) | ||||
| 			group.GET("session", h.Session) | ||||
| 			group.GET("migrate", h.Migrate) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.POST("enable", h.Enable) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.POST("resetPass", h.ResetPass) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/apikey/") | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.POST("set", h.Set) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/user/") | ||||
| @@ -274,11 +300,13 @@ func main() { | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.POST("sort", h.Sort) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.POST("set", h.Set) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/reward/") | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/dashboard/") | ||||
| @@ -292,18 +320,109 @@ func main() { | ||||
| 			group := s.Engine.Group("/api/admin/model/") | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("set", h.Set) | ||||
| 			group.POST("sort", h.Sort) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) { | ||||
| 			group := s.Engine.Group("/api/payment/") | ||||
| 			group.GET("doPay", h.DoPay) | ||||
| 			group.GET("payWays", h.GetPayWays) | ||||
| 			group.POST("query", h.OrderQuery) | ||||
| 			group.POST("qrcode", h.PayQrcode) | ||||
| 			group.POST("mobile", h.Mobile) | ||||
| 			group.POST("alipay/notify", h.AlipayNotify) | ||||
| 			group.POST("hupipay/notify", h.HuPiPayNotify) | ||||
| 			group.POST("payjs/notify", h.PayJsNotify) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/product/") | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.POST("enable", h.Enable) | ||||
| 			group.POST("sort", h.Sort) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/order/") | ||||
| 			group.POST("list", h.List) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) { | ||||
| 			group := s.Engine.Group("/api/order/") | ||||
| 			group.POST("list", h.List) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) { | ||||
| 			group := s.Engine.Group("/api/product/") | ||||
| 			group.GET("list", h.List) | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(handler.NewInviteHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) { | ||||
| 			group := s.Engine.Group("/api/invite/") | ||||
| 			group.GET("code", h.Code) | ||||
| 			group.POST("list", h.List) | ||||
| 			group.GET("hits", h.Hits) | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(handler.NewPromptHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { | ||||
| 			group := s.Engine.Group("/api/prompt/") | ||||
| 			group.POST("rewrite", h.Rewrite) | ||||
| 			group.POST("translate", h.Translate) | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(admin.NewFunctionHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/function/") | ||||
| 			group.POST("save", h.Save) | ||||
| 			group.POST("set", h.Set) | ||||
| 			group.GET("list", h.List) | ||||
| 			group.GET("remove", h.Remove) | ||||
| 			group.GET("token", h.GenToken) | ||||
| 		}), | ||||
|  | ||||
| 		// 验证码 | ||||
| 		fx.Provide(admin.NewCaptchaHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/login/") | ||||
| 			group.GET("captcha", h.GetCaptcha) | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(admin.NewUploadHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) { | ||||
| 			s.Engine.POST("/api/admin/upload", h.Upload) | ||||
| 		}), | ||||
|  | ||||
| 		fx.Provide(handler.NewFunctionHandler), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) { | ||||
| 			group := s.Engine.Group("/api/function/") | ||||
| 			group.POST("weibo", h.WeiBo) | ||||
| 			group.POST("zaobao", h.ZaoBao) | ||||
| 			group.POST("dalle3", h.Dall3) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/chat/") | ||||
| 			group.POST("list", h.List) | ||||
| 			group.POST("message", h.Messages) | ||||
| 			group.GET("history", h.History) | ||||
| 			group.GET("remove", h.RemoveChat) | ||||
| 			group.GET("message/remove", h.RemoveMessage) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) { | ||||
| 			group := s.Engine.Group("/api/powerLog/") | ||||
| 			group.POST("list", h.List) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) { | ||||
| 			group := s.Engine.Group("/api/admin/powerLog/") | ||||
| 			group.POST("list", h.List) | ||||
| 		}), | ||||
| 		fx.Invoke(func(s *core.AppServer, db *gorm.DB) { | ||||
| 			err := s.Run(db) | ||||
| 			if err != nil { | ||||
| 				log.Fatal(err) | ||||
| 			} | ||||
| 		}), | ||||
|  | ||||
| 		// 注册生命周期回调函数 | ||||
| 		fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) { | ||||
| 			lifecycle.Append(fx.Hook{ | ||||
|   | ||||
							
								
								
									
										38
									
								
								api/res/certs/alipay/alipayPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								api/res/certs/alipay/alipayPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIIDszCCApugAwIBAgIQICMRB0rBU2/rZJbfJGMYIzANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE | ||||
| BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv | ||||
| biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1 | ||||
| dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDYzNTQxWhcNMjQxMTA2MDYzNTQxWjCB | ||||
| hDELMAkGA1UEBhMCQ04xHzAdBgNVBAoMFm1ib25meTkwMTVAc2FuZGJveC5jb20xDzANBgNVBAsM | ||||
| BkFsaXBheTFDMEEGA1UEAww65pSv5LuY5a6dKOS4reWbvSnnvZHnu5zmioDmnK/mnInpmZDlhazl | ||||
| j7gtMjA4ODcyMTAyMDc1MDU4MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKsoKcw5 | ||||
| sxaiyV7mpWzDtnQ1K518eQLP0+dJlZAf06aBep/Aj9DIqrba/k7DHt8dKQvILMLAMpN1+2IRxbaO | ||||
| yxMa/laj3lZ1eHrB6F077O3D62oHcE3noZtXL0N1zZAxpmkNmYIHeLZS2oLMS4ANu47O/wpDC7BV | ||||
| HjdpZugtdPJ4mxdCpM9GDdLs7W4s5QI4PUPK4skFNMFoKI+0cYP/9ju87UP//IHC/K510GWNl+Gn | ||||
| Cvgag3AmiIB0utJNsGhxm6zT1T9tUWjW9iz/BxBKiPatsCX9VpPQzGnW7ZonRQtiZSokIlP2IPvl | ||||
| H5DcwpWUz3/LUY0SmKxnKOEYeOOqCW8CAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3 | ||||
| DQEBCwUAA4IBAQAtgxF2EzjOndEFxBUD9tFwcSt6XKGggOp52oft1pvynPg4ALTLafOtfEPDrFBH | ||||
| PwpYrSu9s9C8NJtaA2HrlCfBjIuwEFTXiN+HPvS0SwSPKt9AXEiTcOF8vDcGamEen8QI4fo5Jia7 | ||||
| 2VRKkerkww5/+FzSaVO7ZUKuL80M1QJStmAZc8kPPwdYOTTW2bGf8BcmSDL6SPElBkt7tCCRd4sn | ||||
| +jq4cZ0yb2i77rBZCwHcTvfTqIBblPwLv4uGvg3+83BxIB5w6Kqp06bKEAPmobFY5IVHa+ON0/qi | ||||
| BXxXr+WQ3piKRVQEN64+PTAjSc67Ix1umvpLl3Ko6Ry7NJmpDcUn | ||||
| -----END CERTIFICATE----- | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIIDszCCApugAwIBAgIQIBkIGbgVxq210KxLJ+YA/TANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UE | ||||
| BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxJTAjBgNVBAsMHENlcnRpZmljYXRpb24gQXV0 | ||||
| aG9yaXR5IHRlc3QxNjA0BgNVBAMMLUFudCBGaW5hbmNpYWwgQ2VydGlmaWNhdGlvbiBBdXRob3Jp | ||||
| dHkgUjEgdGVzdDAeFw0xOTA4MTkxMTE2MDBaFw0yNDA4MDExMTE2MDBaMIGRMQswCQYDVQQGEwJD | ||||
| TjEbMBkGA1UECgwSQW50IEZpbmFuY2lhbCB0ZXN0MSUwIwYDVQQLDBxDZXJ0aWZpY2F0aW9uIEF1 | ||||
| dGhvcml0eSB0ZXN0MT4wPAYDVQQDDDVBbnQgRmluYW5jaWFsIENlcnRpZmljYXRpb24gQXV0aG9y | ||||
| aXR5IENsYXNzIDIgUjEgdGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMh4FKYO | ||||
| ZyRQHD6eFbPKZeSAnrfjfU7xmS9Yoozuu+iuqZlb6Z0SPLUqqTZAFZejOcmr07ln/pwZxluqplxC | ||||
| 5+B48End4nclDMlT5HPrDr3W0frs6Xsa2ZNcyil/iKNB5MbGll8LRAxntsKvZZj6vUTMb705gYgm | ||||
| VUMILwi/ZxKTQqBtkT/kQQ5y6nOZsj7XI5rYdz6qqOROrpvS/d7iypdHOMIM9Iz9DlL1mrCykbBi | ||||
| t25y+gTeXmuisHUwqaRpwtCGK4BayCqxRGbNipe6W73EK9lBrrzNtTr9NaysesT/v+l25JHCL9tG | ||||
| wpNr1oWFzk4IHVOg0ORiQ6SUgxZUTYcCAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3 | ||||
| DQEBCwUAA4IBAQBWThEoIaQoBX2YeRY/I8gu6TYnFXtyuCljANnXnM38ft+ikhE5mMNgKmJYLHvT | ||||
| yWWWgwHoSAWEuml7EGbE/2AK2h3k0MdfiWLzdmpPCRG/RJHk6UB1pMHPilI+c0MVu16OPpKbg5Vf | ||||
| LTv7dsAB40AzKsvyYw88/Ezi1osTXo6QQwda7uefvudirtb8FcQM9R66cJxl3kt1FXbpYwheIm/p | ||||
| j1mq64swCoIYu4NrsUYtn6CV542DTQMI5QdXkn+PzUUly8F6kDp+KpMNd0avfWNL5+O++z+F5Szy | ||||
| 1CPta1D7EQ/eYmMP+mOQ35oifWIoFCpN6qQVBS/Hob1J/UUyg7BW | ||||
| -----END CERTIFICATE----- | ||||
							
								
								
									
										88
									
								
								api/res/certs/alipay/alipayRootCert.crt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								api/res/certs/alipay/alipayRootCert.crt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIIBszCCAVegAwIBAgIIaeL+wBcKxnswDAYIKoEcz1UBg3UFADAuMQswCQYDVQQG | ||||
| EwJDTjEOMAwGA1UECgwFTlJDQUMxDzANBgNVBAMMBlJPT1RDQTAeFw0xMjA3MTQw | ||||
| MzExNTlaFw00MjA3MDcwMzExNTlaMC4xCzAJBgNVBAYTAkNOMQ4wDAYDVQQKDAVO | ||||
| UkNBQzEPMA0GA1UEAwwGUk9PVENBMFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAE | ||||
| MPCca6pmgcchsTf2UnBeL9rtp4nw+itk1Kzrmbnqo05lUwkwlWK+4OIrtFdAqnRT | ||||
| V7Q9v1htkv42TsIutzd126NdMFswHwYDVR0jBBgwFoAUTDKxl9kzG8SmBcHG5Yti | ||||
| W/CXdlgwDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFEwysZfZ | ||||
| MxvEpgXBxuWLYlvwl3ZYMAwGCCqBHM9VAYN1BQADSAAwRQIgG1bSLeOXp3oB8H7b | ||||
| 53W+CKOPl2PknmWEq/lMhtn25HkCIQDaHDgWxWFtnCrBjH16/W3Ezn7/U/Vjo5xI | ||||
| pDoiVhsLwg== | ||||
| -----END CERTIFICATE----- | ||||
|  | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIIF0zCCA7ugAwIBAgIIH8+hjWpIDREwDQYJKoZIhvcNAQELBQAwejELMAkGA1UE | ||||
| BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNVBAsMF0NlcnRpZmlj | ||||
| YXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5jaWFsIENlcnRpZmlj | ||||
| YXRpb24gQXV0aG9yaXR5IFIxMB4XDTE4MDMyMTEzNDg0MFoXDTM4MDIyODEzNDg0 | ||||
| MFowejELMAkGA1UEBhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNV | ||||
| BAsMF0NlcnRpZmljYXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5j | ||||
| aWFsIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFIxMIICIjANBgkqhkiG9w0BAQEF | ||||
| AAOCAg8AMIICCgKCAgEAtytTRcBNuur5h8xuxnlKJetT65cHGemGi8oD+beHFPTk | ||||
| rUTlFt9Xn7fAVGo6QSsPb9uGLpUFGEdGmbsQ2q9cV4P89qkH04VzIPwT7AywJdt2 | ||||
| xAvMs+MgHFJzOYfL1QkdOOVO7NwKxH8IvlQgFabWomWk2Ei9WfUyxFjVO1LVh0Bp | ||||
| dRBeWLMkdudx0tl3+21t1apnReFNQ5nfX29xeSxIhesaMHDZFViO/DXDNW2BcTs6 | ||||
| vSWKyJ4YIIIzStumD8K1xMsoaZBMDxg4itjWFaKRgNuPiIn4kjDY3kC66Sl/6yTl | ||||
| YUz8AybbEsICZzssdZh7jcNb1VRfk79lgAprm/Ktl+mgrU1gaMGP1OE25JCbqli1 | ||||
| Pbw/BpPynyP9+XulE+2mxFwTYhKAwpDIDKuYsFUXuo8t261pCovI1CXFzAQM2w7H | ||||
| DtA2nOXSW6q0jGDJ5+WauH+K8ZSvA6x4sFo4u0KNCx0ROTBpLif6GTngqo3sj+98 | ||||
| SZiMNLFMQoQkjkdN5Q5g9N6CFZPVZ6QpO0JcIc7S1le/g9z5iBKnifrKxy0TQjtG | ||||
| PsDwc8ubPnRm/F82RReCoyNyx63indpgFfhN7+KxUIQ9cOwwTvemmor0A+ZQamRe | ||||
| 9LMuiEfEaWUDK+6O0Gl8lO571uI5onYdN1VIgOmwFbe+D8TcuzVjIZ/zvHrAGUcC | ||||
| AwEAAaNdMFswCwYDVR0PBAQDAgEGMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFF90 | ||||
| tATATwda6uWx2yKjh0GynOEBMB8GA1UdIwQYMBaAFF90tATATwda6uWx2yKjh0Gy | ||||
| nOEBMA0GCSqGSIb3DQEBCwUAA4ICAQCVYaOtqOLIpsrEikE5lb+UARNSFJg6tpkf | ||||
| tJ2U8QF/DejemEHx5IClQu6ajxjtu0Aie4/3UnIXop8nH/Q57l+Wyt9T7N2WPiNq | ||||
| JSlYKYbJpPF8LXbuKYG3BTFTdOVFIeRe2NUyYh/xs6bXGr4WKTXb3qBmzR02FSy3 | ||||
| IODQw5Q6zpXj8prYqFHYsOvGCEc1CwJaSaYwRhTkFedJUxiyhyB5GQwoFfExCVHW | ||||
| 05ZFCAVYFldCJvUzfzrWubN6wX0DD2dwultgmldOn/W/n8at52mpPNvIdbZb2F41 | ||||
| T0YZeoWnCJrYXjq/32oc1cmifIHqySnyMnavi75DxPCdZsCOpSAT4j4lAQRGsfgI | ||||
| kkLPGQieMfNNkMCKh7qjwdXAVtdqhf0RVtFILH3OyEodlk1HYXqX5iE5wlaKzDop | ||||
| PKwf2Q3BErq1xChYGGVS+dEvyXc/2nIBlt7uLWKp4XFjqekKbaGaLJdjYP5b2s7N | ||||
| 1dM0MXQ/f8XoXKBkJNzEiM3hfsU6DOREgMc1DIsFKxfuMwX3EkVQM1If8ghb6x5Y | ||||
| jXayv+NLbidOSzk4vl5QwngO/JYFMkoc6i9LNwEaEtR9PhnrdubxmrtM+RjfBm02 | ||||
| 77q3dSWFESFQ4QxYWew4pHE0DpWbWy/iMIKQ6UZ5RLvB8GEcgt8ON7BBJeMc+Dyi | ||||
| kT9qhqn+lw== | ||||
| -----END CERTIFICATE----- | ||||
|  | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIICiDCCAgygAwIBAgIIQX76UsB/30owDAYIKoZIzj0EAwMFADB6MQswCQYDVQQG | ||||
| EwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UECwwXQ2VydGlmaWNh | ||||
| dGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNpYWwgQ2VydGlmaWNh | ||||
| dGlvbiBBdXRob3JpdHkgRTEwHhcNMTkwNDI4MTYyMDQ0WhcNNDkwNDIwMTYyMDQ0 | ||||
| WjB6MQswCQYDVQQGEwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UE | ||||
| CwwXQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNp | ||||
| YWwgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgRTEwdjAQBgcqhkjOPQIBBgUrgQQA | ||||
| IgNiAASCCRa94QI0vR5Up9Yr9HEupz6hSoyjySYqo7v837KnmjveUIUNiuC9pWAU | ||||
| WP3jwLX3HkzeiNdeg22a0IZPoSUCpasufiLAnfXh6NInLiWBrjLJXDSGaY7vaokt | ||||
| rpZvAdmjXTBbMAsGA1UdDwQEAwIBBjAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRZ | ||||
| 4ZTgDpksHL2qcpkFkxD2zVd16TAfBgNVHSMEGDAWgBRZ4ZTgDpksHL2qcpkFkxD2 | ||||
| zVd16TAMBggqhkjOPQQDAwUAA2gAMGUCMQD4IoqT2hTUn0jt7oXLdMJ8q4vLp6sg | ||||
| wHfPiOr9gxreb+e6Oidwd2LDnC4OUqCWiF8CMAzwKs4SnDJYcMLf2vpkbuVE4dTH | ||||
| Rglz+HGcTLWsFs4KxLsq7MuU+vJTBUeDJeDjdA== | ||||
| -----END CERTIFICATE----- | ||||
|  | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIIDxTCCAq2gAwIBAgIUEMdk6dVgOEIS2cCP0Q43P90Ps5YwDQYJKoZIhvcNAQEF | ||||
| BQAwajELMAkGA1UEBhMCQ04xEzARBgNVBAoMCmlUcnVzQ2hpbmExHDAaBgNVBAsM | ||||
| E0NoaW5hIFRydXN0IE5ldHdvcmsxKDAmBgNVBAMMH2lUcnVzQ2hpbmEgQ2xhc3Mg | ||||
| MiBSb290IENBIC0gRzMwHhcNMTMwNDE4MDkzNjU2WhcNMzMwNDE4MDkzNjU2WjBq | ||||
| MQswCQYDVQQGEwJDTjETMBEGA1UECgwKaVRydXNDaGluYTEcMBoGA1UECwwTQ2hp | ||||
| bmEgVHJ1c3QgTmV0d29yazEoMCYGA1UEAwwfaVRydXNDaGluYSBDbGFzcyAyIFJv | ||||
| b3QgQ0EgLSBHMzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOPPShpV | ||||
| nJbMqqCw6Bz1kehnoPst9pkr0V9idOwU2oyS47/HjJXk9Rd5a9xfwkPO88trUpz5 | ||||
| 4GmmwspDXjVFu9L0eFaRuH3KMha1Ak01citbF7cQLJlS7XI+tpkTGHEY5pt3EsQg | ||||
| wykfZl/A1jrnSkspMS997r2Gim54cwz+mTMgDRhZsKK/lbOeBPpWtcFizjXYCqhw | ||||
| WktvQfZBYi6o4sHCshnOswi4yV1p+LuFcQ2ciYdWvULh1eZhLxHbGXyznYHi0dGN | ||||
| z+I9H8aXxqAQfHVhbdHNzi77hCxFjOy+hHrGsyzjrd2swVQ2iUWP8BfEQqGLqM1g | ||||
| KgWKYfcTGdbPB1MCAwEAAaNjMGEwHQYDVR0OBBYEFG/oAMxTVe7y0+408CTAK8hA | ||||
| uTyRMB8GA1UdIwQYMBaAFG/oAMxTVe7y0+408CTAK8hAuTyRMA8GA1UdEwEB/wQF | ||||
| MAMBAf8wDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3DQEBBQUAA4IBAQBLnUTfW7hp | ||||
| emMbuUGCk7RBswzOT83bDM6824EkUnf+X0iKS95SUNGeeSWK2o/3ALJo5hi7GZr3 | ||||
| U8eLaWAcYizfO99UXMRBPw5PRR+gXGEronGUugLpxsjuynoLQu8GQAeysSXKbN1I | ||||
| UugDo9u8igJORYA+5ms0s5sCUySqbQ2R5z/GoceyI9LdxIVa1RjVX8pYOj8JFwtn | ||||
| DJN3ftSFvNMYwRuILKuqUYSHc2GPYiHVflDh5nDymCMOQFcFG3WsEuB+EYQPFgIU | ||||
| 1DHmdZcz7Llx8UOZXX2JupWCYzK1XhJb+r4hK5ncf/w8qGtYlmyJpxk3hr1TfUJX | ||||
| Yf4Zr0fJsGuv | ||||
| -----END CERTIFICATE----- | ||||
							
								
								
									
										19
									
								
								api/res/certs/alipay/appPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								api/res/certs/alipay/appPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| -----BEGIN CERTIFICATE----- | ||||
| MIIDmTCCAoGgAwIBAgIQICMRB2LW76yahgdg3IFNPDANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE | ||||
| BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv | ||||
| biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1 | ||||
| dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDU0NjE5WhcNMjQxMTExMDU0NjE5WjBr | ||||
| MQswCQYDVQQGEwJDTjEfMB0GA1UECgwWbWJvbmZ5OTAxNUBzYW5kYm94LmNvbTEPMA0GA1UECwwG | ||||
| QWxpcGF5MSowKAYDVQQDDCEyMDg4NzIxMDIwNzUwNTgxLTkwMjEwMDAxMzE2NTgwMjMwggEiMA0G | ||||
| CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCxihQPf1Q+g9ArgM46shVqL5sbRha/df95D1PsWyEq | ||||
| ANmWmG4zZ+ksYDVQrc4KzhSRoi56sm/7TDFYTmM6bW99e/nKW58WxyZB4ie5qA3F4n17psPyDqb8 | ||||
| IokcQmCphSFDaXQD6AoXoLNtTM0vAI2cWxAgebZ/vsrdj5Ntjt+Rp3NYMCk1i5xovHcfILzLEGbX | ||||
| QXoT9fo5AhHotTWa6xHVLPUGY9qwLzQxHzBmvy5ZMfnOfJkm/mDisTSqAUB59F3dzU/1ARVkEZ1w | ||||
| Mgb4XohWBw6iurQfbMnH2mIomAAwwZVFv+sXDbL9yMbSMo/SjVsTQprn0Q0EnwLo7nmmOM6HAgMB | ||||
| AAGjEjAQMA4GA1UdDwEB/wQEAwIE8DANBgkqhkiG9w0BAQsFAAOCAQEAn3Y4/C1h9R6ONsBqX3/q | ||||
| XfHX7yX1FM0Y1x48X3/Yxk6HivAkTukhhhVYVKJsbrbzRqHDp9vhAP/FR6o6pAevaYMmLov0VMXU | ||||
| 7oAuetgkaYEYkDuNen5/Hpdhqi2vTtdT+q9w8zHJd6MDQ0aoHgIxpLKw5vof2R1N4fwSgNXMiXE5 | ||||
| kmllKQMem/+on2p+Sj80/2asxryHIGlH87qPzkffv+kIOkZthbTApTFLLjdVri2QHGe8/cc4xy01 | ||||
| /9iR3IUzNahotT41lJ4bMevBY7XMAS3n5ekyABN/9ZRJqhWdXgmFCRN/u56qd6lDgu7R2M2QUoyc | ||||
| LuW5DfgRItKlmUB7sw== | ||||
| -----END CERTIFICATE----- | ||||
							
								
								
									
										1
									
								
								api/res/certs/alipay/privateKey.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								api/res/certs/alipay/privateKey.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| MIIEpQIBAAKCAQEAsYoUD39UPoPQK4DOOrIVai+bG0YWv3X/eQ9T7FshKgDZlphuM2fpLGA1UK3OCs4UkaIuerJv+0wxWE5jOm1vfXv5ylufFscmQeInuagNxeJ9e6bD8g6m/CKJHEJgqYUhQ2l0A+gKF6CzbUzNLwCNnFsQIHm2f77K3Y+TbY7fkadzWDApNYucaLx3HyC8yxBm10F6E/X6OQIR6LU1musR1Sz1BmPasC80MR8wZr8uWTH5znyZJv5g4rE0qgFAefRd3c1P9QEVZBGdcDIG+F6IVgcOorq0H2zJx9piKJgAMMGVRb/rFw2y/cjG0jKP0o1bE0Ka59ENBJ8C6O55pjjOhwIDAQABAoIBAFetNfz1R7hbxjlFshMAkVzQR8wvT9qbvl+dtzdZRcaFhu89NecDIP7+QDYor0FcxoGpU0TazDyRQyk2BQD8vHt+9zv9BVLtZLJSqoWgPbUFBi1DjS8EF2ka8RVYnn35NhUhhd7L//ftL88Bh673mfembQ9srDjoEy1Z01feoABAnCMkNFl986DmEwnarvEufXSDIgeN4ioMxha4NvfIPuI0zpVdV1O9sv+SGC+VEWZBtN3GNsaf4zS/f8FVGvTiU/Abz0gSw/iwSPHclDWQDTN3yFHf/tfqlzh0mH0WfhnuOBFWXzK+R7fbnM+asI9ttvzRcfpzgRGXdPcNcOv/6cECgYEA3DVqpi1k8MYfJixju6SG5gfyhM4VFksFmCMaNPgtatDMBKLMTgV/Ej6LXREojcy29uZl83F09pVlpd41eG39ULIPktixA/BqErQ2UaWh6kOxifycpu22Jh0r09hax6UgVrcBrrnCJEjcFsuJlrZvXQSzc3PBxjWy5gjabS5h9iECgYEAzmVAIh2frF01Y95zsLueAhhZwCtPanm6kf7ivR4r1plIX3b2sNRhWGmEHFgaCE6Braa0ogQ73Hd26kw4ZW+D6QMGC/zjCBEzDLLf++SjdVUHiY5AR4WHqXzq1jdAlsVyo9R661oAOp3lhiJVGLNXkHyEfEVPHsaxJh4osYSbX6cCgYEAx32Qx0i6eDFTyLZQB46uMrgiaVN04QRH5iJuvGvUYT8UhGKjaU8rZfDJOh+wOH2rhxMEaz1uc3C2bERY9mfWI4Ob/jFWc7YZsiYWS3Mcsuhubw4tMECLUg39RWZsHw8ls8kIuixIh6yFzhTH6YQOcRswIrhMZG8DScfdcSmiz2ECgYEAkWP1t5KSpkLKl11etcKUXfl1T8+yk9jIOowIgRw92WAFAWq2AH67TCKYM7dEL1HOO9tRJ0hAOt/U3ttuZtYVYBEHM26jJ02mXm2rJrA7DS4mrxmL4lYH6LbcXqZxU0Qnq4zEQgIWYzRTORf6Rfof1uJAGaJhR9bDd4yLMfGt2cUCgYEAo216Y61xOHUTA4AF1eekk+r+uOcQgQDvLXfs9FkDdJLk0mPG48/+eIYpPFnANJ/riF/DWOp8WGEe2IzA9yUFexzDbNQK8ha9kGcxaSAyiCwzjZ/t9/+hScDSV8kNqWSRSisu/YOFleEHbokT6mbLZ+gdqES8mUUanaEBzRQYGxo= | ||||
							
								
								
									
										
											BIN
										
									
								
								api/res/img/alipay.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								api/res/img/alipay.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 15 KiB | 
							
								
								
									
										
											BIN
										
									
								
								api/res/img/wechat-pay.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								api/res/img/wechat-pay.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 5.7 KiB | 
| @@ -1,36 +1,41 @@ | ||||
| { | ||||
|   "data": [ | ||||
|     "task(38194gitxp745ha)", | ||||
|     "A beautiful Chinese girl riding on a tiger", | ||||
|     "task(cxvkpawy8onnfti)", | ||||
|     "a  cute girl", | ||||
|     "", | ||||
|     [], | ||||
|     20, | ||||
|     "Euler a", | ||||
|     false, | ||||
|     false, | ||||
|     "DPM++ 2M Karras", | ||||
|     1, | ||||
|     1, | ||||
|     7, | ||||
|     -1, | ||||
|     -1, | ||||
|     0, | ||||
|     0, | ||||
|     0, | ||||
|     512, | ||||
|     512, | ||||
|     false, | ||||
|     512, | ||||
|     512, | ||||
|     true, | ||||
|     0.7, | ||||
|     2, | ||||
|     "ESRGAN_4x", | ||||
|     "Latent", | ||||
|     0, | ||||
|     0, | ||||
|     0, | ||||
|     "Use same checkpoint", | ||||
|     "Use same sampler", | ||||
|     "", | ||||
|     "", | ||||
|     [], | ||||
|     "None", | ||||
|     false, | ||||
|     "", | ||||
|     0.8, | ||||
|     -1, | ||||
|     false, | ||||
|     -1, | ||||
|     0, | ||||
|     0, | ||||
|     0, | ||||
|     null, | ||||
|     null, | ||||
|     null, | ||||
|     null, | ||||
|     false, | ||||
|     false, | ||||
| @@ -54,36 +59,13 @@ | ||||
|     false, | ||||
|     false, | ||||
|     0, | ||||
|     "Not set", | ||||
|     true, | ||||
|     true, | ||||
|     "", | ||||
|     "", | ||||
|     "", | ||||
|     "", | ||||
|     "", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     1.3, | ||||
|     "Not set", | ||||
|     null, | ||||
|     null, | ||||
|     false, | ||||
|     "None", | ||||
|     null, | ||||
|     null, | ||||
|     false, | ||||
|     null, | ||||
|     null, | ||||
|     false, | ||||
|     50, | ||||
| @@ -93,6 +75,6 @@ | ||||
|     "" | ||||
|   ], | ||||
|   "event_data": null, | ||||
|   "fn_index": 232, | ||||
|   "session_hash": "3xedmn4nuzq" | ||||
|   "fn_index": 446, | ||||
|   "session_hash": "nk5noh1rz1o" | ||||
| } | ||||
| @@ -1,42 +0,0 @@ | ||||
| package fun | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/mj" | ||||
| 	"chatplus/utils" | ||||
| ) | ||||
|  | ||||
| // AI 绘画函数 | ||||
|  | ||||
| type FuncMidJourney struct { | ||||
| 	name    string | ||||
| 	service *mj.Service | ||||
| } | ||||
|  | ||||
| func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney { | ||||
| 	return FuncMidJourney{ | ||||
| 		name:    "MidJourney AI 绘画", | ||||
| 		service: mjService} | ||||
| } | ||||
|  | ||||
| func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { | ||||
| 	logger.Infof("MJ 绘画参数:%+v", params) | ||||
| 	prompt := utils.InterfaceToString(params["prompt"]) | ||||
| 	f.service.PushTask(types.MjTask{ | ||||
| 		SessionId: utils.InterfaceToString(params["session_id"]), | ||||
| 		Src:       types.TaskSrcChat, | ||||
| 		Type:      types.TaskImage, | ||||
| 		Prompt:    prompt, | ||||
| 		UserId:    utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), | ||||
| 		RoleId:    utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), | ||||
| 		Icon:      utils.InterfaceToString(params["icon"]), | ||||
| 		ChatId:    utils.InterfaceToString(params["chat_id"]), | ||||
| 	}) | ||||
| 	return prompt, nil | ||||
| } | ||||
|  | ||||
| func (f FuncMidJourney) Name() string { | ||||
| 	return f.name | ||||
| } | ||||
|  | ||||
| var _ Function = &FuncMidJourney{} | ||||
| @@ -1,39 +0,0 @@ | ||||
| package fun | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/service/mj" | ||||
| ) | ||||
|  | ||||
| type Function interface { | ||||
| 	Invoke(map[string]interface{}) (string, error) | ||||
| 	Name() string | ||||
| } | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type resVo struct { | ||||
| 	Code    types.BizCode `json:"code"` | ||||
| 	Message string        `json:"message"` | ||||
| 	Data    struct { | ||||
| 		Title     string     `json:"title"` | ||||
| 		UpdatedAt string     `json:"updated_at"` | ||||
| 		Items     []dataItem `json:"items"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type dataItem struct { | ||||
| 	Title  string `json:"title"` | ||||
| 	Url    string `json:"url"` | ||||
| 	Remark string `json:"remark"` | ||||
| } | ||||
|  | ||||
| func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function { | ||||
| 	return map[string]Function{ | ||||
| 		types.FuncZaoBao:     NewZaoBao(config.ApiConfig), | ||||
| 		types.FuncWeibo:      NewWeiboHot(config.ApiConfig), | ||||
| 		types.FuncHeadLine:   NewHeadLines(config.ApiConfig), | ||||
| 		types.FuncMidJourney: NewMidJourneyFunc(mjService), | ||||
| 	} | ||||
| } | ||||
| @@ -1,58 +0,0 @@ | ||||
| package fun | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // 今日头条函数实现 | ||||
|  | ||||
| type FuncHeadlines struct { | ||||
| 	name   string | ||||
| 	config types.ChatPlusApiConfig | ||||
| 	client *req.Client | ||||
| } | ||||
|  | ||||
| func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines { | ||||
| 	return FuncHeadlines{ | ||||
| 		name:   "今日头条", | ||||
| 		config: config, | ||||
| 		client: req.C().SetTimeout(10 * time.Second)} | ||||
| } | ||||
|  | ||||
| func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) { | ||||
| 	if f.config.Token == "" { | ||||
| 		return "", errors.New("无效的 API Token") | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("%s/api/headline/fetch", f.config.ApiURL) | ||||
| 	var res resVo | ||||
| 	r, err := f.client.R(). | ||||
| 		SetHeader("AppId", f.config.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). | ||||
| 		SetSuccessResult(&res).Get(url) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		return "", fmt.Errorf("%v%v", err, r.Err) | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		return "", errors.New(res.Message) | ||||
| 	} | ||||
|  | ||||
| 	builder := make([]string, 0) | ||||
| 	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt)) | ||||
| 	for i, v := range res.Data.Items { | ||||
| 		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Url, v.Remark)) | ||||
| 	} | ||||
| 	return strings.Join(builder, "\n\n"), nil | ||||
| } | ||||
|  | ||||
| func (f FuncHeadlines) Name() string { | ||||
| 	return f.name | ||||
| } | ||||
|  | ||||
| var _ Function = &FuncHeadlines{} | ||||
| @@ -1,58 +0,0 @@ | ||||
| package fun | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // 微博热搜函数实现 | ||||
|  | ||||
| type FuncWeiboHot struct { | ||||
| 	name   string | ||||
| 	config types.ChatPlusApiConfig | ||||
| 	client *req.Client | ||||
| } | ||||
|  | ||||
| func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot { | ||||
| 	return FuncWeiboHot{ | ||||
| 		name:   "微博热搜", | ||||
| 		config: config, | ||||
| 		client: req.C().SetTimeout(10 * time.Second)} | ||||
| } | ||||
|  | ||||
| func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) { | ||||
| 	if f.config.Token == "" { | ||||
| 		return "", errors.New("无效的 API Token") | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("%s/api/weibo/fetch", f.config.ApiURL) | ||||
| 	var res resVo | ||||
| 	r, err := f.client.R(). | ||||
| 		SetHeader("AppId", f.config.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). | ||||
| 		SetSuccessResult(&res).Get(url) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		return "", fmt.Errorf("%v%v", err, r.Err) | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		return "", errors.New(res.Message) | ||||
| 	} | ||||
|  | ||||
| 	builder := make([]string, 0) | ||||
| 	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt)) | ||||
| 	for i, v := range res.Data.Items { | ||||
| 		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark)) | ||||
| 	} | ||||
| 	return strings.Join(builder, "\n\n"), nil | ||||
| } | ||||
|  | ||||
| func (f FuncWeiboHot) Name() string { | ||||
| 	return f.name | ||||
| } | ||||
|  | ||||
| var _ Function = &FuncWeiboHot{} | ||||
| @@ -1,59 +0,0 @@ | ||||
| package fun | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // 每日早报函数实现 | ||||
|  | ||||
| type FuncZaoBao struct { | ||||
| 	name   string | ||||
| 	config types.ChatPlusApiConfig | ||||
| 	client *req.Client | ||||
| } | ||||
|  | ||||
| func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao { | ||||
| 	return FuncZaoBao{ | ||||
| 		name:   "每日早报", | ||||
| 		config: config, | ||||
| 		client: req.C().SetTimeout(10 * time.Second)} | ||||
| } | ||||
|  | ||||
| func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) { | ||||
| 	if f.config.Token == "" { | ||||
| 		return "", errors.New("无效的 API Token") | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("%s/api/zaobao/fetch", f.config.ApiURL) | ||||
| 	var res resVo | ||||
| 	r, err := f.client.R(). | ||||
| 		SetHeader("AppId", f.config.AppId). | ||||
| 		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). | ||||
| 		SetSuccessResult(&res).Get(url) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		return "", fmt.Errorf("%v%v", err, r.Err) | ||||
| 	} | ||||
|  | ||||
| 	if res.Code != types.Success { | ||||
| 		return "", errors.New(res.Message) | ||||
| 	} | ||||
|  | ||||
| 	builder := make([]string, 0) | ||||
| 	builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt)) | ||||
| 	for _, v := range res.Data.Items { | ||||
| 		builder = append(builder, v.Title) | ||||
| 	} | ||||
| 	builder = append(builder, fmt.Sprintf("%s", res.Data.Title)) | ||||
| 	return strings.Join(builder, "\n\n"), nil | ||||
| } | ||||
|  | ||||
| func (f FuncZaoBao) Name() string { | ||||
| 	return f.name | ||||
| } | ||||
|  | ||||
| var _ Function = &FuncZaoBao{} | ||||
| @@ -4,7 +4,7 @@ import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/utils" | ||||
| 	"github.com/bwmarrin/discordgo" | ||||
| 	discordgo "github.com/bg5t/mydiscordgo" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| @@ -17,32 +17,49 @@ import ( | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type Bot struct { | ||||
| 	config  *types.MidJourneyConfig | ||||
| 	config  types.MidJourneyConfig | ||||
| 	bot     *discordgo.Session | ||||
| 	name    string | ||||
| 	service *Service | ||||
| } | ||||
|  | ||||
| func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { | ||||
| 	discord, err := discordgo.New("Bot " + config.MjConfig.BotToken) | ||||
| func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) { | ||||
| 	bot, err := discordgo.New("Bot " + config.BotToken) | ||||
| 	if err != nil { | ||||
| 		logger.Error(err) | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if config.ProxyURL != "" { | ||||
| 		proxy, _ := url.Parse(config.ProxyURL) | ||||
| 		discord.Client = &http.Client{ | ||||
| 			Transport: &http.Transport{ | ||||
| 	// use CDN reverse proxy | ||||
| 	if config.UseCDN { | ||||
| 		discordgo.SetEndpointDiscord(config.DiscordAPI) | ||||
| 		discordgo.SetEndpointCDN("https://cdn.discordapp.com") | ||||
| 		discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/") | ||||
| 		bot.MjGateway = config.DiscordGateway + "/" | ||||
| 	} else { // use proxy | ||||
| 		discordgo.SetEndpointDiscord("https://discord.com") | ||||
| 		discordgo.SetEndpointCDN("https://cdn.discordapp.com") | ||||
| 		discordgo.SetEndpointStatus("https://discord.com/api/v2/") | ||||
| 		bot.MjGateway = "wss://gateway.discord.gg" | ||||
|  | ||||
| 		if proxy != "" { | ||||
| 			proxy, _ := url.Parse(proxy) | ||||
| 			bot.Client = &http.Client{ | ||||
| 				Transport: &http.Transport{ | ||||
| 					Proxy: http.ProxyURL(proxy), | ||||
| 				}, | ||||
| 			} | ||||
| 			bot.Dialer = &websocket.Dialer{ | ||||
| 				Proxy: http.ProxyURL(proxy), | ||||
| 			}, | ||||
| 		} | ||||
| 		discord.Dialer = &websocket.Dialer{ | ||||
| 			Proxy: http.ProxyURL(proxy), | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	return &Bot{ | ||||
| 		config:  &config.MjConfig, | ||||
| 		bot:     discord, | ||||
| 		config:  config, | ||||
| 		bot:     bot, | ||||
| 		name:    name, | ||||
| 		service: service, | ||||
| 	}, nil | ||||
| } | ||||
| @@ -52,13 +69,13 @@ func (b *Bot) Run() error { | ||||
| 	b.bot.AddHandler(b.messageCreate) | ||||
| 	b.bot.AddHandler(b.messageUpdate) | ||||
|  | ||||
| 	logger.Info("Starting MidJourney Bot...") | ||||
| 	logger.Infof("Starting MidJourney %s", b.name) | ||||
| 	err := b.bot.Open() | ||||
| 	if err != nil { | ||||
| 		logger.Error("Error opening Discord connection:", err) | ||||
| 		logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err) | ||||
| 		return err | ||||
| 	} | ||||
| 	logger.Info("Starting MidJourney Bot successfully!") | ||||
| 	logger.Infof("Starting MidJourney %s successfully!", b.name) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -87,7 +104,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { | ||||
| 		return | ||||
| 	} | ||||
| 	// ignore messages for self | ||||
| 	if m.Author.ID == s.State.User.ID { | ||||
| 	if m.Author == nil || m.Author.ID == s.State.User.ID { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -99,6 +116,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { | ||||
| 	if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") { | ||||
| 		// parse content | ||||
| 		req := CBReq{ | ||||
| 			ChannelId:   m.ChannelID, | ||||
| 			MessageId:   m.ID, | ||||
| 			ReferenceId: referenceId, | ||||
| 			Prompt:      extractPrompt(m.Content), | ||||
| @@ -109,7 +127,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) | ||||
| 	b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments) | ||||
| } | ||||
|  | ||||
| func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { | ||||
| @@ -118,7 +136,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { | ||||
| 		return | ||||
| 	} | ||||
| 	// ignore messages for self | ||||
| 	if m.Author.ID == s.State.User.ID { | ||||
| 	if m.Author == nil || m.Author.ID == s.State.User.ID { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -130,6 +148,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { | ||||
| 	} | ||||
| 	if strings.Contains(m.Content, "(Stopped)") { | ||||
| 		req := CBReq{ | ||||
| 			ChannelId:   m.ChannelID, | ||||
| 			MessageId:   m.ID, | ||||
| 			ReferenceId: referenceId, | ||||
| 			Prompt:      extractPrompt(m.Content), | ||||
| @@ -140,11 +159,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) | ||||
| 	b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments) | ||||
|  | ||||
| } | ||||
|  | ||||
| func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { | ||||
| func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { | ||||
| 	progress := extractProgress(content) | ||||
| 	var status TaskStatus | ||||
| 	if progress == 100 { | ||||
| @@ -166,6 +185,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string | ||||
| 			Hash:     extractHashFromFilename(attachment.Filename), | ||||
| 		} | ||||
| 		req := CBReq{ | ||||
| 			ChannelId:   channelId, | ||||
| 			MessageId:   messageId, | ||||
| 			ReferenceId: referenceId, | ||||
| 			Image:       image, | ||||
|   | ||||
| @@ -2,36 +2,46 @@ package mj | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/imroc/req/v3" | ||||
| ) | ||||
|  | ||||
| // MidJourney client | ||||
|  | ||||
| type Client struct { | ||||
| 	client *req.Client | ||||
| 	config *types.MidJourneyConfig | ||||
| 	Config types.MidJourneyConfig | ||||
| 	apiURL string | ||||
| } | ||||
|  | ||||
| func NewClient(config *types.AppConfig) *Client { | ||||
| func NewClient(config types.MidJourneyConfig, proxy string) *Client { | ||||
| 	client := req.C().SetTimeout(10 * time.Second) | ||||
| 	var apiURL string | ||||
| 	// set proxy URL | ||||
| 	if config.ProxyURL != "" { | ||||
| 		client.SetProxyURL(config.ProxyURL) | ||||
| 	if config.UseCDN { | ||||
| 		apiURL = config.DiscordAPI + "/api/v9/interactions" | ||||
| 	} else { | ||||
| 		apiURL = "https://discord.com/api/v9/interactions" | ||||
| 		if proxy != "" { | ||||
| 			client.SetProxyURL(proxy) | ||||
| 		} | ||||
| 	} | ||||
| 	return &Client{client: client, config: &config.MjConfig} | ||||
|  | ||||
| 	return &Client{client: client, Config: config, apiURL: apiURL} | ||||
| } | ||||
|  | ||||
| func (c *Client) Imagine(prompt string) error { | ||||
| func (c *Client) Imagine(task types.MjTask) error { | ||||
| 	interactionsReq := &InteractionsRequest{ | ||||
| 		Type:          2, | ||||
| 		ApplicationID: ApplicationID, | ||||
| 		GuildID:       c.config.GuildId, | ||||
| 		ChannelID:     c.config.ChanelId, | ||||
| 		GuildID:       c.Config.GuildId, | ||||
| 		ChannelID:     c.Config.ChanelId, | ||||
| 		SessionID:     SessionID, | ||||
| 		Data: map[string]any{ | ||||
| 			"version": "1118961510123847772", | ||||
| 			"version": "1166847114203123795", | ||||
| 			"id":      "938956540159881230", | ||||
| 			"name":    "imagine", | ||||
| 			"type":    "1", | ||||
| @@ -39,7 +49,7 @@ func (c *Client) Imagine(prompt string) error { | ||||
| 				{ | ||||
| 					"type":  3, | ||||
| 					"name":  "prompt", | ||||
| 					"value": prompt, | ||||
| 					"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt), | ||||
| 				}, | ||||
| 			}, | ||||
| 			"application_command": map[string]any{ | ||||
| @@ -66,11 +76,10 @@ func (c *Client) Imagine(prompt string) error { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	url := "https://discord.com/api/v9/interactions" | ||||
| 	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). | ||||
| 	r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). | ||||
| 		SetHeader("Content-Type", "application/json"). | ||||
| 		SetBody(interactionsReq). | ||||
| 		Post(url) | ||||
| 		Post(c.apiURL) | ||||
|  | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		return fmt.Errorf("error with http request: %w%v", err, r.Err) | ||||
| @@ -79,31 +88,38 @@ func (c *Client) Imagine(prompt string) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Client) Blend(task types.MjTask) error { | ||||
| 	return errors.New("function not implemented") | ||||
| } | ||||
|  | ||||
| func (c *Client) SwapFace(task types.MjTask) error { | ||||
| 	return errors.New("function not implemented") | ||||
| } | ||||
|  | ||||
| // Upscale 放大指定的图片 | ||||
| func (c *Client) Upscale(index int, messageId string, hash string) error { | ||||
| func (c *Client) Upscale(task types.MjTask) error { | ||||
| 	flags := 0 | ||||
| 	interactionsReq := &InteractionsRequest{ | ||||
| 		Type:          3, | ||||
| 		ApplicationID: ApplicationID, | ||||
| 		GuildID:       c.config.GuildId, | ||||
| 		ChannelID:     c.config.ChanelId, | ||||
| 		MessageFlags:  &flags, | ||||
| 		MessageID:     &messageId, | ||||
| 		GuildID:       c.Config.GuildId, | ||||
| 		ChannelID:     c.Config.ChanelId, | ||||
| 		MessageFlags:  flags, | ||||
| 		MessageID:     task.MessageId, | ||||
| 		SessionID:     SessionID, | ||||
| 		Data: map[string]any{ | ||||
| 			"component_type": 2, | ||||
| 			"custom_id":      fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash), | ||||
| 			"custom_id":      fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), | ||||
| 		}, | ||||
| 		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), | ||||
| 	} | ||||
|  | ||||
| 	url := "https://discord.com/api/v9/interactions" | ||||
| 	var res InteractionsResult | ||||
| 	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). | ||||
| 	r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). | ||||
| 		SetHeader("Content-Type", "application/json"). | ||||
| 		SetBody(interactionsReq). | ||||
| 		SetErrorResult(&res). | ||||
| 		Post(url) | ||||
| 		Post(c.apiURL) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) | ||||
| 	} | ||||
| @@ -112,30 +128,29 @@ func (c *Client) Upscale(index int, messageId string, hash string) error { | ||||
| } | ||||
|  | ||||
| // Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 | ||||
| func (c *Client) Variation(index int, messageId string, hash string) error { | ||||
| func (c *Client) Variation(task types.MjTask) error { | ||||
| 	flags := 0 | ||||
| 	interactionsReq := &InteractionsRequest{ | ||||
| 		Type:          3, | ||||
| 		ApplicationID: ApplicationID, | ||||
| 		GuildID:       c.config.GuildId, | ||||
| 		ChannelID:     c.config.ChanelId, | ||||
| 		MessageFlags:  &flags, | ||||
| 		MessageID:     &messageId, | ||||
| 		GuildID:       c.Config.GuildId, | ||||
| 		ChannelID:     c.Config.ChanelId, | ||||
| 		MessageFlags:  flags, | ||||
| 		MessageID:     task.MessageId, | ||||
| 		SessionID:     SessionID, | ||||
| 		Data: map[string]any{ | ||||
| 			"component_type": 2, | ||||
| 			"custom_id":      fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash), | ||||
| 			"custom_id":      fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), | ||||
| 		}, | ||||
| 		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), | ||||
| 	} | ||||
|  | ||||
| 	url := "https://discord.com/api/v9/interactions" | ||||
| 	var res InteractionsResult | ||||
| 	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). | ||||
| 	r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). | ||||
| 		SetHeader("Content-Type", "application/json"). | ||||
| 		SetBody(interactionsReq). | ||||
| 		SetErrorResult(&res). | ||||
| 		Post(url) | ||||
| 		Post(c.apiURL) | ||||
| 	if err != nil || r.IsErrorState() { | ||||
| 		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										292
									
								
								api/service/mj/plus/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										292
									
								
								api/service/mj/plus/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,292 @@ | ||||
| package plus | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"chatplus/utils" | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"io" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| // Client MidJourney Plus Client | ||||
| type Client struct { | ||||
| 	Config types.MidJourneyPlusConfig | ||||
| 	apiURL string | ||||
| } | ||||
|  | ||||
| func NewClient(config types.MidJourneyPlusConfig) *Client { | ||||
| 	var apiURL string | ||||
| 	if config.CdnURL != "" { | ||||
| 		apiURL = config.CdnURL | ||||
| 	} else { | ||||
| 		apiURL = config.ApiURL | ||||
| 	} | ||||
| 	if config.Mode == "" { | ||||
| 		config.Mode = "fast" | ||||
| 	} | ||||
| 	return &Client{Config: config, apiURL: apiURL} | ||||
| } | ||||
|  | ||||
| type ImageReq struct { | ||||
| 	BotType       string   `json:"botType"` | ||||
| 	Prompt        string   `json:"prompt,omitempty"` | ||||
| 	Dimensions    string   `json:"dimensions,omitempty"` | ||||
| 	Base64Array   []string `json:"base64Array,omitempty"` | ||||
| 	AccountFilter struct { | ||||
| 		InstanceId          string        `json:"instanceId"` | ||||
| 		Modes               []interface{} `json:"modes"` | ||||
| 		Remix               bool          `json:"remix"` | ||||
| 		RemixAutoConsidered bool          `json:"remixAutoConsidered"` | ||||
| 	} `json:"accountFilter,omitempty"` | ||||
| 	NotifyHook string `json:"notifyHook"` | ||||
| 	State      string `json:"state,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageRes struct { | ||||
| 	Code        int    `json:"code"` | ||||
| 	Description string `json:"description"` | ||||
| 	Properties  struct { | ||||
| 	} `json:"properties"` | ||||
| 	Result string `json:"result"` | ||||
| } | ||||
|  | ||||
| type ErrRes struct { | ||||
| 	Error struct { | ||||
| 		Message string `json:"message"` | ||||
| 	} `json:"error"` | ||||
| } | ||||
|  | ||||
| func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) | ||||
| 	body := ImageReq{ | ||||
| 		BotType:     "MID_JOURNEY", | ||||
| 		Prompt:      task.Prompt, | ||||
| 		NotifyHook:  c.Config.NotifyURL, | ||||
| 		Base64Array: make([]string, 0), | ||||
| 	} | ||||
| 	// 生成图片 Base64 编码 | ||||
| 	if len(task.ImgArr) > 0 { | ||||
| 		imageData, err := utils.DownloadImage(task.ImgArr[0], "") | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with download image: ", err) | ||||
| 		} else { | ||||
| 			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		errStr, _ := io.ReadAll(r.Body) | ||||
| 		logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL) | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		errStr, _ := io.ReadAll(r.Body) | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| // Blend 融图 | ||||
| func (c *Client) Blend(task types.MjTask) (ImageRes, error) { | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) | ||||
| 	body := ImageReq{ | ||||
| 		BotType:     "MID_JOURNEY", | ||||
| 		Dimensions:  "SQUARE", | ||||
| 		NotifyHook:  c.Config.NotifyURL, | ||||
| 		Base64Array: make([]string, 0), | ||||
| 	} | ||||
| 	// 生成图片 Base64 编码 | ||||
| 	if len(task.ImgArr) > 0 { | ||||
| 		for _, imgURL := range task.ImgArr { | ||||
| 			imageData, err := utils.DownloadImage(imgURL, "") | ||||
| 			if err != nil { | ||||
| 				logger.Error("error with download image: ", err) | ||||
| 			} else { | ||||
| 				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		errStr, _ := io.ReadAll(r.Body) | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr)) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| // SwapFace 换脸 | ||||
| func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { | ||||
| 	apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) | ||||
| 	// 生成图片 Base64 编码 | ||||
| 	if len(task.ImgArr) != 2 { | ||||
| 		return ImageRes{}, errors.New("参数错误,必须上传2张图片") | ||||
| 	} | ||||
| 	var sourceBase64 string | ||||
| 	var targetBase64 string | ||||
| 	imageData, err := utils.DownloadImage(task.ImgArr[0], "") | ||||
| 	if err != nil { | ||||
| 		logger.Error("error with download image: ", err) | ||||
| 	} else { | ||||
| 		sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) | ||||
| 	} | ||||
| 	imageData, err = utils.DownloadImage(task.ImgArr[1], "") | ||||
| 	if err != nil { | ||||
| 		logger.Error("error with download image: ", err) | ||||
| 	} else { | ||||
| 		targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) | ||||
| 	} | ||||
|  | ||||
| 	body := gin.H{ | ||||
| 		"sourceBase64": sourceBase64, | ||||
| 		"targetBase64": targetBase64, | ||||
| 		"accountFilter": gin.H{ | ||||
| 			"instanceId": "", | ||||
| 		}, | ||||
| 		"notifyHook": c.Config.NotifyURL, | ||||
| 		"state":      "", | ||||
| 	} | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		errStr, _ := io.ReadAll(r.Body) | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr)) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| // Upscale 放大指定的图片 | ||||
| func (c *Client) Upscale(task types.MjTask) (ImageRes, error) { | ||||
| 	body := map[string]string{ | ||||
| 		"customId":   fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), | ||||
| 		"taskId":     task.MessageId, | ||||
| 		"notifyHook": c.Config.NotifyURL, | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| // Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 | ||||
| func (c *Client) Variation(task types.MjTask) (ImageRes, error) { | ||||
| 	body := map[string]string{ | ||||
| 		"customId":   fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), | ||||
| 		"taskId":     task.MessageId, | ||||
| 		"notifyHook": c.Config.NotifyURL, | ||||
| 	} | ||||
| 	apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) | ||||
| 	var res ImageRes | ||||
| 	var errRes ErrRes | ||||
| 	r, err := req.C().R(). | ||||
| 		SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetBody(body). | ||||
| 		SetSuccessResult(&res). | ||||
| 		SetErrorResult(&errRes). | ||||
| 		Post(apiURL) | ||||
| 	if err != nil { | ||||
| 		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| type QueryRes struct { | ||||
| 	Action  string `json:"action"` | ||||
| 	Buttons []struct { | ||||
| 		CustomId string `json:"customId"` | ||||
| 		Emoji    string `json:"emoji"` | ||||
| 		Label    string `json:"label"` | ||||
| 		Style    int    `json:"style"` | ||||
| 		Type     int    `json:"type"` | ||||
| 	} `json:"buttons"` | ||||
| 	Description string `json:"description"` | ||||
| 	FailReason  string `json:"failReason"` | ||||
| 	FinishTime  int    `json:"finishTime"` | ||||
| 	Id          string `json:"id"` | ||||
| 	ImageUrl    string `json:"imageUrl"` | ||||
| 	Progress    string `json:"progress"` | ||||
| 	Prompt      string `json:"prompt"` | ||||
| 	PromptEn    string `json:"promptEn"` | ||||
| 	Properties  struct { | ||||
| 	} `json:"properties"` | ||||
| 	StartTime  int    `json:"startTime"` | ||||
| 	State      string `json:"state"` | ||||
| 	Status     string `json:"status"` | ||||
| 	SubmitTime int    `json:"submitTime"` | ||||
| } | ||||
|  | ||||
| func (c *Client) QueryTask(taskId string) (QueryRes, error) { | ||||
| 	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) | ||||
| 	var res QueryRes | ||||
| 	r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). | ||||
| 		SetSuccessResult(&res). | ||||
| 		Get(apiURL) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return QueryRes{}, err | ||||
| 	} | ||||
|  | ||||
| 	if r.IsErrorState() { | ||||
| 		return QueryRes{}, errors.New("error status:" + r.Status) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
							
								
								
									
										198
									
								
								api/service/mj/plus/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								api/service/mj/plus/service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,198 @@ | ||||
| package plus | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // Service MJ 绘画服务 | ||||
| type Service struct { | ||||
| 	Name             string  // service Name | ||||
| 	Client           *Client // MJ Client | ||||
| 	taskQueue        *store.RedisQueue | ||||
| 	notifyQueue      *store.RedisQueue | ||||
| 	db               *gorm.DB | ||||
| 	maxHandleTaskNum int32             // max task number current service can handle | ||||
| 	HandledTaskNum   int32             // already handled task number | ||||
| 	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout | ||||
| 	taskTimeout      int64 | ||||
| } | ||||
|  | ||||
| func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { | ||||
| 	return &Service{ | ||||
| 		Name:             name, | ||||
| 		db:               db, | ||||
| 		taskQueue:        taskQueue, | ||||
| 		notifyQueue:      notifyQueue, | ||||
| 		Client:           client, | ||||
| 		taskTimeout:      timeout, | ||||
| 		maxHandleTaskNum: maxTaskNum, | ||||
| 		taskStartTimes:   make(map[int]time.Time, 0), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	logger.Infof("Starting MidJourney job consumer for %s", s.Name) | ||||
| 	for { | ||||
| 		s.checkTasks() | ||||
| 		if !s.canHandleTask() { | ||||
| 			// current service is full, can not handle more task | ||||
| 			// waiting for running task finish | ||||
| 			time.Sleep(time.Second * 3) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		var task types.MjTask | ||||
| 		err := s.taskQueue.LPop(&task) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("taking task with error: %v", err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// if it's reference message, check if it's this channel's  message | ||||
| 		//if task.ChannelId != "" && task.ChannelId != s.Name { | ||||
| 		//	logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) | ||||
| 		//	s.taskQueue.RPush(task) | ||||
| 		//	time.Sleep(time.Second) | ||||
| 		//	continue | ||||
| 		//} | ||||
|  | ||||
| 		logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) | ||||
| 		var res ImageRes | ||||
| 		switch task.Type { | ||||
| 		case types.TaskImage: | ||||
| 			res, err = s.Client.Imagine(task) | ||||
| 			break | ||||
| 		case types.TaskUpscale: | ||||
| 			res, err = s.Client.Upscale(task) | ||||
| 			break | ||||
| 		case types.TaskVariation: | ||||
| 			res, err = s.Client.Variation(task) | ||||
| 			break | ||||
| 		case types.TaskBlend: | ||||
| 			res, err = s.Client.Blend(task) | ||||
| 			break | ||||
| 		case types.TaskSwapFace: | ||||
| 			res, err = s.Client.SwapFace(task) | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		var job model.MidJourneyJob | ||||
| 		s.db.Where("id = ?", task.Id).First(&job) | ||||
| 		if err != nil || (res.Code != 1 && res.Code != 22) { | ||||
| 			errMsg := fmt.Sprintf("%v,%s", err, res.Description) | ||||
| 			logger.Error("绘画任务执行失败:", errMsg) | ||||
| 			job.Progress = -1 | ||||
| 			job.ErrMsg = errMsg | ||||
| 			// update the task progress | ||||
| 			s.db.Updates(&job) | ||||
| 			// 任务失败,通知前端 | ||||
| 			s.notifyQueue.RPush(task.UserId) | ||||
| 			continue | ||||
| 		} | ||||
| 		logger.Infof("任务提交成功:%+v", res) | ||||
| 		// lock the task until the execute timeout | ||||
| 		s.taskStartTimes[int(task.Id)] = time.Now() | ||||
| 		atomic.AddInt32(&s.HandledTaskNum, 1) | ||||
| 		// 更新任务 ID/频道 | ||||
| 		job.TaskId = res.Result | ||||
| 		job.ChannelId = s.Name | ||||
| 		s.db.Updates(&job) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // check if current service instance can handle more task | ||||
| func (s *Service) canHandleTask() bool { | ||||
| 	handledNum := atomic.LoadInt32(&s.HandledTaskNum) | ||||
| 	return handledNum < s.maxHandleTaskNum | ||||
| } | ||||
|  | ||||
| // remove the expired tasks | ||||
| func (s *Service) checkTasks() { | ||||
| 	for k, t := range s.taskStartTimes { | ||||
| 		if time.Now().Unix()-t.Unix() > s.taskTimeout { | ||||
| 			delete(s.taskStartTimes, k) | ||||
| 			atomic.AddInt32(&s.HandledTaskNum, -1) | ||||
| 			// delete task from database | ||||
| 			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type CBReq struct { | ||||
| 	Id          string      `json:"id"` | ||||
| 	Action      string      `json:"action"` | ||||
| 	Status      string      `json:"status"` | ||||
| 	Prompt      string      `json:"prompt"` | ||||
| 	PromptEn    string      `json:"promptEn"` | ||||
| 	Description string      `json:"description"` | ||||
| 	SubmitTime  int64       `json:"submitTime"` | ||||
| 	StartTime   int64       `json:"startTime"` | ||||
| 	FinishTime  int64       `json:"finishTime"` | ||||
| 	Progress    string      `json:"progress"` | ||||
| 	ImageUrl    string      `json:"imageUrl"` | ||||
| 	FailReason  interface{} `json:"failReason"` | ||||
| 	Properties  struct { | ||||
| 		FinalPrompt string `json:"finalPrompt"` | ||||
| 	} `json:"properties"` | ||||
| } | ||||
|  | ||||
| func (s *Service) Notify(job model.MidJourneyJob) error { | ||||
| 	task, err := s.Client.QueryTask(job.TaskId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 任务执行失败了 | ||||
| 	if task.FailReason != "" { | ||||
| 		s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ | ||||
| 			"progress": -1, | ||||
| 			"err_msg":  task.FailReason, | ||||
| 		}) | ||||
| 		return fmt.Errorf("task failed: %v", task.FailReason) | ||||
| 	} | ||||
|  | ||||
| 	if len(task.Buttons) > 0 { | ||||
| 		job.Hash = GetImageHash(task.Buttons[0].CustomId) | ||||
| 	} | ||||
| 	oldProgress := job.Progress | ||||
| 	job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) | ||||
| 	job.Prompt = task.PromptEn | ||||
| 	if task.ImageUrl != "" { | ||||
| 		if s.Client.Config.CdnURL != "" { | ||||
| 			job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1) | ||||
| 		} else { | ||||
| 			job.OrgURL = task.ImageUrl | ||||
| 		} | ||||
| 	} | ||||
| 	job.MessageId = task.Id | ||||
| 	tx := s.db.Updates(&job) | ||||
| 	if tx.Error != nil { | ||||
| 		return fmt.Errorf("error with update database: %v", tx.Error) | ||||
| 	} | ||||
| 	if task.Status == "SUCCESS" { | ||||
| 		// release lock task | ||||
| 		atomic.AddInt32(&s.HandledTaskNum, -1) | ||||
| 	} | ||||
| 	// 通知前端更新任务进度 | ||||
| 	if oldProgress != job.Progress { | ||||
| 		s.notifyQueue.RPush(job.UserId) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetImageHash(action string) string { | ||||
| 	split := strings.Split(action, "::") | ||||
| 	if len(split) > 5 { | ||||
| 		return split[4] | ||||
| 	} | ||||
| 	return split[len(split)-1] | ||||
| } | ||||
							
								
								
									
										248
									
								
								api/service/mj/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								api/service/mj/pool.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,248 @@ | ||||
| package mj | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/mj/plus" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"fmt" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // ServicePool Mj service pool | ||||
| type ServicePool struct { | ||||
| 	services        []interface{} | ||||
| 	taskQueue       *store.RedisQueue | ||||
| 	notifyQueue     *store.RedisQueue | ||||
| 	db              *gorm.DB | ||||
| 	uploaderManager *oss.UploaderManager | ||||
| 	Clients         *types.LMap[uint, *types.WsClient] // UserId => Client | ||||
| } | ||||
|  | ||||
| func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { | ||||
| 	services := make([]interface{}, 0) | ||||
| 	taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) | ||||
| 	notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) | ||||
|  | ||||
| 	for k, config := range appConfig.MjPlusConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
| 		client := plus.NewClient(config) | ||||
| 		name := fmt.Sprintf("mj-service-plus-%d", k) | ||||
| 		servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client) | ||||
| 		go func() { | ||||
| 			servicePlus.Run() | ||||
| 		}() | ||||
| 		services = append(services, servicePlus) | ||||
| 	} | ||||
|  | ||||
| 	if len(services) == 0 { | ||||
| 		// create mj client and service | ||||
| 		for k, config := range appConfig.MjConfigs { | ||||
| 			if config.Enabled == false { | ||||
| 				continue | ||||
| 			} | ||||
| 			// create mj client | ||||
| 			client := NewClient(config, appConfig.ProxyURL) | ||||
|  | ||||
| 			name := fmt.Sprintf("MjService-%d", k) | ||||
| 			// create mj service | ||||
| 			service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client) | ||||
| 			botName := fmt.Sprintf("MjBot-%d", k) | ||||
| 			bot, err := NewBot(botName, appConfig.ProxyURL, config, service) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			err = bot.Run() | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// run mj service | ||||
| 			go func() { | ||||
| 				service.Run() | ||||
| 			}() | ||||
|  | ||||
| 			services = append(services, service) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &ServicePool{ | ||||
| 		taskQueue:       taskQueue, | ||||
| 		notifyQueue:     notifyQueue, | ||||
| 		services:        services, | ||||
| 		uploaderManager: manager, | ||||
| 		db:              db, | ||||
| 		Clients:         types.NewLMap[uint, *types.WsClient](), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) CheckTaskNotify() { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var userId uint | ||||
| 			err := p.notifyQueue.LPop(&userId) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			client := p.Clients.Get(userId) | ||||
| 			if client == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			err = client.Send([]byte("Task Updated")) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) DownloadImages() { | ||||
| 	go func() { | ||||
| 		var items []model.MidJourneyJob | ||||
| 		for { | ||||
| 			res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) | ||||
| 			if res.Error != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// download images | ||||
| 			for _, v := range items { | ||||
| 				if v.OrgURL == "" { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				logger.Infof("try to download image: %s", v.OrgURL) | ||||
| 				var imgURL string | ||||
| 				var err error | ||||
| 				if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil { | ||||
| 					task, _ := servicePlus.Client.QueryTask(v.TaskId) | ||||
| 					if len(task.Buttons) > 0 { | ||||
| 						v.Hash = plus.GetImageHash(task.Buttons[0].CustomId) | ||||
| 					} | ||||
| 					imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false) | ||||
| 				} else { | ||||
| 					imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true) | ||||
| 				} | ||||
| 				if err != nil { | ||||
| 					logger.Errorf("error with download image %s, %v", v.OrgURL, err) | ||||
| 					continue | ||||
| 				} else { | ||||
| 					logger.Infof("download image %s successfully.", v.OrgURL) | ||||
| 				} | ||||
|  | ||||
| 				v.ImgURL = imgURL | ||||
| 				p.db.Updates(&v) | ||||
|  | ||||
| 				client := p.Clients.Get(uint(v.UserId)) | ||||
| 				if client == nil { | ||||
| 					continue | ||||
| 				} | ||||
| 				err = client.Send([]byte("Task Updated")) | ||||
| 				if err != nil { | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| // PushTask push a new mj task in to task queue | ||||
| func (p *ServicePool) PushTask(task types.MjTask) { | ||||
| 	logger.Debugf("add a new MidJourney task to the task list: %+v", task) | ||||
| 	p.taskQueue.RPush(task) | ||||
| } | ||||
|  | ||||
| // HasAvailableService check if it has available mj service in pool | ||||
| func (p *ServicePool) HasAvailableService() bool { | ||||
| 	return len(p.services) > 0 | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) Notify(data plus.CBReq) error { | ||||
| 	logger.Debugf("收到任务回调:%+v", data) | ||||
| 	var job model.MidJourneyJob | ||||
| 	res := p.db.Where("task_id = ?", data.Id).First(&job) | ||||
| 	if res.Error != nil { | ||||
| 		return fmt.Errorf("非法任务:%s", data.Id) | ||||
| 	} | ||||
|  | ||||
| 	// 任务已经拉取完成 | ||||
| 	if job.Progress == 100 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { | ||||
| 		return servicePlus.Notify(job) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SyncTaskProgress 异步拉取任务 | ||||
| func (p *ServicePool) SyncTaskProgress() { | ||||
| 	go func() { | ||||
| 		var items []model.MidJourneyJob | ||||
| 		for { | ||||
| 			res := p.db.Where("progress < ?", 100).Find(&items) | ||||
| 			if res.Error != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for _, job := range items { | ||||
| 				// 失败或者 30 分钟还没完成的任务删除并退回算力 | ||||
| 				if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 { | ||||
| 					p.db.Delete(&job) | ||||
| 					// 略过 Upscale 任务 | ||||
| 					if job.Type != types.TaskUpscale.String() { | ||||
| 						continue | ||||
| 					} | ||||
| 					tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) | ||||
| 					if tx.Error == nil && tx.RowsAffected > 0 { | ||||
| 						var user model.User | ||||
| 						p.db.Where("id = ?", job.UserId).First(&user) | ||||
| 						p.db.Create(&model.PowerLog{ | ||||
| 							UserId:    user.Id, | ||||
| 							Username:  user.Username, | ||||
| 							Type:      types.PowerConsume, | ||||
| 							Amount:    job.Power, | ||||
| 							Balance:   user.Power + job.Power, | ||||
| 							Mark:      types.PowerAdd, | ||||
| 							Model:     "mid-journey", | ||||
| 							Remark:    fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId), | ||||
| 							CreatedAt: time.Now(), | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 				if !strings.HasPrefix(job.ChannelId, "mj-service-plus") { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { | ||||
| 					_ = servicePlus.Notify(job) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Second) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) getServicePlus(name string) *plus.Service { | ||||
| 	for _, s := range p.services { | ||||
| 		if servicePlus, ok := s.(*plus.Service); ok { | ||||
| 			if servicePlus.Name == name { | ||||
| 				return servicePlus | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -2,110 +2,128 @@ package mj | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // MJ 绘画服务 | ||||
|  | ||||
| const RunningJobKey = "MidJourney_Running_Job" | ||||
|  | ||||
| // Service MJ 绘画服务 | ||||
| type Service struct { | ||||
| 	client        *Client // MJ 客户端 | ||||
| 	taskQueue     *store.RedisQueue | ||||
| 	redis         *redis.Client | ||||
| 	db            *gorm.DB | ||||
| 	uploadManager *oss.UploaderManager | ||||
| 	Clients       *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息 | ||||
| 	ChatClients   *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息 | ||||
| 	proxyURL      string | ||||
| 	name             string  // service name | ||||
| 	client           *Client // MJ client | ||||
| 	taskQueue        *store.RedisQueue | ||||
| 	notifyQueue      *store.RedisQueue | ||||
| 	db               *gorm.DB | ||||
| 	maxHandleTaskNum int32             // max task number current service can handle | ||||
| 	handledTaskNum   int32             // already handled task number | ||||
| 	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout | ||||
| 	taskTimeout      int64 | ||||
| } | ||||
|  | ||||
| func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { | ||||
| func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { | ||||
| 	return &Service{ | ||||
| 		redis:         redisCli, | ||||
| 		db:            db, | ||||
| 		taskQueue:     store.NewRedisQueue("MidJourney_Task_Queue", redisCli), | ||||
| 		client:        client, | ||||
| 		uploadManager: manager, | ||||
| 		Clients:       types.NewLMap[string, *types.WsClient](), | ||||
| 		ChatClients:   types.NewLMap[string, *types.WsClient](), | ||||
| 		proxyURL:      config.ProxyURL, | ||||
| 		name:             name, | ||||
| 		db:               db, | ||||
| 		taskQueue:        taskQueue, | ||||
| 		notifyQueue:      notifyQueue, | ||||
| 		client:           client, | ||||
| 		taskTimeout:      timeout, | ||||
| 		maxHandleTaskNum: maxTaskNum, | ||||
| 		taskStartTimes:   make(map[int]time.Time, 0), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	logger.Info("Starting MidJourney job consumer.") | ||||
| 	ctx := context.Background() | ||||
| 	logger.Infof("Starting MidJourney job consumer for %s", s.name) | ||||
| 	for { | ||||
| 		_, err := s.redis.Get(ctx, RunningJobKey).Result() | ||||
| 		if err == nil { // 队列串行执行 | ||||
| 		s.checkTasks() | ||||
| 		if !s.canHandleTask() { | ||||
| 			// current service is full, can not handle more task | ||||
| 			// waiting for running task finish | ||||
| 			time.Sleep(time.Second * 3) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		var task types.MjTask | ||||
| 		err = s.taskQueue.LPop(&task) | ||||
| 		err := s.taskQueue.LPop(&task) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("taking task with error: %v", err) | ||||
| 			continue | ||||
| 		} | ||||
| 		logger.Infof("Consuming Task: %+v", task) | ||||
| 		switch task.Type { | ||||
| 		case types.TaskImage: | ||||
| 			err = s.client.Imagine(task.Prompt) | ||||
| 			break | ||||
| 		case types.TaskUpscale: | ||||
| 			err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash) | ||||
|  | ||||
| 			break | ||||
| 		case types.TaskVariation: | ||||
| 			err = s.client.Variation(task.Index, task.MessageId, task.MessageHash) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			logger.Error("绘画任务执行失败:", err) | ||||
| 			if task.RetryCount <= 5 { | ||||
| 				s.taskQueue.RPush(task) | ||||
| 			} | ||||
| 			task.RetryCount += 1 | ||||
| 			time.Sleep(time.Second * 3) | ||||
| 		// if it's reference message, check if it's this channel's  message | ||||
| 		if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId { | ||||
| 			s.taskQueue.RPush(task) | ||||
| 			time.Sleep(time.Second) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// 更新任务的执行状态 | ||||
| 		s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) | ||||
| 		// 锁定任务执行通道,直到任务超时(5分钟) | ||||
| 		s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5) | ||||
| 		logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) | ||||
| 		switch task.Type { | ||||
| 		case types.TaskImage: | ||||
| 			err = s.client.Imagine(task) | ||||
| 			break | ||||
| 		case types.TaskUpscale: | ||||
| 			err = s.client.Upscale(task) | ||||
| 			break | ||||
| 		case types.TaskVariation: | ||||
| 			err = s.client.Variation(task) | ||||
| 			break | ||||
| 		case types.TaskBlend: | ||||
| 			err = s.client.Blend(task) | ||||
| 			break | ||||
| 		case types.TaskSwapFace: | ||||
| 			err = s.client.SwapFace(task) | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			logger.Error("绘画任务执行失败:", err.Error()) | ||||
| 			// update the task progress | ||||
| 			s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ | ||||
| 				"progress": -1, | ||||
| 				"err_msg":  err.Error(), | ||||
| 			}) | ||||
| 			s.notifyQueue.RPush(task.UserId) | ||||
| 			// restore img_call quota | ||||
| 			if task.Type.String() != types.TaskUpscale.String() { | ||||
| 				s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		logger.Infof("Task Executed: %+v", task) | ||||
| 		// lock the task until the execute timeout | ||||
| 		s.taskStartTimes[int(task.Id)] = time.Now() | ||||
| 		atomic.AddInt32(&s.handledTaskNum, 1) | ||||
|  | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) PushTask(task types.MjTask) { | ||||
| 	logger.Infof("add a new MidJourney Task: %+v", task) | ||||
| 	s.taskQueue.RPush(task) | ||||
| // check if current service instance can handle more task | ||||
| func (s *Service) canHandleTask() bool { | ||||
| 	handledNum := atomic.LoadInt32(&s.handledTaskNum) | ||||
| 	return handledNum < s.maxHandleTaskNum | ||||
| } | ||||
|  | ||||
| // remove the expired tasks | ||||
| func (s *Service) checkTasks() { | ||||
| 	for k, t := range s.taskStartTimes { | ||||
| 		if time.Now().Unix()-t.Unix() > s.taskTimeout { | ||||
| 			delete(s.taskStartTimes, k) | ||||
| 			atomic.AddInt32(&s.handledTaskNum, -1) | ||||
| 			// delete task from database | ||||
| 			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) Notify(data CBReq) { | ||||
| 	taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result() | ||||
| 	if err != nil { // 过期任务,丢弃 | ||||
| 		logger.Warn("任务已过期:", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var task types.MjTask | ||||
| 	err = utils.JsonDecode(taskString, &task) | ||||
| 	if err != nil { // 非标准任务,丢弃 | ||||
| 		logger.Warn("任务解析失败:", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// extract the task ID | ||||
| 	split := strings.Split(data.Prompt, " ") | ||||
| 	var job model.MidJourneyJob | ||||
| 	res := s.db.Where("message_id = ?", data.MessageId).First(&job) | ||||
| 	if res.Error == nil && data.Status == Finished { | ||||
| @@ -113,137 +131,48 @@ func (s *Service) Notify(data CBReq) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if task.Src == types.TaskSrcImg { // 绘画任务 | ||||
| 		var job model.MidJourneyJob | ||||
| 		res := s.db.Where("id = ?", task.Id).First(&job) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Warn("非法任务:", res.Error) | ||||
| 			return | ||||
| 		} | ||||
| 		job.MessageId = data.MessageId | ||||
| 		job.ReferenceId = data.ReferenceId | ||||
| 		job.Progress = data.Progress | ||||
| 		job.Prompt = data.Prompt | ||||
| 		job.Hash = data.Image.Hash | ||||
|  | ||||
| 		// 任务完成,将最终的图片下载下来 | ||||
| 		if data.Progress == 100 { | ||||
| 			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) | ||||
| 			if err != nil { | ||||
| 				logger.Error("error with download img: ", err.Error()) | ||||
| 				return | ||||
| 			} | ||||
| 			job.ImgURL = imgURL | ||||
| 		} else { | ||||
| 			// 临时图片直接保存,访问的时候使用代理进行转发 | ||||
| 			job.ImgURL = data.Image.URL | ||||
| 		} | ||||
| 		res = s.db.Updates(&job) | ||||
| 		if res.Error != nil { | ||||
| 			logger.Error("error with update job: ", res.Error) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var jobVo vo.MidJourneyJob | ||||
| 		err := utils.CopyObject(job, &jobVo) | ||||
| 		if err == nil { | ||||
| 			if data.Progress < 100 { | ||||
| 				image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL) | ||||
| 				if err == nil { | ||||
| 					jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			// 推送任务到前端 | ||||
| 			client := s.Clients.Get(task.SessionId) | ||||
| 			if client != nil { | ||||
| 				utils.ReplyChunkMessage(client, jobVo) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	} else if task.Src == types.TaskSrcChat { // 聊天任务 | ||||
| 		wsClient := s.ChatClients.Get(task.SessionId) | ||||
| 		if data.Status == Finished { | ||||
| 			if wsClient != nil && data.ReferenceId != "" { | ||||
| 				content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) | ||||
| 				utils.ReplyMessage(wsClient, content) | ||||
| 			} | ||||
| 			// download image | ||||
| 			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) | ||||
| 			if err != nil { | ||||
| 				logger.Error("error with download image: ", err) | ||||
| 				if wsClient != nil && data.ReferenceId != "" { | ||||
| 					content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) | ||||
| 					utils.ReplyMessage(wsClient, content) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			tx := s.db.Begin() | ||||
| 			data.Image.URL = imgURL | ||||
| 			message := model.HistoryMessage{ | ||||
| 				UserId:     uint(task.UserId), | ||||
| 				ChatId:     task.ChatId, | ||||
| 				RoleId:     uint(task.RoleId), | ||||
| 				Type:       types.MjMsg, | ||||
| 				Icon:       task.Icon, | ||||
| 				Content:    utils.JsonEncode(data), | ||||
| 				Tokens:     0, | ||||
| 				UseContext: false, | ||||
| 			} | ||||
| 			res = tx.Create(&message) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("error with update database: ", err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// save the job | ||||
| 			job.UserId = task.UserId | ||||
| 			job.Type = task.Type.String() | ||||
| 			job.MessageId = data.MessageId | ||||
| 			job.ReferenceId = data.ReferenceId | ||||
| 			job.Prompt = data.Prompt | ||||
| 			job.ImgURL = imgURL | ||||
| 			job.Progress = data.Progress | ||||
| 			job.Hash = data.Image.Hash | ||||
| 			job.CreatedAt = time.Now() | ||||
| 			res = tx.Create(&job) | ||||
| 			if res.Error != nil { | ||||
| 				logger.Error("error with update database: ", err) | ||||
| 				tx.Rollback() | ||||
| 				return | ||||
| 			} | ||||
| 			tx.Commit() | ||||
| 		} | ||||
|  | ||||
| 		if wsClient == nil { // 客户端断线,则丢弃 | ||||
| 			logger.Errorf("Client is offline: %+v", data) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if data.Status == Finished { | ||||
| 			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) | ||||
| 			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) | ||||
| 			// 本次绘画完毕,移除客户端 | ||||
| 			s.ChatClients.Delete(task.SessionId) | ||||
| 		} else { | ||||
| 			// 使用代理临时转发图片 | ||||
| 			if data.Image.URL != "" { | ||||
| 				image, err := utils.DownloadImage(data.Image.URL, s.proxyURL) | ||||
| 				if err == nil { | ||||
| 					data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) | ||||
| 				} | ||||
| 			} | ||||
| 			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) | ||||
| 		} | ||||
| 	tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC") | ||||
| 	if data.ReferenceId != "" { | ||||
| 		tx = tx.Where("reference_id = ?", data.ReferenceId) | ||||
| 	} else { | ||||
| 		tx = tx.Where("task_id = ?", split[0]) | ||||
| 	} | ||||
| 	// fixed: 修复 U/V 操作任务混淆覆盖的 Bug | ||||
| 	if strings.Contains(data.Prompt, "** - Image #") { // for upscale | ||||
| 		tx = tx.Where("type = ?", types.TaskUpscale.String()) | ||||
| 	} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations | ||||
| 		tx = tx.Where("type = ?", types.TaskVariation.String()) | ||||
| 	} | ||||
| 	res = tx.First(&job) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Warn("非法任务:", res.Error) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	job.ChannelId = data.ChannelId | ||||
| 	job.MessageId = data.MessageId | ||||
| 	job.ReferenceId = data.ReferenceId | ||||
| 	job.Progress = data.Progress | ||||
| 	job.Prompt = data.Prompt | ||||
| 	job.Hash = data.Image.Hash | ||||
| 	if s.client.Config.UseCDN { | ||||
| 		job.UseProxy = true | ||||
| 		job.OrgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL) | ||||
| 	} else { | ||||
| 		job.OrgURL = data.Image.URL | ||||
| 	} | ||||
|  | ||||
| 	res = s.db.Updates(&job) | ||||
| 	if res.Error != nil { | ||||
| 		logger.Error("error with update job: ", res.Error) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 更新用户剩余绘图次数 | ||||
| 	// TODO: 放大图片是否需要消耗绘图次数? | ||||
| 	if data.Status == Finished { | ||||
| 		s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) | ||||
| 		// 解除任务锁定 | ||||
| 		s.redis.Del(context.Background(), RunningJobKey) | ||||
| 		// release lock task | ||||
| 		atomic.AddInt32(&s.handledTaskNum, -1) | ||||
| 	} | ||||
|  | ||||
| 	s.notifyQueue.RPush(job.UserId) | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -8,8 +8,8 @@ const ( | ||||
| type InteractionsRequest struct { | ||||
| 	Type          int            `json:"type"` | ||||
| 	ApplicationID string         `json:"application_id"` | ||||
| 	MessageFlags  *int           `json:"message_flags,omitempty"` | ||||
| 	MessageID     *string        `json:"message_id,omitempty"` | ||||
| 	MessageFlags  int            `json:"message_flags,omitempty"` | ||||
| 	MessageID     string         `json:"message_id,omitempty"` | ||||
| 	GuildID       string         `json:"guild_id"` | ||||
| 	ChannelID     string         `json:"channel_id"` | ||||
| 	SessionID     string         `json:"session_id"` | ||||
| @@ -24,6 +24,7 @@ type InteractionsResult struct { | ||||
| } | ||||
|  | ||||
| type CBReq struct { | ||||
| 	ChannelId   string     `json:"channel_id"` | ||||
| 	MessageId   string     `json:"message_id"` | ||||
| 	ReferenceId string     `json:"reference_id"` | ||||
| 	Image       Image      `json:"image"` | ||||
|   | ||||
| @@ -5,11 +5,13 @@ import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"github.com/aliyun/aliyun-oss-go-sdk/oss" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/aliyun/aliyun-oss-go-sdk/oss" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type AliYunOss struct { | ||||
| @@ -32,6 +34,10 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if config.SubDir == "" { | ||||
| 		config.SubDir = "gpt" | ||||
| 	} | ||||
|  | ||||
| 	return &AliYunOss{ | ||||
| 		config:   config, | ||||
| 		bucket:   bucket, | ||||
| @@ -40,28 +46,34 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) { | ||||
| func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
| 	// 解析表单 | ||||
| 	file, err := ctx.FormFile(name) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return File{}, err | ||||
| 	} | ||||
| 	// 打开上传文件 | ||||
| 	src, err := file.Open() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return File{}, err | ||||
| 	} | ||||
| 	defer src.Close() | ||||
|  | ||||
| 	fileExt := filepath.Ext(file.Filename) | ||||
| 	objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) | ||||
| 	objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	// 上传文件 | ||||
| 	err = s.bucket.PutObject(objectKey, src) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return File{}, err | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil | ||||
| 	return File{ | ||||
| 		Name:   file.Filename, | ||||
| 		ObjKey: objectKey, | ||||
| 		URL:    fmt.Sprintf("%s/%s", s.config.Domain, objectKey), | ||||
| 		Ext:    fileExt, | ||||
| 		Size:   file.Size, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| @@ -79,19 +91,25 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| 	fileExt := filepath.Ext(parse.Path) | ||||
| 	objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) | ||||
| 	fileExt := utils.GetImgExt(parse.Path) | ||||
| 	objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	// 上传文件字节数据 | ||||
| 	err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil | ||||
| 	return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil | ||||
| } | ||||
|  | ||||
| func (s AliYunOss) Delete(fileURL string) error { | ||||
| 	objectName := filepath.Base(fileURL) | ||||
| 	return s.bucket.DeleteObject(objectName) | ||||
| 	var objectKey string | ||||
| 	if strings.HasPrefix(fileURL, "http") { | ||||
| 		filename := filepath.Base(fileURL) | ||||
| 		objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename) | ||||
| 	} else { | ||||
| 		objectKey = fileURL | ||||
| 	} | ||||
| 	return s.bucket.DeleteObject(objectKey) | ||||
| } | ||||
|  | ||||
| var _ Uploader = AliYunOss{} | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type LocalStorage struct { | ||||
| @@ -23,23 +24,30 @@ func NewLocalStorage(config *types.AppConfig) LocalStorage { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) { | ||||
| func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
| 	file, err := ctx.FormFile(name) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with get form: %v", err) | ||||
| 		return File{}, fmt.Errorf("error with get form: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	filePath, err := utils.GenUploadPath(s.config.BasePath, file.Filename) | ||||
| 	path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with generate filename: %s", err.Error()) | ||||
| 		return File{}, fmt.Errorf("error with generate filename: %s", err.Error()) | ||||
| 	} | ||||
| 	// 将文件保存到指定路径 | ||||
| 	err = ctx.SaveUploadedFile(file, filePath) | ||||
| 	err = ctx.SaveUploadedFile(file, path) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with save upload file: %s", err.Error()) | ||||
| 		return File{}, fmt.Errorf("error with save upload file: %s", err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil | ||||
| 	ext := filepath.Ext(file.Filename) | ||||
| 	return File{ | ||||
| 		Name:   file.Filename, | ||||
| 		ObjKey: path, | ||||
| 		URL:    utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path), | ||||
| 		Ext:    ext, | ||||
| 		Size:   file.Size, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| @@ -48,7 +56,7 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| 	filename := filepath.Base(parse.Path) | ||||
| 	filePath, err := utils.GenUploadPath(s.config.BasePath, filename) | ||||
| 	filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with generate image dir: %v", err) | ||||
| 	} | ||||
| @@ -66,6 +74,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| } | ||||
|  | ||||
| func (s LocalStorage) Delete(fileURL string) error { | ||||
| 	if _, err := os.Stat(fileURL); err == nil { | ||||
| 		return os.Remove(fileURL) | ||||
| 	} | ||||
| 	filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1) | ||||
| 	return os.Remove(filePath) | ||||
| } | ||||
|   | ||||
| @@ -5,13 +5,14 @@ import ( | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/minio/minio-go/v7" | ||||
| 	"github.com/minio/minio-go/v7/pkg/credentials" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/minio/minio-go/v7" | ||||
| 	"github.com/minio/minio-go/v7/pkg/credentials" | ||||
| ) | ||||
|  | ||||
| type MiniOss struct { | ||||
| @@ -29,6 +30,9 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) { | ||||
| 	if err != nil { | ||||
| 		return MiniOss{}, err | ||||
| 	} | ||||
| 	if config.SubDir == "" { | ||||
| 		config.SubDir = "gpt" | ||||
| 	} | ||||
| 	return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil | ||||
| } | ||||
|  | ||||
| @@ -48,7 +52,7 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| 	fileExt := filepath.Ext(parse.Path) | ||||
| 	filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) | ||||
| 	filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	info, err := s.client.PutObject( | ||||
| 		context.Background(), | ||||
| 		s.config.Bucket, | ||||
| @@ -62,33 +66,45 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil | ||||
| } | ||||
|  | ||||
| func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) { | ||||
| func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
| 	file, err := ctx.FormFile(name) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with get form: %v", err) | ||||
| 		return File{}, fmt.Errorf("error with get form: %v", err) | ||||
| 	} | ||||
| 	// Open the uploaded file | ||||
| 	fileReader, err := file.Open() | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error opening file: %v", err) | ||||
| 		return File{}, fmt.Errorf("error opening file: %v", err) | ||||
| 	} | ||||
| 	defer fileReader.Close() | ||||
|  | ||||
| 	fileExt := filepath.Ext(file.Filename) | ||||
| 	filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) | ||||
| 	fileExt := utils.GetImgExt(file.Filename) | ||||
| 	filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ | ||||
| 		ContentType: file.Header.Get("Content-Type"), | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error uploading to MinIO: %v", err) | ||||
| 		return File{}, fmt.Errorf("error uploading to MinIO: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil | ||||
| 	return File{ | ||||
| 		Name:   file.Filename, | ||||
| 		ObjKey: info.Key, | ||||
| 		URL:    fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), | ||||
| 		Ext:    fileExt, | ||||
| 		Size:   file.Size, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s MiniOss) Delete(fileURL string) error { | ||||
| 	objectName := filepath.Base(fileURL) | ||||
| 	return s.client.RemoveObject(context.Background(), s.config.Bucket, objectName, minio.RemoveObjectOptions{}) | ||||
| 	var objectKey string | ||||
| 	if strings.HasPrefix(fileURL, "http") { | ||||
| 		filename := filepath.Base(fileURL) | ||||
| 		objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename) | ||||
| 	} else { | ||||
| 		objectKey = fileURL | ||||
| 	} | ||||
| 	return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{}) | ||||
| } | ||||
|  | ||||
| var _ Uploader = MiniOss{} | ||||
|   | ||||
| @@ -6,21 +6,23 @@ import ( | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/qiniu/go-sdk/v7/auth/qbox" | ||||
| 	"github.com/qiniu/go-sdk/v7/storage" | ||||
| 	"net/url" | ||||
| 	"path/filepath" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type QinNiuOss struct { | ||||
| 	config   *types.QiNiuOssConfig | ||||
| 	token    string | ||||
| 	uploader *storage.FormUploader | ||||
| 	manager  *storage.BucketManager | ||||
| 	proxyURL string | ||||
| 	dir      string | ||||
| 	config    *types.QiNiuOssConfig | ||||
| 	mac       *qbox.Mac | ||||
| 	putPolicy storage.PutPolicy | ||||
| 	uploader  *storage.FormUploader | ||||
| 	manager   *storage.BucketManager | ||||
| 	proxyURL  string | ||||
| } | ||||
|  | ||||
| func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { | ||||
| @@ -37,40 +39,50 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { | ||||
| 	putPolicy := storage.PutPolicy{ | ||||
| 		Scope: config.Bucket, | ||||
| 	} | ||||
| 	if config.SubDir == "" { | ||||
| 		config.SubDir = "gpt" | ||||
| 	} | ||||
| 	return QinNiuOss{ | ||||
| 		config:   config, | ||||
| 		token:    putPolicy.UploadToken(mac), | ||||
| 		uploader: formUploader, | ||||
| 		manager:  storage.NewBucketManager(mac, &storeConfig), | ||||
| 		proxyURL: appConfig.ProxyURL, | ||||
| 		dir:      "chatgpt-plus", | ||||
| 		config:    config, | ||||
| 		mac:       mac, | ||||
| 		putPolicy: putPolicy, | ||||
| 		uploader:  formUploader, | ||||
| 		manager:   storage.NewBucketManager(mac, &storeConfig), | ||||
| 		proxyURL:  appConfig.ProxyURL, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) { | ||||
| func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) { | ||||
| 	// 解析表单 | ||||
| 	file, err := ctx.FormFile(name) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return File{}, err | ||||
| 	} | ||||
| 	// 打开上传文件 | ||||
| 	src, err := file.Open() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return File{}, err | ||||
| 	} | ||||
| 	defer src.Close() | ||||
|  | ||||
| 	fileExt := filepath.Ext(file.Filename) | ||||
| 	key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt) | ||||
| 	key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	// 上传文件 | ||||
| 	ret := storage.PutRet{} | ||||
| 	extra := storage.PutExtra{} | ||||
| 	err = s.uploader.Put(ctx, &ret, s.token, key, src, file.Size, &extra) | ||||
| 	err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return File{}, err | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil | ||||
| 	return File{ | ||||
| 		Name:   file.Filename, | ||||
| 		ObjKey: key, | ||||
| 		URL:    fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), | ||||
| 		Ext:    fileExt, | ||||
| 		Size:   file.Size, | ||||
| 	}, nil | ||||
|  | ||||
| } | ||||
|  | ||||
| func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| @@ -88,12 +100,12 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error with parse image URL: %v", err) | ||||
| 	} | ||||
| 	fileExt := filepath.Ext(parse.Path) | ||||
| 	key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt) | ||||
| 	fileExt := utils.GetImgExt(parse.Path) | ||||
| 	key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) | ||||
| 	ret := storage.PutRet{} | ||||
| 	extra := storage.PutExtra{} | ||||
| 	// 上传文件字节数据 | ||||
| 	err = s.uploader.Put(context.Background(), &ret, s.token, key, bytes.NewReader(imageData), int64(len(imageData)), &extra) | ||||
| 	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| @@ -101,9 +113,15 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { | ||||
| } | ||||
|  | ||||
| func (s QinNiuOss) Delete(fileURL string) error { | ||||
| 	objectName := filepath.Base(fileURL) | ||||
| 	key := fmt.Sprintf("%s/%s", s.dir, objectName) | ||||
| 	return s.manager.Delete(s.config.Bucket, key) | ||||
| 	var objectKey string | ||||
| 	if strings.HasPrefix(fileURL, "http") { | ||||
| 		filename := filepath.Base(fileURL) | ||||
| 		objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename) | ||||
| 	} else { | ||||
| 		objectKey = fileURL | ||||
| 	} | ||||
|  | ||||
| 	return s.manager.Delete(s.config.Bucket, objectKey) | ||||
| } | ||||
|  | ||||
| var _ Uploader = QinNiuOss{} | ||||
|   | ||||
| @@ -2,8 +2,20 @@ package oss | ||||
|  | ||||
| import "github.com/gin-gonic/gin" | ||||
|  | ||||
| const Local = "LOCAL" | ||||
| const Minio = "MINIO" | ||||
| const QiNiu = "QINIU" | ||||
| const AliYun = "ALIYUN" | ||||
|  | ||||
| type File struct { | ||||
| 	Name   string `json:"name"` | ||||
| 	ObjKey string `json:"obj_key"` | ||||
| 	Size   int64  `json:"size"` | ||||
| 	URL    string `json:"url"` | ||||
| 	Ext    string `json:"ext"` | ||||
| } | ||||
| type Uploader interface { | ||||
| 	PutFile(ctx *gin.Context, name string) (string, error) | ||||
| 	PutFile(ctx *gin.Context, name string) (File, error) | ||||
| 	PutImg(imageURL string, useProxy bool) (string, error) | ||||
| 	Delete(fileURL string) error | ||||
| } | ||||
|   | ||||
| @@ -9,11 +9,6 @@ type UploaderManager struct { | ||||
| 	handler Uploader | ||||
| } | ||||
|  | ||||
| const Local = "LOCAL" | ||||
| const Minio = "MINIO" | ||||
| const QiNiu = "QINIU" | ||||
| const AliYun = "ALIYUN" | ||||
|  | ||||
| func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) { | ||||
| 	active := Local | ||||
| 	if config.OSS.Active != "" { | ||||
|   | ||||
							
								
								
									
										142
									
								
								api/service/payment/alipay_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								api/service/payment/alipay_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | ||||
| package payment | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"fmt" | ||||
| 	"github.com/smartwalle/alipay/v3" | ||||
| 	"log" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| type AlipayService struct { | ||||
| 	config *types.AlipayConfig | ||||
| 	client *alipay.Client | ||||
| } | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) { | ||||
| 	config := appConfig.AlipayConfig | ||||
| 	if !config.Enabled { | ||||
| 		logger.Info("Disabled Alipay service") | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	priKey, err := readKey(config.PrivateKey) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error with read App Private key: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	xClient, err := alipay.New(config.AppId, priKey, !config.SandBox) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error with initialize alipay service: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil { | ||||
| 		return nil, fmt.Errorf("error with loading App PublicKey: %v", err) | ||||
| 	} | ||||
| 	if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil { | ||||
| 		return nil, fmt.Errorf("error with loading alipay RootCert: %v", err) | ||||
| 	} | ||||
| 	if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil { | ||||
| 		return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return &AlipayService{config: &config, client: xClient}, nil | ||||
| } | ||||
|  | ||||
| func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) { | ||||
| 	var p = alipay.TradeWapPay{} | ||||
| 	p.NotifyURL = notifyURL | ||||
| 	p.ReturnURL = returnURL | ||||
| 	p.Subject = subject | ||||
| 	p.OutTradeNo = outTradeNo | ||||
| 	p.TotalAmount = Amount | ||||
| 	p.ProductCode = "QUICK_WAP_WAY" | ||||
| 	res, err := s.client.TradeWapPay(p) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return res.String(), err | ||||
| } | ||||
|  | ||||
| func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) { | ||||
| 	var p = alipay.TradePagePay{} | ||||
| 	p.NotifyURL = notifyURL | ||||
| 	p.ReturnURL = returnURL | ||||
| 	p.Subject = subject | ||||
| 	p.OutTradeNo = outTradeNo | ||||
| 	p.TotalAmount = amount | ||||
| 	p.ProductCode = "FAST_INSTANT_TRADE_PAY" | ||||
| 	res, err := s.client.TradePagePay(p) | ||||
| 	if err != nil { | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	return res.String(), err | ||||
| } | ||||
|  | ||||
| // TradeVerify 交易验证 | ||||
| func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo { | ||||
| 	err := s.client.VerifySign(reqForm) | ||||
| 	if err != nil { | ||||
| 		log.Println("异步通知验证签名发生错误", err) | ||||
| 		return NotifyVo{ | ||||
| 			Status:  0, | ||||
| 			Message: "异步通知验证签名发生错误", | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s.TradeQuery(reqForm.Get("out_trade_no")) | ||||
| } | ||||
|  | ||||
| func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo { | ||||
| 	var p = alipay.TradeQuery{} | ||||
| 	p.OutTradeNo = outTradeNo | ||||
| 	rsp, err := s.client.TradeQuery(p) | ||||
| 	if err != nil { | ||||
| 		return NotifyVo{ | ||||
| 			Status:  0, | ||||
| 			Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" { | ||||
| 		return NotifyVo{ | ||||
| 			Status:     1, | ||||
| 			OutTradeNo: rsp.OutTradeNo, | ||||
| 			TradeNo:    rsp.TradeNo, | ||||
| 			Amount:     rsp.TotalAmount, | ||||
| 			Subject:    rsp.Subject, | ||||
| 			Message:    "OK", | ||||
| 		} | ||||
| 	} else { | ||||
| 		return NotifyVo{ | ||||
| 			Status:  0, | ||||
| 			Message: "异步查询验证订单信息发生错误" + outTradeNo, | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func readKey(filename string) (string, error) { | ||||
| 	data, err := os.ReadFile(filename) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return string(data), nil | ||||
| } | ||||
|  | ||||
| type NotifyVo struct { | ||||
| 	Status     int | ||||
| 	OutTradeNo string | ||||
| 	TradeNo    string | ||||
| 	Amount     string | ||||
| 	Message    string | ||||
| 	Subject    string | ||||
| } | ||||
|  | ||||
| func (v NotifyVo) Success() bool { | ||||
| 	return v.Status == 1 | ||||
| } | ||||
							
								
								
									
										162
									
								
								api/service/payment/hupipay_serive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								api/service/payment/hupipay_serive.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | ||||
| package payment | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"crypto/md5" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type HuPiPayService struct { | ||||
| 	appId     string | ||||
| 	appSecret string | ||||
| 	apiURL    string | ||||
| } | ||||
|  | ||||
| func NewHuPiPay(config *types.AppConfig) *HuPiPayService { | ||||
| 	return &HuPiPayService{ | ||||
| 		appId:     config.HuPiPayConfig.AppId, | ||||
| 		appSecret: config.HuPiPayConfig.AppSecret, | ||||
| 		apiURL:    config.HuPiPayConfig.ApiURL, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type HuPiPayReq struct { | ||||
| 	AppId        string `json:"appid"` | ||||
| 	Version      string `json:"version"` | ||||
| 	TradeOrderId string `json:"trade_order_id"` | ||||
| 	TotalFee     string `json:"total_fee"` | ||||
| 	Title        string `json:"title"` | ||||
| 	NotifyURL    string `json:"notify_url"` | ||||
| 	ReturnURL    string `json:"return_url"` | ||||
| 	WapName      string `json:"wap_name"` | ||||
| 	CallbackURL  string `json:"callback_url"` | ||||
| 	Time         string `json:"time"` | ||||
| 	NonceStr     string `json:"nonce_str"` | ||||
| } | ||||
|  | ||||
| type HuPiResp struct { | ||||
| 	Openid    interface{} `json:"openid"` | ||||
| 	UrlQrcode string      `json:"url_qrcode"` | ||||
| 	URL       string      `json:"url"` | ||||
| 	ErrCode   int         `json:"errcode"` | ||||
| 	ErrMsg    string      `json:"errmsg,omitempty"` | ||||
| } | ||||
|  | ||||
| // Pay 执行支付请求操作 | ||||
| func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) { | ||||
| 	data := url.Values{} | ||||
| 	simple := strconv.FormatInt(time.Now().Unix(), 10) | ||||
| 	params.AppId = s.appId | ||||
| 	params.Time = simple | ||||
| 	params.NonceStr = simple | ||||
| 	encode := utils.JsonEncode(params) | ||||
| 	m := make(map[string]string) | ||||
| 	_ = utils.JsonDecode(encode, &m) | ||||
| 	for k, v := range m { | ||||
| 		data.Add(k, fmt.Sprintf("%v", v)) | ||||
| 	} | ||||
| 	// 生成签名 | ||||
| 	data.Add("hash", s.Sign(data)) | ||||
| 	// 发送支付请求 | ||||
| 	apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL) | ||||
| 	resp, err := http.PostForm(apiURL, data) | ||||
| 	if err != nil { | ||||
| 		return HuPiResp{}, fmt.Errorf("error with requst api: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	all, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return HuPiResp{}, fmt.Errorf("error with reading response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	var res HuPiResp | ||||
| 	err = utils.JsonDecode(string(all), &res) | ||||
| 	if err != nil { | ||||
| 		return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if res.ErrCode != 0 { | ||||
| 		return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| // Sign 签名方法 | ||||
| func (s *HuPiPayService) Sign(params url.Values) string { | ||||
| 	params.Del(`Sign`) | ||||
| 	var keys = make([]string, 0, 0) | ||||
| 	for key := range params { | ||||
| 		if params.Get(key) != `` { | ||||
| 			keys = append(keys, key) | ||||
| 		} | ||||
| 	} | ||||
| 	sort.Strings(keys) | ||||
|  | ||||
| 	var pList = make([]string, 0, 0) | ||||
| 	for _, key := range keys { | ||||
| 		var value = strings.TrimSpace(params.Get(key)) | ||||
| 		if len(value) > 0 { | ||||
| 			pList = append(pList, key+"="+value) | ||||
| 		} | ||||
| 	} | ||||
| 	var src = strings.Join(pList, "&") | ||||
| 	src += s.appSecret | ||||
|  | ||||
| 	md5bs := md5.Sum([]byte(src)) | ||||
| 	return hex.EncodeToString(md5bs[:]) | ||||
| } | ||||
|  | ||||
| // Check 校验订单状态 | ||||
| func (s *HuPiPayService) Check(tradeNo string) error { | ||||
| 	data := url.Values{} | ||||
| 	data.Add("appid", s.appId) | ||||
| 	data.Add("open_order_id", tradeNo) | ||||
| 	stamp := strconv.FormatInt(time.Now().Unix(), 10) | ||||
| 	data.Add("time", stamp) | ||||
| 	data.Add("nonce_str", stamp) | ||||
| 	data.Add("hash", s.Sign(data)) | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL) | ||||
| 	resp, err := http.PostForm(apiURL, data) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with http reqeust: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with reading response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	var r struct { | ||||
| 		ErrCode int `json:"errcode"` | ||||
| 		Data    struct { | ||||
| 			Status      string `json:"status"` | ||||
| 			OpenOrderId string `json:"open_order_id"` | ||||
| 		} `json:"data,omitempty"` | ||||
| 		ErrMsg string `json:"errmsg"` | ||||
| 		Hash   string `json:"hash"` | ||||
| 	} | ||||
| 	err = utils.JsonDecode(string(body), &r) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with decode response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.ErrCode == 0 && r.Data.Status == "OD" { | ||||
| 		return nil | ||||
| 	} else { | ||||
| 		logger.Debugf("%+v", r) | ||||
| 		return errors.New("order not paid:" + r.ErrMsg) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										148
									
								
								api/service/payment/payjs_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								api/service/payment/payjs_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,148 @@ | ||||
| package payment | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"crypto/md5" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type PayJS struct { | ||||
| 	config *types.JPayConfig | ||||
| } | ||||
|  | ||||
| func NewPayJS(appConfig *types.AppConfig) *PayJS { | ||||
| 	return &PayJS{ | ||||
| 		config: &appConfig.JPayConfig, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type JPayReq struct { | ||||
| 	TotalFee   int    `json:"total_fee"` | ||||
| 	OutTradeNo string `json:"out_trade_no"` | ||||
| 	Subject    string `json:"body"` | ||||
| 	NotifyURL  string `json:"notify_url"` | ||||
| 	ReturnURL  string `json:"callback_url"` | ||||
| } | ||||
| type JPayReps struct { | ||||
| 	OutTradeNo string `json:"out_trade_no"` | ||||
| 	OrderId    string `json:"payjs_order_id"` | ||||
| 	ReturnCode int    `json:"return_code"` | ||||
| 	ReturnMsg  string `json:"return_msg"` | ||||
| 	Sign       string `json:"Sign"` | ||||
| 	TotalFee   string `json:"total_fee"` | ||||
| 	CodeUrl    string `json:"code_url,omitempty"` | ||||
| 	Qrcode     string `json:"qrcode,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r JPayReps) IsOK() bool { | ||||
| 	return r.ReturnMsg == "SUCCESS" | ||||
| } | ||||
|  | ||||
| func (js *PayJS) Pay(param JPayReq) JPayReps { | ||||
| 	param.NotifyURL = js.config.NotifyURL | ||||
| 	var p = url.Values{} | ||||
| 	encode := utils.JsonEncode(param) | ||||
| 	m := make(map[string]interface{}) | ||||
| 	_ = utils.JsonDecode(encode, &m) | ||||
| 	for k, v := range m { | ||||
| 		p.Add(k, fmt.Sprintf("%v", v)) | ||||
| 	} | ||||
| 	p.Add("mchid", js.config.AppId) | ||||
|  | ||||
| 	p.Add("sign", js.sign(p)) | ||||
|  | ||||
| 	cli := http.Client{} | ||||
| 	apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL) | ||||
| 	r, err := cli.PostForm(apiURL, p) | ||||
| 	if err != nil { | ||||
| 		return JPayReps{ReturnMsg: err.Error()} | ||||
| 	} | ||||
| 	defer r.Body.Close() | ||||
| 	bs, err := io.ReadAll(r.Body) | ||||
| 	if err != nil { | ||||
| 		return JPayReps{ReturnMsg: err.Error()} | ||||
| 	} | ||||
|  | ||||
| 	var data JPayReps | ||||
| 	err = utils.JsonDecode(string(bs), &data) | ||||
| 	if err != nil { | ||||
| 		return JPayReps{ReturnMsg: err.Error()} | ||||
| 	} | ||||
| 	return data | ||||
| } | ||||
|  | ||||
| func (js *PayJS) PayH5(p url.Values) string { | ||||
| 	p.Add("mchid", js.config.AppId) | ||||
| 	p.Add("sign", js.sign(p)) | ||||
| 	return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode()) | ||||
| } | ||||
|  | ||||
| func (js *PayJS) sign(params url.Values) string { | ||||
| 	params.Del(`sign`) | ||||
| 	var keys = make([]string, 0, 0) | ||||
| 	for key := range params { | ||||
| 		if params.Get(key) != `` { | ||||
| 			keys = append(keys, key) | ||||
| 		} | ||||
| 	} | ||||
| 	sort.Strings(keys) | ||||
|  | ||||
| 	var pList = make([]string, 0, 0) | ||||
| 	for _, key := range keys { | ||||
| 		var value = strings.TrimSpace(params.Get(key)) | ||||
| 		if len(value) > 0 { | ||||
| 			pList = append(pList, key+"="+value) | ||||
| 		} | ||||
| 	} | ||||
| 	var src = strings.Join(pList, "&") | ||||
| 	src += "&key=" + js.config.PrivateKey | ||||
|  | ||||
| 	md5bs := md5.Sum([]byte(src)) | ||||
| 	md5res := hex.EncodeToString(md5bs[:]) | ||||
| 	return strings.ToUpper(md5res) | ||||
| } | ||||
|  | ||||
| // Check 查询订单支付状态 | ||||
| // @param tradeNo 支付平台交易 ID | ||||
| func (js *PayJS) Check(tradeNo string) error { | ||||
| 	apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL) | ||||
| 	params := url.Values{} | ||||
| 	params.Add("payjs_order_id", tradeNo) | ||||
| 	params.Add("sign", js.sign(params)) | ||||
| 	data := strings.NewReader(params.Encode()) | ||||
| 	resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data) | ||||
| 	defer resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with http reqeust: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with reading response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	var r struct { | ||||
| 		ReturnCode int `json:"return_code"` | ||||
| 		Status     int `json:"status"` | ||||
| 	} | ||||
| 	err = utils.JsonDecode(string(body), &r) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error with decode response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if r.ReturnCode == 1 && r.Status == 1 { | ||||
| 		return nil | ||||
| 	} else { | ||||
| 		logger.Errorf("PayJs 支付验证响应:%s", string(body)) | ||||
| 		return errors.New("order not paid") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										122
									
								
								api/service/sd/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								api/service/sd/pool.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,122 @@ | ||||
| package sd | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type ServicePool struct { | ||||
| 	services    []*Service | ||||
| 	taskQueue   *store.RedisQueue | ||||
| 	notifyQueue *store.RedisQueue | ||||
| 	db          *gorm.DB | ||||
| 	Clients     *types.LMap[uint, *types.WsClient] // UserId => Client | ||||
| } | ||||
|  | ||||
| func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { | ||||
| 	services := make([]*Service, 0) | ||||
| 	taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) | ||||
| 	notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) | ||||
| 	// create mj client and service | ||||
| 	for k, config := range appConfig.SdConfigs { | ||||
| 		if config.Enabled == false { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// create sd service | ||||
| 		name := fmt.Sprintf("StableDifffusion Service-%d", k) | ||||
| 		service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager) | ||||
| 		// run sd service | ||||
| 		go func() { | ||||
| 			service.Run() | ||||
| 		}() | ||||
|  | ||||
| 		services = append(services, service) | ||||
| 	} | ||||
|  | ||||
| 	return &ServicePool{ | ||||
| 		taskQueue:   taskQueue, | ||||
| 		notifyQueue: notifyQueue, | ||||
| 		services:    services, | ||||
| 		db:          db, | ||||
| 		Clients:     types.NewLMap[uint, *types.WsClient](), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // PushTask push a new mj task in to task queue | ||||
| func (p *ServicePool) PushTask(task types.SdTask) { | ||||
| 	logger.Debugf("add a new MidJourney task to the task list: %+v", task) | ||||
| 	p.taskQueue.RPush(task) | ||||
| } | ||||
|  | ||||
| func (p *ServicePool) CheckTaskNotify() { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var userId uint | ||||
| 			err := p.notifyQueue.LPop(&userId) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			client := p.Clients.Get(userId) | ||||
| 			if client == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			err = client.Send([]byte("Task Updated")) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| // CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 | ||||
| func (p *ServicePool) CheckTaskStatus() { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var jobs []model.SdJob | ||||
| 			res := p.db.Where("progress < ?", 100).Find(&jobs) | ||||
| 			if res.Error != nil { | ||||
| 				time.Sleep(5 * time.Second) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			for _, job := range jobs { | ||||
| 				// 5 分钟还没完成的任务直接删除 | ||||
| 				if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 { | ||||
| 					p.db.Delete(&job) | ||||
| 					var user model.User | ||||
| 					p.db.Where("id = ?", job.UserId).First(&user) | ||||
| 					// 退回绘图次数 | ||||
| 					res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) | ||||
| 					if res.Error == nil && res.RowsAffected > 0 { | ||||
| 						p.db.Create(&model.PowerLog{ | ||||
| 							UserId:    user.Id, | ||||
| 							Username:  user.Username, | ||||
| 							Type:      types.PowerConsume, | ||||
| 							Amount:    job.Power, | ||||
| 							Balance:   user.Power + job.Power, | ||||
| 							Mark:      types.PowerAdd, | ||||
| 							Model:     "stable-diffusion", | ||||
| 							Remark:    fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId), | ||||
| 							CreatedAt: time.Now(), | ||||
| 						}) | ||||
| 					} | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| // HasAvailableService check if it has available mj service in pool | ||||
| func (p *ServicePool) HasAvailableService() bool { | ||||
| 	return len(p.services) > 0 | ||||
| } | ||||
| @@ -5,84 +5,104 @@ import ( | ||||
| 	"chatplus/service/oss" | ||||
| 	"chatplus/store" | ||||
| 	"chatplus/store/model" | ||||
| 	"chatplus/store/vo" | ||||
| 	"chatplus/utils" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| 	"io" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/imroc/req/v3" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // SD 绘画服务 | ||||
|  | ||||
| const RunningJobKey = "StableDiffusion_Running_Job" | ||||
|  | ||||
| type Service struct { | ||||
| 	httpClient    *req.Client | ||||
| 	config        *types.StableDiffusionConfig | ||||
| 	taskQueue     *store.RedisQueue | ||||
| 	redis         *redis.Client | ||||
| 	db            *gorm.DB | ||||
| 	uploadManager *oss.UploaderManager | ||||
| 	Clients       *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池 | ||||
| 	httpClient       *req.Client | ||||
| 	config           types.StableDiffusionConfig | ||||
| 	taskQueue        *store.RedisQueue | ||||
| 	notifyQueue      *store.RedisQueue | ||||
| 	db               *gorm.DB | ||||
| 	uploadManager    *oss.UploaderManager | ||||
| 	name             string            // service name | ||||
| 	maxHandleTaskNum int32             // max task number current service can handle | ||||
| 	handledTaskNum   int32             // already handled task number | ||||
| 	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout | ||||
| 	taskTimeout      int64 | ||||
| } | ||||
|  | ||||
| func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service { | ||||
| func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { | ||||
| 	return &Service{ | ||||
| 		config:        &config.SdConfig, | ||||
| 		httpClient:    req.C(), | ||||
| 		redis:         redisCli, | ||||
| 		db:            db, | ||||
| 		uploadManager: manager, | ||||
| 		Clients:       types.NewLMap[string, *types.WsClient](), | ||||
| 		taskQueue:     store.NewRedisQueue("stable_diffusion_task_queue", redisCli), | ||||
| 		name:             name, | ||||
| 		config:           config, | ||||
| 		httpClient:       req.C(), | ||||
| 		taskQueue:        taskQueue, | ||||
| 		notifyQueue:      notifyQueue, | ||||
| 		db:               db, | ||||
| 		uploadManager:    manager, | ||||
| 		taskTimeout:      timeout, | ||||
| 		maxHandleTaskNum: maxTaskNum, | ||||
| 		taskStartTimes:   make(map[int]time.Time), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Service) Run() { | ||||
| 	logger.Info("Starting StableDiffusion job consumer.") | ||||
| 	ctx := context.Background() | ||||
| 	for { | ||||
| 		_, err := s.redis.Get(ctx, RunningJobKey).Result() | ||||
| 		if err == nil { // 队列串行执行 | ||||
| 		s.checkTasks() | ||||
| 		if !s.canHandleTask() { | ||||
| 			// current service is full, can not handle more task | ||||
| 			// waiting for running task finish | ||||
| 			time.Sleep(time.Second * 3) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		var task types.SdTask | ||||
| 		err = s.taskQueue.LPop(&task) | ||||
| 		err := s.taskQueue.LPop(&task) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("taking task with error: %v", err) | ||||
| 			continue | ||||
| 		} | ||||
| 		logger.Infof("Consuming Task: %+v", task) | ||||
| 		logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) | ||||
| 		err = s.Txt2Img(task) | ||||
| 		if err != nil { | ||||
| 			logger.Error("绘画任务执行失败:", err) | ||||
| 			if task.RetryCount <= 5 { | ||||
| 				s.taskQueue.RPush(task) | ||||
| 			} | ||||
| 			task.RetryCount += 1 | ||||
| 			time.Sleep(time.Second * 3) | ||||
| 			logger.Error("绘画任务执行失败:", err.Error()) | ||||
| 			// update the task progress | ||||
| 			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ | ||||
| 				"progress": -1, | ||||
| 				"err_msg":  err.Error(), | ||||
| 			}) | ||||
| 			// release task num | ||||
| 			atomic.AddInt32(&s.handledTaskNum, -1) | ||||
| 			// 通知前端,任务失败 | ||||
| 			s.notifyQueue.RPush(task.UserId) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// 更新任务的执行状态 | ||||
| 		s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) | ||||
| 		// 锁定任务执行通道,直到任务超时(5分钟) | ||||
| 		s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5) | ||||
| 		// lock the task until the execute timeout | ||||
| 		s.taskStartTimes[task.Id] = time.Now() | ||||
| 		atomic.AddInt32(&s.handledTaskNum, 1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // PushTask 推送任务到队列 | ||||
| func (s *Service) PushTask(task types.SdTask) { | ||||
| 	logger.Infof("add a new MidJourney Task: %+v", task) | ||||
| 	s.taskQueue.RPush(task) | ||||
| // check if current service instance can handle more task | ||||
| func (s *Service) canHandleTask() bool { | ||||
| 	handledNum := atomic.LoadInt32(&s.handledTaskNum) | ||||
| 	return handledNum < s.maxHandleTaskNum | ||||
| } | ||||
|  | ||||
| // remove the expired tasks | ||||
| func (s *Service) checkTasks() { | ||||
| 	for k, t := range s.taskStartTimes { | ||||
| 		if time.Now().Unix()-t.Unix() > s.taskTimeout { | ||||
| 			delete(s.taskStartTimes, k) | ||||
| 			atomic.AddInt32(&s.handledTaskNum, -1) | ||||
| 			// delete task from database | ||||
| 			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Txt2Img 文生图 API | ||||
| @@ -105,7 +125,8 @@ func (s *Service) Txt2Img(task types.SdTask) error { | ||||
| 	data[ParamKeys["negative_prompt"]] = params.NegativePrompt | ||||
| 	data[ParamKeys["steps"]] = params.Steps | ||||
| 	data[ParamKeys["sampler"]] = params.Sampler | ||||
| 	data[ParamKeys["face_fix"]] = params.FaceFix | ||||
| 	// @fix bug: 有些 stable diffusion 没有面部修复功能 | ||||
| 	//data[ParamKeys["face_fix"]] = params.FaceFix | ||||
| 	data[ParamKeys["cfg_scale"]] = params.CfgScale | ||||
| 	data[ParamKeys["seed"]] = params.Seed | ||||
| 	data[ParamKeys["height"]] = params.Height | ||||
| @@ -120,6 +141,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { | ||||
| 	taskInfo.TaskId = params.TaskId | ||||
| 	taskInfo.Data = data | ||||
| 	taskInfo.JobId = task.Id | ||||
| 	taskInfo.UserId = uint(task.UserId) | ||||
| 	go func() { | ||||
| 		s.runTask(taskInfo, s.httpClient) | ||||
| 	}() | ||||
| @@ -134,7 +156,6 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { | ||||
| 		"fn_index":     taskInfo.FnIndex, | ||||
| 		"session_hash": taskInfo.SessionHash, | ||||
| 	} | ||||
| 	logger.Debug(utils.JsonEncode(body)) | ||||
| 	var result = make(chan CBReq) | ||||
| 	go func() { | ||||
| 		var res struct { | ||||
| @@ -143,7 +164,7 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { | ||||
| 			Duration        float64       `json:"duration"` | ||||
| 			AverageDuration float64       `json:"average_duration"` | ||||
| 		} | ||||
| 		var cbReq = CBReq{TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} | ||||
| 		var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} | ||||
| 		response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict") | ||||
| 		if err != nil { | ||||
| 			cbReq.Message = "error with send request: " + err.Error() | ||||
| @@ -176,7 +197,8 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { | ||||
| 		var info map[string]any | ||||
| 		err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info) | ||||
| 		if err != nil { | ||||
| 			cbReq.Message = err.Error() | ||||
| 			logger.Error(res.Data) | ||||
| 			cbReq.Message = "error with decode image url:" + err.Error() | ||||
| 			cbReq.Success = false | ||||
| 			result <- cbReq | ||||
| 			return | ||||
| @@ -215,7 +237,7 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { | ||||
| 				TextInfo      interface{} `json:"textinfo"` | ||||
| 			} | ||||
| 			response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress") | ||||
| 			var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} | ||||
| 			var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} | ||||
| 			if err != nil { // TODO: 这里可以考虑设置失败重试次数 | ||||
| 				logger.Error(err) | ||||
| 				return | ||||
| @@ -236,9 +258,8 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { | ||||
| } | ||||
|  | ||||
| func (s *Service) callback(data CBReq) { | ||||
| 	// 释放任务锁 | ||||
| 	s.redis.Del(context.Background(), RunningJobKey) | ||||
| 	client := s.Clients.Get(data.SessionId) | ||||
| 	// release task num | ||||
| 	atomic.AddInt32(&s.handledTaskNum, -1) | ||||
| 	if data.Success { // 任务成功 | ||||
| 		var job model.SdJob | ||||
| 		res := s.db.Where("id = ?", data.JobId).First(&job) | ||||
| @@ -258,13 +279,15 @@ func (s *Service) callback(data CBReq) { | ||||
|  | ||||
| 		params.Seed = data.Seed | ||||
| 		if data.ImageName != "" { // 下载图片 | ||||
| 			imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) | ||||
| 			imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false) | ||||
| 			if err != nil { | ||||
| 				logger.Error("error with download img: ", err.Error()) | ||||
| 				return | ||||
| 			job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) | ||||
| 			if data.Progress == 100 { | ||||
| 				imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false) | ||||
| 				if err != nil { | ||||
| 					logger.Error("error with download img: ", err.Error()) | ||||
| 					return | ||||
| 				} | ||||
| 				job.ImgURL = imageURL | ||||
| 			} | ||||
| 			job.ImgURL = imageURL | ||||
| 		} | ||||
|  | ||||
| 		job.Params = utils.JsonEncode(params) | ||||
| @@ -274,32 +297,16 @@ func (s *Service) callback(data CBReq) { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		var jobVo vo.SdJob | ||||
| 		err = utils.CopyObject(job, &jobVo) | ||||
| 		if err != nil { | ||||
| 			logger.Error("error with copy object: ", err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if data.Progress < 100 && data.ImageData != "" { | ||||
| 			jobVo.ImgURL = data.ImageData | ||||
| 		} | ||||
|  | ||||
| 		// 推送任务到前端 | ||||
| 		if client != nil { | ||||
| 			utils.ReplyChunkMessage(client, jobVo) | ||||
| 		} | ||||
| 		logger.Debugf("绘图进度:%d", data.Progress) | ||||
| 	} else { // 任务失败 | ||||
| 		logger.Error("任务执行失败:", data.Message) | ||||
| 		// 删除任务 | ||||
| 		s.db.Delete(&model.SdJob{Id: uint(data.JobId)}) | ||||
| 		// 推送消息到前端 | ||||
| 		if client != nil { | ||||
| 			utils.ReplyChunkMessage(client, vo.SdJob{ | ||||
| 				Id:       uint(data.JobId), | ||||
| 				Progress: -1, | ||||
| 				TaskId:   data.TaskId, | ||||
| 			}) | ||||
| 		} | ||||
| 		// update the task progress | ||||
| 		s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{ | ||||
| 			"progress": -1, | ||||
| 			"err_msg":  data.Message, | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// 发送更新状态信号 | ||||
| 	s.notifyQueue.RPush(data.UserId) | ||||
| } | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import logger2 "chatplus/logger" | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| type TaskInfo struct { | ||||
| 	UserId      uint          `json:"user_id"` | ||||
| 	SessionId   string        `json:"session_id"` | ||||
| 	JobId       int           `json:"job_id"` | ||||
| 	TaskId      string        `json:"task_id"` | ||||
| @@ -15,6 +16,7 @@ type TaskInfo struct { | ||||
| } | ||||
|  | ||||
| type CBReq struct { | ||||
| 	UserId    uint | ||||
| 	SessionId string | ||||
| 	JobId     int | ||||
| 	TaskId    string | ||||
| @@ -32,14 +34,14 @@ var ParamKeys = map[string]int{ | ||||
| 	"negative_prompt": 2, | ||||
| 	"steps":           4, | ||||
| 	"sampler":         5, | ||||
| 	"face_fix":        6, | ||||
| 	"cfg_scale":       10, | ||||
| 	"seed":            11, | ||||
| 	"height":          17, | ||||
| 	"width":           18, | ||||
| 	"hd_fix":          19, | ||||
| 	"hd_redraw_rate":  20, //高清修复重绘幅度 | ||||
| 	"hd_scale":        21, // 高清修复放大倍数 | ||||
| 	"hd_scale_alg":    22, // 高清修复放大算法 | ||||
| 	"hd_sample_num":   23, // 高清修复采样次数 | ||||
| 	"face_fix":        7, // 面部修复 | ||||
| 	"cfg_scale":       8, | ||||
| 	"seed":            27, | ||||
| 	"height":          10, | ||||
| 	"width":           9, | ||||
| 	"hd_fix":          11, | ||||
| 	"hd_redraw_rate":  12, //高清修复重绘幅度 | ||||
| 	"hd_scale":        13, // 高清修复放大倍数 | ||||
| 	"hd_scale_alg":    14, // 高清修复放大算法 | ||||
| 	"hd_sample_num":   15, // 高清修复采样次数 | ||||
| } | ||||
|   | ||||
| @@ -1,31 +1,29 @@ | ||||
| package service | ||||
| package sms | ||||
| 
 | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/store" | ||||
| 	"fmt" | ||||
| 	"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi" | ||||
| ) | ||||
| 
 | ||||
| type AliYunSmsService struct { | ||||
| 	config *types.AliYunSmsConfig | ||||
| 	db     *store.LevelDB | ||||
| 	config *types.SmsConfigAli | ||||
| 	client *dysmsapi.Client | ||||
| } | ||||
| 
 | ||||
| func NewAliYunSmsService(config *types.AppConfig, db *store.LevelDB) (*AliYunSmsService, error) { | ||||
| func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) { | ||||
| 	config := &appConfig.SMS.Ali | ||||
| 	// 创建阿里云短信客户端 | ||||
| 	client, err := dysmsapi.NewClientWithAccessKey( | ||||
| 		"cn-hangzhou", | ||||
| 		config.SmsConfig.AccessKey, | ||||
| 		config.SmsConfig.AccessSecret) | ||||
| 		config.AccessKey, | ||||
| 		config.AccessSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create client: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return &AliYunSmsService{ | ||||
| 		config: &config.SmsConfig, | ||||
| 		db:     db, | ||||
| 		config: config, | ||||
| 		client: client, | ||||
| 	}, nil | ||||
| } | ||||
| @@ -49,6 +47,7 @@ func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error { | ||||
| 	if response.Code != "OK" { | ||||
| 		return fmt.Errorf("failed to send SMS:%v", response.Message) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| var _ Service = &AliYunSmsService{} | ||||
							
								
								
									
										72
									
								
								api/service/sms/bao.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								api/service/sms/bao.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,72 @@ | ||||
| package sms | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	"chatplus/utils" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type BaoSmsService struct { | ||||
| 	config *types.SmsConfigBao | ||||
| } | ||||
|  | ||||
| func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService { | ||||
| 	config := appConfig.SMS.Bao | ||||
| 	if config.Domain == "" { // use default domain | ||||
| 		config.Domain = "api.smsbao.com" | ||||
| 		logger.Infof("Using default domain for SMS-BAO: %s", config.Domain) | ||||
| 	} | ||||
| 	return &BaoSmsService{ | ||||
| 		config: &config, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var errMsg = map[string]string{ | ||||
| 	"0":  "短信发送成功", | ||||
| 	"-1": "参数不全", | ||||
| 	"-2": "服务器空间不支持,请确认支持curl或者fsocket,联系您的空间商解决或者更换空间", | ||||
| 	"30": "密码错误", | ||||
| 	"40": "账号不存在", | ||||
| 	"41": "余额不足", | ||||
| 	"42": "账户已过期", | ||||
| 	"43": "IP地址限制", | ||||
| 	"50": "内容含有敏感词", | ||||
| } | ||||
|  | ||||
| func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error { | ||||
|  | ||||
| 	content := fmt.Sprintf("%s%s", s.config.Sign, s.config.CodeTemplate) | ||||
| 	content = strings.ReplaceAll(content, "{code}", strconv.Itoa(code)) | ||||
| 	password := utils.Md5(s.config.Password) | ||||
| 	params := url.Values{} | ||||
| 	params.Set("u", s.config.Username) | ||||
| 	params.Set("p", password) | ||||
| 	params.Set("m", mobile) | ||||
| 	params.Set("c", content) | ||||
|  | ||||
| 	apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode()) | ||||
| 	response, err := http.Get(apiURL) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer response.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(response.Body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	result := string(body) | ||||
| 	logger.Debugf("send SmsBao result: %v", errMsg[result]) | ||||
|  | ||||
| 	if result != "0" { | ||||
| 		return fmt.Errorf("failed to send SMS:%v", errMsg[result]) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| var _ Service = &BaoSmsService{} | ||||
							
								
								
									
										8
									
								
								api/service/sms/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								api/service/sms/service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package sms | ||||
|  | ||||
| const Ali = "ALI" | ||||
| const Bao = "BAO" | ||||
|  | ||||
| type Service interface { | ||||
| 	SendVerifyCode(mobile string, code int) error | ||||
| } | ||||
							
								
								
									
										39
									
								
								api/service/sms/service_manager.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								api/service/sms/service_manager.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package sms | ||||
|  | ||||
| import ( | ||||
| 	"chatplus/core/types" | ||||
| 	logger2 "chatplus/logger" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type ServiceManager struct { | ||||
| 	handler Service | ||||
| } | ||||
|  | ||||
| var logger = logger2.GetLogger() | ||||
|  | ||||
| func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) { | ||||
| 	active := Ali | ||||
| 	if config.SMS.Active != "" { | ||||
| 		active = strings.ToUpper(config.SMS.Active) | ||||
| 	} | ||||
| 	var handler Service | ||||
| 	switch active { | ||||
| 	case Ali: | ||||
| 		client, err := NewAliYunSmsService(config) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		handler = client | ||||
| 		break | ||||
| 	case Bao: | ||||
| 		handler = NewSmsBaoSmsService(config) | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return &ServiceManager{handler: handler}, nil | ||||
| } | ||||
|  | ||||
| func (m *ServiceManager) GetService() Service { | ||||
| 	return m.handler | ||||
| } | ||||
| @@ -1,5 +0,0 @@ | ||||
| package service | ||||
|  | ||||
| type SmsService interface { | ||||
| 	SendVerifyCode(mobile string, code int) error | ||||
| } | ||||
							
								
								
									
										44
									
								
								api/service/smtp_sms_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								api/service/smtp_sms_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| package service | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"chatplus/core/types" | ||||
| 	"fmt" | ||||
| 	"mime" | ||||
| 	"net/smtp" | ||||
| ) | ||||
|  | ||||
| type SmtpService struct { | ||||
| 	config *types.SmtpConfig | ||||
| } | ||||
|  | ||||
| func NewSmtpService(appConfig *types.AppConfig) *SmtpService { | ||||
| 	return &SmtpService{ | ||||
| 		config: &appConfig.SmtpConfig, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *SmtpService) SendVerifyCode(to string, code int) error { | ||||
| 	subject := "ChatPlus注册验证码" | ||||
| 	body := fmt.Sprintf("您正在注册 ChatPlus AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code) | ||||
|  | ||||
| 	// 设置SMTP客户端配置 | ||||
| 	auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host) | ||||
|  | ||||
| 	// 对主题进行MIME编码 | ||||
| 	encodedSubject := mime.QEncoding.Encode("UTF-8", subject) | ||||
| 	// 组装邮件 | ||||
| 	message := bytes.NewBuffer(nil) | ||||
| 	message.WriteString(fmt.Sprintf("From: \"%s\" <%s>\r\n", s.config.AppName, s.config.From)) | ||||
| 	message.WriteString(fmt.Sprintf("To: %s\r\n", to)) | ||||
| 	message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject)) | ||||
| 	message.WriteString("\r\n" + body) | ||||
|  | ||||
| 	// 发送邮件 | ||||
| 	// 发送邮件 | ||||
| 	err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes()) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error sending email: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										59
									
								
								api/service/snowflake.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								api/service/snowflake.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| package service | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // Snowflake 雪花算法实现 | ||||
| type Snowflake struct { | ||||
| 	mu            sync.Mutex | ||||
| 	lastTimestamp int64 | ||||
| 	workerID      int | ||||
| 	sequence      int | ||||
| } | ||||
|  | ||||
| func NewSnowflake() *Snowflake { | ||||
| 	return &Snowflake{ | ||||
| 		lastTimestamp: -1, | ||||
| 		workerID:      0, // TODO: 增加 WorkID 参数 | ||||
| 		sequence:      0, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Next 生成一个新的唯一ID | ||||
| func (s *Snowflake) Next(raw bool) (string, error) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
|  | ||||
| 	timestamp := time.Now().UnixNano() / 1000000 // 转换为毫秒 | ||||
| 	if timestamp < s.lastTimestamp { | ||||
| 		return "", fmt.Errorf("clock moved backwards. Refusing to generate id for %d milliseconds", s.lastTimestamp-timestamp) | ||||
| 	} | ||||
|  | ||||
| 	if timestamp == s.lastTimestamp { | ||||
| 		s.sequence = (s.sequence + 1) & 4095 | ||||
| 		if s.sequence == 0 { | ||||
| 			timestamp = s.waitNextMillis() | ||||
| 		} | ||||
| 	} else { | ||||
| 		s.sequence = 0 | ||||
| 	} | ||||
|  | ||||
| 	s.lastTimestamp = timestamp | ||||
| 	id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence) | ||||
| 	if raw { | ||||
| 		return fmt.Sprintf("%d", id), nil | ||||
| 	} | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil | ||||
| } | ||||
|  | ||||
| func (s *Snowflake) waitNextMillis() int64 { | ||||
| 	timestamp := time.Now().UnixNano() / 1000000 | ||||
| 	for timestamp <= s.lastTimestamp { | ||||
| 		timestamp = time.Now().UnixNano() / 1000000 | ||||
| 	} | ||||
| 	return timestamp | ||||
| } | ||||
| @@ -6,6 +6,8 @@ import ( | ||||
| 	"github.com/eatmoreapple/openwechat" | ||||
| 	"github.com/skip2/go-qrcode" | ||||
| 	"gorm.io/gorm" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| // 微信收款机器人 | ||||
| @@ -34,8 +36,13 @@ func (b *Bot) Run() error { | ||||
| 	} | ||||
| 	// scan code login callback | ||||
| 	b.bot.UUIDCallback = b.qrCodeCallBack | ||||
|  | ||||
| 	err := b.bot.Login() | ||||
| 	debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG")) | ||||
| 	if debug { | ||||
| 		reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json") | ||||
| 		err = b.bot.HotLogin(reloadStorage, true) | ||||
| 	} else { | ||||
| 		err = b.bot.Login() | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -56,13 +63,13 @@ func (b *Bot) messageHandler(msg *openwechat.Message) { | ||||
| 		msg.MsgType == openwechat.MsgTypeApp || | ||||
| 		msg.AppMsgType == openwechat.AppMsgTypeUrl { | ||||
| 		// 解析支付金额 | ||||
| 		message, err := parseTransactionMessage(msg.Content) | ||||
| 		if err == nil { | ||||
| 			transaction := extractTransaction(message) | ||||
| 			logger.Infof("解析到收款信息:%+v", transaction) | ||||
| 		message := parseTransactionMessage(msg.Content) | ||||
| 		transaction := extractTransaction(message) | ||||
| 		logger.Infof("解析到收款信息:%+v", transaction) | ||||
| 		if transaction.TransId != "" { | ||||
| 			var item model.Reward | ||||
| 			res := b.db.Where("tx_id = ?", transaction.TransId).First(&item) | ||||
| 			if res.Error == nil { | ||||
| 			if item.Id > 0 { | ||||
| 				logger.Error("当前交易 ID 己经存在!") | ||||
| 				return | ||||
| 			} | ||||
|   | ||||
| @@ -2,17 +2,15 @@ package wx | ||||
|  | ||||
| import ( | ||||
| 	"encoding/xml" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // Message 转账消息 | ||||
| type Message struct { | ||||
| 	XMLName xml.Name `xml:"msg"` | ||||
| 	AppMsg  struct { | ||||
| 		Des string `xml:"des"` | ||||
| 		Url string `xml:"url"` | ||||
| 	} `xml:"appmsg"` | ||||
| 	Des string | ||||
| 	Url string | ||||
| } | ||||
|  | ||||
| // Transaction 解析后的交易信息 | ||||
| @@ -23,20 +21,56 @@ type Transaction struct { | ||||
| } | ||||
|  | ||||
| // 解析微信转账消息 | ||||
| func parseTransactionMessage(xmlData string) (*Message, error) { | ||||
| 	var msg Message | ||||
| 	if err := xml.Unmarshal([]byte(xmlData), &msg); err != nil { | ||||
| 		return nil, err | ||||
| func parseTransactionMessage(xmlData string) *Message { | ||||
| 	decoder := xml.NewDecoder(strings.NewReader(xmlData)) | ||||
| 	message := Message{} | ||||
| 	for { | ||||
| 		token, err := decoder.Token() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		switch se := token.(type) { | ||||
| 		case xml.StartElement: | ||||
| 			var value string | ||||
| 			if se.Name.Local == "des" && message.Des == "" { | ||||
| 				if err := decoder.DecodeElement(&value, &se); err == nil { | ||||
| 					message.Des = strings.TrimSpace(value) | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 			if se.Name.Local == "weapp_path" || se.Name.Local == "url" { | ||||
| 				if err := decoder.DecodeElement(&value, &se); err == nil { | ||||
| 					if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") { | ||||
| 						message.Url = value | ||||
| 					} | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &msg, nil | ||||
| 	// 兼容旧版消息记录 | ||||
| 	if message.Url == "" { | ||||
| 		var msg struct { | ||||
| 			XMLName xml.Name `xml:"msg"` | ||||
| 			AppMsg  struct { | ||||
| 				Des string `xml:"des"` | ||||
| 				Url string `xml:"url"` | ||||
| 			} `xml:"appmsg"` | ||||
| 		} | ||||
| 		if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil { | ||||
| 			message.Url = msg.AppMsg.Url | ||||
| 		} | ||||
| 	} | ||||
| 	return &message | ||||
| } | ||||
|  | ||||
| // 导出交易信息 | ||||
| func extractTransaction(message *Message) Transaction { | ||||
| 	var tx = Transaction{} | ||||
| 	// 导出交易金额和备注 | ||||
| 	lines := strings.Split(message.AppMsg.Des, "\n") | ||||
| 	lines := strings.Split(message.Des, "\n") | ||||
| 	for _, line := range lines { | ||||
| 		line = strings.TrimSpace(line) | ||||
| 		if len(line) == 0 { | ||||
| @@ -59,10 +93,13 @@ func extractTransaction(message *Message) Transaction { | ||||
| 	} | ||||
|  | ||||
| 	// 解析交易 ID | ||||
| 	index := strings.Index(message.AppMsg.Url, "trans_id=") | ||||
| 	if index != -1 { | ||||
| 		end := strings.LastIndex(message.AppMsg.Url, "&") | ||||
| 		tx.TransId = strings.TrimSpace(message.AppMsg.Url[index+9 : end]) | ||||
| 	parse, err := url.Parse(message.Url) | ||||
| 	if err == nil { | ||||
| 		tx.TransId = parse.Query().Get("id") | ||||
| 		if tx.TransId == "" { | ||||
| 			tx.TransId = parse.Query().Get("trans_id") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return tx | ||||
| } | ||||
|   | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user