mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-09 19:25:20 +08:00
merge v4.2.6
整合 v4.2.6 的后端中间件与服务层重构、前端样式体系迁移和管理端/移动端功能更新,统一清理历史冲突并完成版本升级。 Made-with: Cursor
This commit is contained in:
37
.claude/commands/frontend-developer.md
Normal file
37
.claude/commands/frontend-developer.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
---
|
||||||
|
name: frontend-developer
|
||||||
|
description: Use this agent when you need assistance with frontend development tasks including Vue.js components, UI implementation, styling, responsive design, state management, or frontend architecture decisions. Examples: <example>Context: User is working on a Vue.js component and needs help with implementing a responsive layout. user: 'I need to create a mobile-friendly chat interface component' assistant: 'I'll use the frontend-developer agent to help design and implement this responsive chat component' <commentary>Since this involves frontend development work with Vue.js and responsive design, use the frontend-developer agent.</commentary></example> <example>Context: User encounters styling issues with Element Plus components. user: 'The Element Plus dialog is not displaying correctly on mobile devices' assistant: 'Let me use the frontend-developer agent to troubleshoot this mobile styling issue' <commentary>This is a frontend styling problem that requires expertise in Element Plus and responsive design.</commentary></example>
|
||||||
|
color: purple
|
||||||
|
---
|
||||||
|
|
||||||
|
You are a Senior Frontend Development Engineer with deep expertise in modern web development technologies, particularly Vue.js 3, Element Plus, Vant, and responsive design patterns. You specialize in creating high-quality, maintainable frontend applications with excellent user experience.
|
||||||
|
|
||||||
|
Your core responsibilities include:
|
||||||
|
- Developing Vue.js 3 components using Composition API and best practices
|
||||||
|
- Implementing responsive designs that work seamlessly across desktop and mobile devices
|
||||||
|
- Working with Element Plus for desktop UI and Vant for mobile components
|
||||||
|
- Managing application state using Pinia store patterns
|
||||||
|
- Styling with Stylus preprocessor and Tailwind CSS utilities
|
||||||
|
- Optimizing build processes with Vite and ensuring proper code organization
|
||||||
|
- Implementing theme switching (dark/light mode) and accessibility features
|
||||||
|
- Follow decoupled development, with HTML, CSS, and JS codes placed in separate files for easier maintenance
|
||||||
|
|
||||||
|
When working on frontend tasks, you will:
|
||||||
|
1. Analyze requirements and suggest the most appropriate Vue.js patterns and component structures
|
||||||
|
2. Ensure responsive design principles are followed, considering both desktop and mobile viewports
|
||||||
|
3. Choose appropriate UI components from Element Plus (desktop) or Vant (mobile) libraries
|
||||||
|
4. Write clean, maintainable code following Vue.js 3 Composition API best practices
|
||||||
|
5. Consider performance implications and suggest optimizations when relevant
|
||||||
|
6. Ensure proper state management using Pinia when component state needs to be shared
|
||||||
|
7. Follow the project's established patterns for routing, API integration, and component organization
|
||||||
|
8. Provide specific code examples and explain the reasoning behind architectural decisions
|
||||||
|
|
||||||
|
You have deep knowledge of:
|
||||||
|
- Vue.js 3 ecosystem (Vue Router, Pinia, Composition API)
|
||||||
|
- Modern CSS techniques and preprocessors (Stylus, Tailwind)
|
||||||
|
- Component library integration (Element Plus, Vant)
|
||||||
|
- Build tools and development workflow (Vite, npm scripts)
|
||||||
|
- Cross-browser compatibility and mobile-first design principles
|
||||||
|
- Performance optimization and code splitting strategies
|
||||||
|
|
||||||
|
Always consider the user experience, code maintainability, and alignment with modern frontend development standards. When suggesting solutions, provide clear explanations and consider both immediate needs and long-term scalability.
|
||||||
6
.claude/commands/refactor.md
Normal file
6
.claude/commands/refactor.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
重构当前页面代码
|
||||||
|
|
||||||
|
1. 把当前页面 JS 代码全部抽离,然后是采用 Pinia 重构
|
||||||
|
2. 把当前页面 CSS 代码全部抽离,如果是 stylus 语法代码,则需要改成 SCSS 语法代码
|
||||||
|
3. 尽量做到代码的复用性,不要重复造轮子
|
||||||
|
4. 移动端的 css 和 js 分别放到对应的 mobile 目录下,不要覆盖 PC 端的代码
|
||||||
13
CHANGELOG.md
13
CHANGELOG.md
@@ -1,5 +1,18 @@
|
|||||||
# 更新日志
|
# 更新日志
|
||||||
|
|
||||||
|
## v4.2.6
|
||||||
|
|
||||||
|
- 功能重构:优化系统配置管理功能,把 OSS,支付,短信,邮件等配置全部迁移到管理后台,无需通过修改配置文档的方式修改 🎉🎉🎉
|
||||||
|
- 功能优化:重构 API 授权代码,采用中间件鉴权方式,实现更加精准的 API 鉴权 🎉🎉🎉
|
||||||
|
- 功能优化:优化 PC 端的 Suno 音乐,视频生成,以及即梦 AI 页面 UI
|
||||||
|
- 功能优化:重构登录和注册页面,兼容移动端和 PC 端,并且所有的登录组件共用了同一套组件代码,大大降低维护成本 🎉🎉🎉
|
||||||
|
- 功能优化:管理后台增加模型批量删除功能
|
||||||
|
- 功能优化:优化 Table 组件 UI,并支持 dark 主题
|
||||||
|
- 功能优化:移动端对话页面支持上传文件和图片
|
||||||
|
- 功能新增:新增微信扫码登录支持
|
||||||
|
- 功能新增:新增安全监控,内容审核功能,支持敏感内容过滤拦截
|
||||||
|
- 功能新增:DALL-E 绘图支持参 Google Banana 图片编辑功能
|
||||||
|
|
||||||
## v4.2.5
|
## v4.2.5
|
||||||
|
|
||||||
- 功能优化:在代码右下角增加复制代码功能按钮,增加收起和展开代码功能
|
- 功能优化:在代码右下角增加复制代码功能按钮,增加收起和展开代码功能
|
||||||
|
|||||||
@@ -1,195 +0,0 @@
|
|||||||
# 即梦 AI 配置功能说明
|
|
||||||
|
|
||||||
## 功能概述
|
|
||||||
|
|
||||||
即梦 AI 配置功能允许管理员通过 Web 界面配置即梦 AI 的 API 密钥和算力消耗设置,支持动态配置更新,无需重启服务。
|
|
||||||
|
|
||||||
## 功能特性
|
|
||||||
|
|
||||||
### 1. 秘钥配置
|
|
||||||
|
|
||||||
- AccessKey 和 SecretKey 配置
|
|
||||||
- 支持密码显示/隐藏
|
|
||||||
- 连接测试功能
|
|
||||||
|
|
||||||
### 2. 算力配置
|
|
||||||
|
|
||||||
- 文生图算力消耗
|
|
||||||
- 图生图算力消耗
|
|
||||||
- 图片编辑算力消耗
|
|
||||||
- 图片特效算力消耗
|
|
||||||
- 文生视频算力消耗
|
|
||||||
- 图生视频算力消耗
|
|
||||||
|
|
||||||
### 3. 动态配置
|
|
||||||
|
|
||||||
- 配置实时生效
|
|
||||||
- 无需重启服务
|
|
||||||
- 支持配置验证
|
|
||||||
|
|
||||||
## API 接口
|
|
||||||
|
|
||||||
### 获取配置
|
|
||||||
|
|
||||||
```
|
|
||||||
GET /api/admin/jimeng/config
|
|
||||||
```
|
|
||||||
|
|
||||||
### 更新配置
|
|
||||||
|
|
||||||
```
|
|
||||||
POST /api/admin/jimeng/config
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"config": {
|
|
||||||
"access_key": "your_access_key",
|
|
||||||
"secret_key": "your_secret_key",
|
|
||||||
"power": {
|
|
||||||
"text_to_image": 10,
|
|
||||||
"image_to_image": 15,
|
|
||||||
"image_edit": 20,
|
|
||||||
"image_effects": 25,
|
|
||||||
"text_to_video": 30,
|
|
||||||
"image_to_video": 35
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 测试连接
|
|
||||||
|
|
||||||
```
|
|
||||||
POST /api/admin/jimeng/config/test
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"config": {
|
|
||||||
"access_key": "your_access_key",
|
|
||||||
"secret_key": "your_secret_key"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 前端页面
|
|
||||||
|
|
||||||
### 访问路径
|
|
||||||
|
|
||||||
管理后台 -> 即梦 AI -> 配置设置
|
|
||||||
|
|
||||||
### 页面功能
|
|
||||||
|
|
||||||
1. **秘钥配置标签页**
|
|
||||||
|
|
||||||
- AccessKey 输入框(密码模式)
|
|
||||||
- SecretKey 输入框(密码模式)
|
|
||||||
- 测试连接按钮
|
|
||||||
|
|
||||||
2. **算力配置标签页**
|
|
||||||
|
|
||||||
- 各种任务类型的算力消耗配置
|
|
||||||
- 数字输入框,支持 1-100 范围
|
|
||||||
- 提示信息说明
|
|
||||||
|
|
||||||
3. **操作按钮**
|
|
||||||
- 保存配置
|
|
||||||
- 重置配置
|
|
||||||
|
|
||||||
## 配置存储
|
|
||||||
|
|
||||||
配置存储在数据库的`config`表中:
|
|
||||||
|
|
||||||
- 配置键:`jimeng`
|
|
||||||
- 配置值:JSON 格式的即梦 AI 配置
|
|
||||||
|
|
||||||
## 默认配置
|
|
||||||
|
|
||||||
如果配置不存在,系统会使用以下默认值:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"access_key": "",
|
|
||||||
"secret_key": "",
|
|
||||||
"power": {
|
|
||||||
"text_to_image": 10,
|
|
||||||
"image_to_image": 15,
|
|
||||||
"image_edit": 20,
|
|
||||||
"image_effects": 25,
|
|
||||||
"text_to_video": 30,
|
|
||||||
"image_to_video": 35
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 使用流程
|
|
||||||
|
|
||||||
1. **初始配置**
|
|
||||||
|
|
||||||
- 访问管理后台即梦 AI 配置页面
|
|
||||||
- 填写 AccessKey 和 SecretKey
|
|
||||||
- 点击"测试连接"验证配置
|
|
||||||
- 调整各功能算力消耗
|
|
||||||
- 保存配置
|
|
||||||
|
|
||||||
2. **配置更新**
|
|
||||||
|
|
||||||
- 修改需要更新的配置项
|
|
||||||
- 保存配置
|
|
||||||
- 配置立即生效
|
|
||||||
|
|
||||||
3. **故障排查**
|
|
||||||
- 使用"测试连接"功能验证 API 密钥
|
|
||||||
- 检查配置是否正确保存
|
|
||||||
- 查看服务日志
|
|
||||||
|
|
||||||
## 注意事项
|
|
||||||
|
|
||||||
1. **权限要求**
|
|
||||||
|
|
||||||
- 只有管理员可以访问配置页面
|
|
||||||
- 需要有效的管理员登录会话
|
|
||||||
|
|
||||||
2. **配置验证**
|
|
||||||
|
|
||||||
- AccessKey 和 SecretKey 不能为空
|
|
||||||
- 算力消耗必须大于 0
|
|
||||||
- 建议先测试连接再保存配置
|
|
||||||
|
|
||||||
3. **服务影响**
|
|
||||||
- 配置更新不会影响正在进行的任务
|
|
||||||
- 新任务会使用更新后的配置
|
|
||||||
- 客户端配置会在下次请求时更新
|
|
||||||
|
|
||||||
## 错误处理
|
|
||||||
|
|
||||||
1. **配置加载失败**
|
|
||||||
|
|
||||||
- 使用默认配置
|
|
||||||
- 记录错误日志
|
|
||||||
|
|
||||||
2. **连接测试失败**
|
|
||||||
|
|
||||||
- 显示具体错误信息
|
|
||||||
- 建议检查 API 密钥
|
|
||||||
|
|
||||||
3. **配置保存失败**
|
|
||||||
- 显示错误信息
|
|
||||||
- 保留原有配置
|
|
||||||
|
|
||||||
## 开发说明
|
|
||||||
|
|
||||||
### 后端文件
|
|
||||||
|
|
||||||
- `api/handler/admin/jimeng_handler.go` - 配置管理 API
|
|
||||||
- `api/service/jimeng/service.go` - 配置服务逻辑
|
|
||||||
- `api/core/types/jimeng.go` - 配置类型定义
|
|
||||||
|
|
||||||
### 前端文件
|
|
||||||
|
|
||||||
- `web/src/views/admin/jimeng/JimengSetting.vue` - 配置页面
|
|
||||||
|
|
||||||
### 数据库
|
|
||||||
|
|
||||||
- `config`表存储配置信息
|
|
||||||
- 配置键:`jimeng`
|
|
||||||
- 配置值:JSON 格式
|
|
||||||
145
README.md
145
README.md
@@ -1,90 +1,77 @@
|
|||||||
# GeekAI
|
# 🚀 GeekAI-PLUS:一站式 AI 创意生产力平台
|
||||||
|
|
||||||
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
**重新定义 AI 创作体验,让每个人都能成为内容创作大师**
|
||||||
|
|
||||||
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Claude, 通义千问,Kimi,DeepSeek,Gitee AI 等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI 绘画功能。
|
基于 GeekAI 项目开发的高级版,增加了很多高级功能,比如思维导图,Dalle 绘画等。**高级版源码不会一次性开放,只提供镜像给大家免费使用**,源码会逐步逐步按照版同步迁移到[社区版(GeekAI)](https://github.com/yangjian102621/geekai)。所以如果大家想要二次开发,请移步去社区版。
|
||||||
|
|
||||||
主要特性:
|
## ✨ 核心特色
|
||||||
|
|
||||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
### 🎨 **全能 AI 创作矩阵**
|
||||||
- 基于 Websocket 实现,完美的打字机体验。
|
|
||||||
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
|
||||||
- 支持 OpenAI, Claude, 通义千问,Kimi,DeepSeek 等多个大语言模型,**支持 Gitee AI Serverless 大模型 API**。
|
|
||||||
- 支持 Suno 文生音乐
|
|
||||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
|
|
||||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
|
||||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
|
||||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
|
||||||
绘画函数插件。
|
|
||||||
|
|
||||||
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
|
- **智能对话**:集成 ChatGPT、Claude 等多款顶级 AI 模型,支持角色扮演和专业对话
|
||||||
|
- **图像生成**:整合 MidJourney、DALL-E、Stable Diffusion 三大主流 AI 绘画引擎
|
||||||
|
- **音频创作**:Suno AI 音乐生成,从旋律到歌词一键创作专属音乐
|
||||||
|
- **视频制作**:Luma 和 KeLing,即梦,Veo3 视频 AI,文本到视频,创意无限
|
||||||
|
- **思维导图**:AI 辅助思维整理,复杂想法可视化呈现
|
||||||
|
|
||||||
- [x] 更友好的 UI 界面
|
### 🏗️ **企业级技术架构**
|
||||||
- [x] 支持 Dall-E 文生图功能
|
|
||||||
- [x] 支持文生思维导图
|
- **高性能后端**:Go + Gin + MySQL + Redis,支持高并发访问
|
||||||
- [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能
|
- **现代化前端**:Vue3 + Element Plus + Vant,桌面移动双端适配
|
||||||
- [x] 支持网站 Logo 版权等信息的修改
|
- **智能缓存**:多层缓存策略,响应速度提升 80%
|
||||||
|
- **弹性部署**:Docker 容器化部署,一键启动,轻松扩展
|
||||||
|
- **私有化部署**:支持私有化部署,私有化部署不支持升级,需要手动升级
|
||||||
|
- **文档支持**:丰富且详细的部署和 API 开发文档支持,二次开发轻松上手
|
||||||
|
|
||||||
|
### 💼 **商业化就绪**
|
||||||
|
|
||||||
|
- **完整用户系统**:注册登录、权限管理、积分充值
|
||||||
|
- **灵活计费模式**:支持按次付费、包月订阅等多种商业模式
|
||||||
|
- **数据统计分析**:用户行为、消费记录、系统性能全方位监控
|
||||||
|
- **管理后台**:功能完备的管理员界面,运营数据一目了然
|
||||||
|
|
||||||
|
### 🎯 **用户体验优势**
|
||||||
|
|
||||||
|
- **响应式设计**:完美适配桌面、平板、手机等全终端设备
|
||||||
|
- **暗黑模式**:支持明暗主题切换,护眼舒适
|
||||||
|
- **实时交互**:WebSocket 实时通信,创作过程流畅无卡顿
|
||||||
|
- **文件管理**:支持多种云存储,作品安全可靠
|
||||||
|
|
||||||
|
## 🎪 **应用场景**
|
||||||
|
|
||||||
|
- **内容创作者**:博客写作、社交媒体素材、短视频制作
|
||||||
|
- **企业营销**:品牌宣传材料、产品介绍、创意广告
|
||||||
|
- **教育培训**:课件制作、知识图谱、互动内容
|
||||||
|
- **个人娱乐**:AI 聊天、创意绘画、音乐创作
|
||||||
|
|
||||||
|
## 🔥 **为什么选择 GeekAI-PLUS?**
|
||||||
|
|
||||||
|
1. **技术领先**:集成当前最先进的 AI 技术,始终保持创新前沿
|
||||||
|
2. **开箱即用**:完整的商业化解决方案,无需从零开发
|
||||||
|
3. **高度定制**:模块化架构设计,支持个性化功能扩展
|
||||||
|
4. **稳定可靠**:经过大量用户验证,性能稳定,安全可信
|
||||||
|
5. **持续更新**:紧跟 AI 技术发展,功能持续迭代升级
|
||||||
|
|
||||||
|
## 演示站点
|
||||||
|
|
||||||
|
[Geek-AI 创作系统](https://www.geekai.me)
|
||||||
|
|
||||||
|
## 文档地址
|
||||||
|
|
||||||
|
[Geek-AI 文档](https://www.geekai.me/docs/)
|
||||||
|
|
||||||
|
## 部署
|
||||||
|
|
||||||
|
1. 安装 docker 和 docker-compose 程序,这个自行解决。
|
||||||
|
2. 直接在项目根目录运行启动命令:
|
||||||
|
```shell
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
## 功能截图
|
## 功能截图
|
||||||
|
|
||||||
请参考 [GeekAI 项目介绍](https://docs.geekai.me/plus/info/)。
|
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
|
||||||
|
|
||||||
### 体验地址
|
---
|
||||||
|
|
||||||
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/> > **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
|
_让 AI 成为你最强大的创作伙伴,开启无限创意可能!_
|
||||||
|
|
||||||
## 快速部署体验
|
|
||||||
|
|
||||||
您可以通过 EazyDevelop 平台体验-键私有化部署 **GeekAI 创作助手**,只需一分钟即可部署成功。
|
|
||||||
|
|
||||||
部署模板地址: [https://eazydevelop.eazytec-cloud.com/templates/dev-template-5e4dc4-1764053014?q=bB3R_1VnJq9_3Zs9uX](https://eazydevelop.eazytec-cloud.com/templates/dev-template-5e4dc4-1764053014?q=bB3R_1VnJq9_3Zs9uX)
|
|
||||||
|
|
||||||
详细部署教程请参考 [EazyDevelop 部署 GeekAI](https://docs.geekai.me/plus/install/quick-start.html#eazydevelop-一键部署)。
|
|
||||||
|
|
||||||
## 使用须知
|
|
||||||
|
|
||||||
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
|
||||||
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
|
|
||||||
|
|
||||||
## 项目地址
|
|
||||||
|
|
||||||
- Github 地址:https://github.com/yangjian102621/geekai
|
|
||||||
- 码云地址:https://gitee.com/blackfox/geekai
|
|
||||||
|
|
||||||
## 客户端下载
|
|
||||||
|
|
||||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
|
|
||||||
|
|
||||||
## 项目文档
|
|
||||||
|
|
||||||
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
|
|
||||||
|
|
||||||
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
|
|
||||||
|
|
||||||
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自 Github!!!)。**
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## 参与贡献
|
|
||||||
|
|
||||||
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
|
|
||||||
|
|
||||||
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合!
|
|
||||||
|
|
||||||
### Commit 类型
|
|
||||||
|
|
||||||
- feat: 新特性或功能
|
|
||||||
- fix: 缺陷修复
|
|
||||||
- docs: 文档更新
|
|
||||||
- style: 代码风格或者组件样式更新
|
|
||||||
- refactor: 代码重构,不引入新功能和缺陷修复
|
|
||||||
- opt: 性能优化
|
|
||||||
- chore: 一些不涉及到功能变动的小提交,比如修改文字表述,修改注释等
|
|
||||||
|
|
||||||
## 打赏
|
|
||||||
|
|
||||||
如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|||||||
@@ -8,118 +8,50 @@ package core
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
|
||||||
"image"
|
|
||||||
"image/jpeg"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"github.com/nfnt/resize"
|
|
||||||
"github.com/shirou/gopsutil/host"
|
"github.com/shirou/gopsutil/host"
|
||||||
"golang.org/x/image/webp"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthConfig 定义授权配置
|
|
||||||
type AuthConfig struct {
|
|
||||||
ExactPaths map[string]bool // 精确匹配的路径
|
|
||||||
PrefixPaths map[string]bool // 前缀匹配的路径
|
|
||||||
}
|
|
||||||
|
|
||||||
var authConfig = &AuthConfig{
|
|
||||||
ExactPaths: map[string]bool{
|
|
||||||
"/api/user/login": false,
|
|
||||||
"/api/user/logout": false,
|
|
||||||
"/api/user/resetPass": false,
|
|
||||||
"/api/user/register": false,
|
|
||||||
"/api/admin/login": false,
|
|
||||||
"/api/admin/logout": false,
|
|
||||||
"/api/admin/login/captcha": false,
|
|
||||||
"/api/app/list": false,
|
|
||||||
"/api/app/type/list": false,
|
|
||||||
"/api/app/list/user": false,
|
|
||||||
"/api/model/list": false,
|
|
||||||
"/api/mj/imgWall": false,
|
|
||||||
"/api/mj/notify": false,
|
|
||||||
"/api/invite/hits": false,
|
|
||||||
"/api/sd/imgWall": false,
|
|
||||||
"/api/dall/imgWall": false,
|
|
||||||
"/api/product/list": false,
|
|
||||||
"/api/menu/list": false,
|
|
||||||
"/api/markMap/client": false,
|
|
||||||
"/api/payment/doPay": false,
|
|
||||||
"/api/payment/payWays": false,
|
|
||||||
"/api/download": false,
|
|
||||||
"/api/dall/models": false,
|
|
||||||
},
|
|
||||||
PrefixPaths: map[string]bool{
|
|
||||||
"/api/test/": false,
|
|
||||||
"/api/payment/notify/": false,
|
|
||||||
"/api/user/clogin": false,
|
|
||||||
"/api/config/": false,
|
|
||||||
"/api/function/": false,
|
|
||||||
"/api/sms/": false,
|
|
||||||
"/api/captcha/": false,
|
|
||||||
"/static/": false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type AppServer struct {
|
type AppServer struct {
|
||||||
Config *types.AppConfig
|
Config *types.AppConfig
|
||||||
Engine *gin.Engine
|
Engine *gin.Engine
|
||||||
SysConfig *types.SystemConfig // system config cache
|
SysConfig *types.SystemConfig // system config cache
|
||||||
|
Redis *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
gin.DefaultWriter = io.Discard
|
gin.DefaultWriter = io.Discard
|
||||||
return &AppServer{
|
return &AppServer{
|
||||||
Config: appConfig,
|
Config: appConfig,
|
||||||
Engine: gin.Default(),
|
Redis: redis,
|
||||||
|
Engine: gin.Default(),
|
||||||
|
SysConfig: sysConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
|
func (s *AppServer) Init(client *redis.Client) {
|
||||||
// 允许跨域请求 API
|
s.Engine.Use(middleware.ParameterHandlerMiddleware())
|
||||||
s.Engine.Use(corsMiddleware())
|
|
||||||
s.Engine.Use(staticResourceMiddleware())
|
|
||||||
s.Engine.Use(authorizeMiddleware(s, client))
|
|
||||||
s.Engine.Use(parameterHandlerMiddleware())
|
|
||||||
s.Engine.Use(errorHandler)
|
s.Engine.Use(errorHandler)
|
||||||
// 添加静态资源访问
|
// 添加静态资源访问
|
||||||
s.Engine.Static("/static", s.Config.StaticDir)
|
s.Engine.Static("/static", s.Config.StaticDir)
|
||||||
|
s.Engine.Use(middleware.StaticMiddleware())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AppServer) Run(db *gorm.DB) error {
|
func (s *AppServer) Run(db *gorm.DB) error {
|
||||||
|
|
||||||
// 重命名 config 表字段
|
|
||||||
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
|
||||||
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasColumn(&model.Config{}, "marker") {
|
|
||||||
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
|
||||||
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasIndex(&model.Config{}, "marker") {
|
|
||||||
db.Migrator().DropIndex(&model.Config{}, "marker")
|
|
||||||
}
|
|
||||||
|
|
||||||
// load system configs
|
// load system configs
|
||||||
var sysConfig model.Config
|
var sysConfig model.Config
|
||||||
err := db.Where("name", "system").First(&sysConfig).Error
|
err := db.Where("name", "system").First(&sysConfig).Error
|
||||||
@@ -131,94 +63,22 @@ func (s *AppServer) Run(db *gorm.DB) error {
|
|||||||
return fmt.Errorf("failed to decode system config: %v", err)
|
return fmt.Errorf("failed to decode system config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 迁移数据表
|
|
||||||
logger.Info("Migrating database tables...")
|
|
||||||
db.AutoMigrate(
|
|
||||||
&model.ChatItem{},
|
|
||||||
&model.ChatMessage{},
|
|
||||||
&model.ChatRole{},
|
|
||||||
&model.ChatModel{},
|
|
||||||
&model.InviteCode{},
|
|
||||||
&model.InviteLog{},
|
|
||||||
&model.Menu{},
|
|
||||||
&model.Order{},
|
|
||||||
&model.Product{},
|
|
||||||
&model.User{},
|
|
||||||
&model.Function{},
|
|
||||||
&model.File{},
|
|
||||||
&model.Redeem{},
|
|
||||||
&model.Config{},
|
|
||||||
&model.ApiKey{},
|
|
||||||
&model.AdminUser{},
|
|
||||||
&model.AppType{},
|
|
||||||
&model.SdJob{},
|
|
||||||
&model.SunoJob{},
|
|
||||||
&model.PowerLog{},
|
|
||||||
&model.VideoJob{},
|
|
||||||
&model.MidJourneyJob{},
|
|
||||||
&model.UserLoginLog{},
|
|
||||||
&model.DallJob{},
|
|
||||||
&model.JimengJob{},
|
|
||||||
)
|
|
||||||
// 手动删除字段
|
|
||||||
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
|
||||||
db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
|
||||||
db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
|
||||||
db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
|
||||||
db.Migrator().DropColumn(&model.User{}, "chat_config")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
|
||||||
db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
|
||||||
}
|
|
||||||
if db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
|
||||||
db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Database tables migrated successfully")
|
|
||||||
|
|
||||||
// 统计安装信息
|
// 统计安装信息
|
||||||
go func() {
|
go func() {
|
||||||
info, err := host.Info()
|
info, err := host.Info()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
|
apiURL := fmt.Sprintf("%s/api/installs/push", types.GeekAPIURL)
|
||||||
timestamp := time.Now().Unix()
|
timestamp := time.Now().Unix()
|
||||||
product := "geekai-plus"
|
product := "geekai-plus"
|
||||||
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
||||||
sign := utils.Sha256(signStr)
|
sign := utils.Sha256(signStr)
|
||||||
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
logger.Errorf("register install info failed: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Debugf("register install info success: %v", resp.String())
|
logger.Debugf("register install info success: %v", resp.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
logger.Infof("http://%s", s.Config.Listen)
|
logger.Infof("http://%s", s.Config.Listen)
|
||||||
|
|
||||||
// 统计安装信息
|
|
||||||
go func() {
|
|
||||||
info, err := host.Info()
|
|
||||||
if err == nil {
|
|
||||||
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
|
|
||||||
timestamp := time.Now().Unix()
|
|
||||||
product := "geekai-plus"
|
|
||||||
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
|
||||||
sign := utils.Sha256(signStr)
|
|
||||||
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("register install info failed: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Debugf("register install info success: %v", resp.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s.Engine.Run(s.Config.Listen)
|
return s.Engine.Run(s.Config.Listen)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,283 +95,3 @@ func errorHandler(c *gin.Context) {
|
|||||||
//加载完 defer recover,继续后续接口调用
|
//加载完 defer recover,继续后续接口调用
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 跨域中间件设置
|
|
||||||
func corsMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
method := c.Request.Method
|
|
||||||
origin := c.Request.Header.Get("Origin")
|
|
||||||
|
|
||||||
// 设置允许的请求源
|
|
||||||
if origin != "" {
|
|
||||||
c.Header("Access-Control-Allow-Origin", origin)
|
|
||||||
} else {
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
|
||||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
|
||||||
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
|
||||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
|
||||||
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
|
||||||
//设置缓存时间
|
|
||||||
c.Header("Access-Control-Max-Age", "172800")
|
|
||||||
//允许客户端传递校验信息比如 cookie (重要)
|
|
||||||
c.Header("Access-Control-Allow-Credentials", "true")
|
|
||||||
|
|
||||||
if method == http.MethodOptions {
|
|
||||||
c.JSON(http.StatusOK, "ok!")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
logger.Info("Panic info is: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户授权验证
|
|
||||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
if !needLogin(c) {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
|
||||||
var tokenString string
|
|
||||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
|
||||||
if isAdminApi { // 后台管理 API
|
|
||||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
|
||||||
} else if clientProtocols != "" { // Websocket 连接
|
|
||||||
// 解析子协议内容
|
|
||||||
protocols := strings.Split(clientProtocols, ",")
|
|
||||||
if protocols[0] == "realtime" {
|
|
||||||
tokenString = strings.TrimSpace(protocols[1][25:])
|
|
||||||
} else if protocols[0] == "token" {
|
|
||||||
tokenString = strings.TrimSpace(protocols[1])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenString == "" {
|
|
||||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
if isAdminApi {
|
|
||||||
return []byte(s.Config.AdminSession.SecretKey), nil
|
|
||||||
} else {
|
|
||||||
return []byte(s.Config.Session.SecretKey), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok || !token.Valid {
|
|
||||||
resp.NotAuth(c, "Token is invalid")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
|
||||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
|
||||||
resp.NotAuth(c, "Token is expired")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
|
||||||
if isAdminApi {
|
|
||||||
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
|
||||||
}
|
|
||||||
if _, err := client.Get(context.Background(), key).Result(); err != nil {
|
|
||||||
resp.NotAuth(c, "Token is not found in redis")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set(types.LoginUserID, claims["user_id"])
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func needLogin(c *gin.Context) bool {
|
|
||||||
path := c.Request.URL.Path
|
|
||||||
|
|
||||||
// 如果不是 API 路径,不需要登录
|
|
||||||
if !strings.HasPrefix(path, "/api") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查精确匹配的路径
|
|
||||||
if skip, exists := authConfig.ExactPaths[path]; exists {
|
|
||||||
return skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查前缀匹配的路径
|
|
||||||
for prefix, skip := range authConfig.PrefixPaths {
|
|
||||||
if strings.HasPrefix(path, prefix) {
|
|
||||||
return skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 跳过授权
|
|
||||||
func (s *AppServer) SkipAuth(url string, prefix bool) {
|
|
||||||
if prefix {
|
|
||||||
authConfig.PrefixPaths[url] = false
|
|
||||||
} else {
|
|
||||||
authConfig.ExactPaths[url] = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统一参数处理
|
|
||||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
// GET 参数处理
|
|
||||||
params := c.Request.URL.Query()
|
|
||||||
for key, values := range params {
|
|
||||||
for i, value := range values {
|
|
||||||
params[key][i] = strings.TrimSpace(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// update get parameters
|
|
||||||
c.Request.URL.RawQuery = params.Encode()
|
|
||||||
// skip file upload requests
|
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
|
||||||
if strings.Contains(contentType, "multipart/form-data") {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(contentType, "application/json") {
|
|
||||||
// process POST JSON request body
|
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 还原请求体
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
// 将请求体解析为 JSON
|
|
||||||
var jsonData map[string]interface{}
|
|
||||||
if err := c.ShouldBindJSON(&jsonData); err != nil {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 对 JSON 数据中的字符串值去除两端空格
|
|
||||||
trimJSONStrings(jsonData)
|
|
||||||
// 更新请求体
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 递归对 JSON 数据中的字符串值去除两端空格
|
|
||||||
func trimJSONStrings(data interface{}) {
|
|
||||||
switch v := data.(type) {
|
|
||||||
case map[string]interface{}:
|
|
||||||
for key, value := range v {
|
|
||||||
switch valueType := value.(type) {
|
|
||||||
case string:
|
|
||||||
v[key] = strings.TrimSpace(valueType)
|
|
||||||
case map[string]interface{}, []interface{}:
|
|
||||||
trimJSONStrings(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for i, value := range v {
|
|
||||||
switch valueType := value.(type) {
|
|
||||||
case string:
|
|
||||||
v[i] = strings.TrimSpace(valueType)
|
|
||||||
case map[string]interface{}, []interface{}:
|
|
||||||
trimJSONStrings(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 静态资源中间件
|
|
||||||
func staticResourceMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
|
|
||||||
url := c.Request.URL.String()
|
|
||||||
// 拦截生成缩略图请求
|
|
||||||
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
|
|
||||||
r := strings.SplitAfter(url, "imageView2")
|
|
||||||
size := strings.Split(r[1], "/")
|
|
||||||
if len(size) != 8 {
|
|
||||||
c.String(http.StatusNotFound, "invalid thumb args")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
with := utils.IntValue(size[3], 0)
|
|
||||||
height := utils.IntValue(size[5], 0)
|
|
||||||
quality := utils.IntValue(size[7], 75)
|
|
||||||
|
|
||||||
// 打开图片文件
|
|
||||||
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
|
|
||||||
file, err := os.Open(filePath)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusNotFound, "Image not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
// 解码图片
|
|
||||||
img, _, err := image.Decode(file)
|
|
||||||
// for .webp image
|
|
||||||
if err != nil {
|
|
||||||
img, err = webp.Decode(file)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var newImg image.Image
|
|
||||||
if height == 0 || with == 0 {
|
|
||||||
// 固定宽度,高度自适应
|
|
||||||
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
|
|
||||||
} else {
|
|
||||||
// 生成缩略图
|
|
||||||
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
|
|
||||||
}
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
c.String(http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置图片缓存有效期为一年 (365天)
|
|
||||||
c.Header("Cache-Control", "max-age=31536000, public")
|
|
||||||
// 直接输出图像数据流
|
|
||||||
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
|
||||||
c.Abort() // 中断请求
|
|
||||||
|
|
||||||
}
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,10 +11,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
@@ -30,7 +32,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
SecretKey: utils.RandString(64),
|
SecretKey: utils.RandString(64),
|
||||||
MaxAge: 86400,
|
MaxAge: 86400,
|
||||||
},
|
},
|
||||||
ApiConfig: types.ApiConfig{},
|
|
||||||
OSS: types.OSSConfig{
|
OSS: types.OSSConfig{
|
||||||
Active: "local",
|
Active: "local",
|
||||||
Local: types.LocalStorageConfig{
|
Local: types.LocalStorageConfig{
|
||||||
@@ -38,7 +39,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
BasePath: "./static/upload",
|
BasePath: "./static/upload",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,3 +74,108 @@ func SaveConfig(config *types.AppConfig) error {
|
|||||||
|
|
||||||
return os.WriteFile(config.Path, buf.Bytes(), 0644)
|
return os.WriteFile(config.Path, buf.Bytes(), 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
|
||||||
|
// 加载系统配置
|
||||||
|
var sysConfig model.Config
|
||||||
|
var baseConfig types.BaseConfig
|
||||||
|
db.Where("name", "system").First(&sysConfig)
|
||||||
|
err := utils.JsonDecode(sysConfig.Value, &baseConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load system config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载许可证配置
|
||||||
|
var license types.License
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyLicense).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &license)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load license config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载验证码配置
|
||||||
|
var captchaConfig types.CaptchaConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyCaptcha).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &captchaConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load geek service config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载微信登录配置
|
||||||
|
var wxLoginConfig types.WxLoginConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyWxLogin).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &wxLoginConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load wx login config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载短信配置
|
||||||
|
var smsConfig types.SMSConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeySms).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &smsConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load sms config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载 OSS 配置
|
||||||
|
var ossConfig types.OSSConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyOss).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &ossConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load oss config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载 SMTP 配置
|
||||||
|
var smtpConfig types.SmtpConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeySmtp).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &smtpConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load smtp config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载支付配置
|
||||||
|
var paymentConfig types.PaymentConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyPayment).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &paymentConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load payment config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载文本审查配置
|
||||||
|
var moderationConfig types.ModerationConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyModeration).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &moderationConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load moderation config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载即梦AI配置
|
||||||
|
var jimengConfig types.JimengConfig
|
||||||
|
sysConfig.Id = 0
|
||||||
|
db.Where("name", types.ConfigKeyJimeng).First(&sysConfig)
|
||||||
|
err = utils.JsonDecode(sysConfig.Value, &jimengConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("load jimeng config error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.SystemConfig{
|
||||||
|
Base: baseConfig,
|
||||||
|
License: license,
|
||||||
|
SMS: smsConfig,
|
||||||
|
OSS: ossConfig,
|
||||||
|
SMTP: smtpConfig,
|
||||||
|
Payment: paymentConfig,
|
||||||
|
Captcha: captchaConfig,
|
||||||
|
WxLogin: wxLoginConfig,
|
||||||
|
Moderation: moderationConfig,
|
||||||
|
Jimeng: jimengConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
109
api/core/middleware/auth.go
Normal file
109
api/core/middleware/auth.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 前端用户授权验证
|
||||||
|
func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
tokenString := c.GetHeader(types.UserAuthHeader)
|
||||||
|
if tokenString == "" {
|
||||||
|
resp.NotAuth(c, "无效的授权令牌")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
return []byte(secretKey), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
resp.NotAuth(c, "令牌无效")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||||
|
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||||
|
resp.NotAuth(c, "令牌过期")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||||
|
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
|
||||||
|
resp.NotAuth(c, "当前用户已退出登录")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(types.LoginUserID, claims["user_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 管理后台用户授权验证
|
||||||
|
func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
tokenString := c.GetHeader(types.AdminAuthHeader)
|
||||||
|
if tokenString == "" {
|
||||||
|
resp.NotAuth(c, "无效的授权令牌")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
return []byte(secretKey), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
resp.NotAuth(c, "令牌无效")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||||
|
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||||
|
resp.NotAuth(c, "令牌过期")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("admin/%v", claims["user_id"])
|
||||||
|
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
|
||||||
|
resp.NotAuth(c, "当前用户已退出登录")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(types.AdminUserID, claims["user_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
80
api/core/middleware/parameter.go
Normal file
80
api/core/middleware/parameter.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"geekai/utils"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 统一参数处理
|
||||||
|
func ParameterHandlerMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// GET 参数处理
|
||||||
|
params := c.Request.URL.Query()
|
||||||
|
for key, values := range params {
|
||||||
|
for i, value := range values {
|
||||||
|
params[key][i] = strings.TrimSpace(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// update get parameters
|
||||||
|
c.Request.URL.RawQuery = params.Encode()
|
||||||
|
// skip file upload requests
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
if strings.Contains(contentType, "multipart/form-data") {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
// process POST JSON request body
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 还原请求体
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
// 将请求体解析为 JSON
|
||||||
|
var jsonData map[string]any
|
||||||
|
if err := c.ShouldBindJSON(&jsonData); err != nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对 JSON 数据中的字符串值去除两端空格
|
||||||
|
trimJSONStrings(jsonData)
|
||||||
|
// 更新请求体
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 递归对 JSON 数据中的字符串值去除两端空格
|
||||||
|
func trimJSONStrings(data any) {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
for key, value := range v {
|
||||||
|
switch valueType := value.(type) {
|
||||||
|
case string:
|
||||||
|
v[key] = strings.TrimSpace(valueType)
|
||||||
|
case map[string]any, []any:
|
||||||
|
trimJSONStrings(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for i, value := range v {
|
||||||
|
switch valueType := value.(type) {
|
||||||
|
case string:
|
||||||
|
v[i] = strings.TrimSpace(valueType)
|
||||||
|
case map[string]any, []any:
|
||||||
|
trimJSONStrings(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
43
api/core/middleware/rate_limit.go
Normal file
43
api/core/middleware/rate_limit.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求
|
||||||
|
// Key 优先使用登录用户ID,若没有则退化为 route + IP
|
||||||
|
func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
keyID := ""
|
||||||
|
if userID, ok := c.Get(types.LoginUserID); ok {
|
||||||
|
keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID))
|
||||||
|
} else {
|
||||||
|
keyID = fmt.Sprintf("ip:%s", c.ClientIP())
|
||||||
|
}
|
||||||
|
|
||||||
|
fullPath := c.FullPath()
|
||||||
|
if fullPath == "" {
|
||||||
|
fullPath = c.Request.URL.Path
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("rl:%s:%s", fullPath, keyID)
|
||||||
|
|
||||||
|
okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result()
|
||||||
|
if err != nil {
|
||||||
|
// Redis 异常时放行,避免误伤可用性
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !okSet {
|
||||||
|
c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
78
api/core/middleware/static.go
Normal file
78
api/core/middleware/static.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"geekai/utils"
|
||||||
|
"image"
|
||||||
|
"image/jpeg"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/nfnt/resize"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 静态资源中间件
|
||||||
|
func StaticMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
|
||||||
|
url := c.Request.URL.String()
|
||||||
|
// 拦截生成缩略图请求
|
||||||
|
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
|
||||||
|
r := strings.SplitAfter(url, "imageView2")
|
||||||
|
size := strings.Split(r[1], "/")
|
||||||
|
if len(size) != 8 {
|
||||||
|
c.String(http.StatusNotFound, "invalid thumb args")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
with := utils.IntValue(size[3], 0)
|
||||||
|
height := utils.IntValue(size[5], 0)
|
||||||
|
quality := utils.IntValue(size[7], 75)
|
||||||
|
|
||||||
|
// 打开图片文件
|
||||||
|
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusNotFound, "Image not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// 解码图片
|
||||||
|
img, _, err := image.Decode(file)
|
||||||
|
// for .webp image
|
||||||
|
if err != nil {
|
||||||
|
img, err = webp.Decode(file)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var newImg image.Image
|
||||||
|
if height == 0 || with == 0 {
|
||||||
|
// 固定宽度,高度自适应
|
||||||
|
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
|
||||||
|
} else {
|
||||||
|
// 生成缩略图
|
||||||
|
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
|
||||||
|
}
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置图片缓存有效期为一年 (365天)
|
||||||
|
c.Header("Cache-Control", "max-age=31536000, public")
|
||||||
|
// 直接输出图像数据流
|
||||||
|
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
||||||
|
c.Abort() // 中断请求
|
||||||
|
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,88 +17,17 @@ type AppConfig struct {
|
|||||||
Session Session
|
Session Session
|
||||||
AdminSession Session
|
AdminSession Session
|
||||||
ProxyURL string
|
ProxyURL string
|
||||||
MysqlDns string // mysql 连接地址
|
MysqlDns string // mysql 连接地址
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
SMS SMSConfig // send mobile message config
|
||||||
SMS SMSConfig // send mobile message config
|
OSS OSSConfig // OSS config
|
||||||
OSS OSSConfig // OSS config
|
SmtpConfig SmtpConfig // 邮件发送配置
|
||||||
SmtpConfig SmtpConfig // 邮件发送配置
|
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||||
XXLConfig XXLConfig
|
GeekPayConfig EpayConfig // GEEK 支付配置
|
||||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
WechatPayConfig WxPayConfig // 微信支付渠道配置
|
||||||
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
TikaHost string // TiKa 服务器地址
|
||||||
GeekPayConfig GeekPayConfig // GEEK 支付配置
|
|
||||||
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
|
||||||
TikaHost string // TiKa 服务器地址
|
|
||||||
}
|
|
||||||
|
|
||||||
type SmtpConfig struct {
|
|
||||||
UseTls bool // 是否使用 TLS 发送
|
|
||||||
Host string
|
|
||||||
Port int
|
|
||||||
AppName string // 应用名称
|
|
||||||
From string // 发件人邮箱地址
|
|
||||||
Password string // 发件人邮箱密码
|
|
||||||
}
|
|
||||||
|
|
||||||
type ApiConfig struct {
|
|
||||||
ApiURL string
|
|
||||||
AppId string
|
|
||||||
Token string
|
|
||||||
JimengConfig JimengConfig // 即梦AI配置
|
|
||||||
}
|
|
||||||
|
|
||||||
type AlipayConfig struct {
|
|
||||||
Enabled bool // 是否启用该支付通道
|
|
||||||
SandBox bool // 是否沙盒环境
|
|
||||||
AppId string // 应用 ID
|
|
||||||
UserId string // 支付宝用户 ID
|
|
||||||
PrivateKey string // 用户私钥文件路径
|
|
||||||
PublicKey string // 用户公钥文件路径
|
|
||||||
AlipayPublicKey string // 支付宝公钥文件路径
|
|
||||||
RootCert string // Root 秘钥路径
|
|
||||||
NotifyURL string // 异步通知地址
|
|
||||||
ReturnURL string // 同步通知地址
|
|
||||||
}
|
|
||||||
|
|
||||||
type WechatPayConfig struct {
|
|
||||||
Enabled bool // 是否启用该支付通道
|
|
||||||
AppId string // 公众号的APPID,如:wxd678efh567hg6787
|
|
||||||
MchId string // 直连商户的商户号,由微信支付生成并下发
|
|
||||||
SerialNo string // 商户证书的证书序列号
|
|
||||||
PrivateKey string // 用户私钥文件路径
|
|
||||||
ApiV3Key string // API V3 秘钥
|
|
||||||
NotifyURL string // 异步通知地址
|
|
||||||
}
|
|
||||||
|
|
||||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
|
||||||
Enabled bool // 是否启用该支付通道
|
|
||||||
AppId string // App ID
|
|
||||||
AppSecret string // app 密钥
|
|
||||||
ApiURL string // 支付网关
|
|
||||||
NotifyURL string // 异步通知地址
|
|
||||||
ReturnURL string // 同步通知地址
|
|
||||||
}
|
|
||||||
|
|
||||||
// GeekPayConfig GEEK支付配置
|
|
||||||
type GeekPayConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
AppId string // 商户 ID
|
|
||||||
PrivateKey string // 私钥
|
|
||||||
ApiURL string // API 网关
|
|
||||||
NotifyURL string // 异步通知地址
|
|
||||||
ReturnURL string // 同步通知地址
|
|
||||||
Methods []string // 支付方式
|
|
||||||
}
|
|
||||||
|
|
||||||
type XXLConfig struct { // XXL 任务调度配置
|
|
||||||
Enabled bool
|
|
||||||
ServerAddr string
|
|
||||||
ExecutorIp string
|
|
||||||
ExecutorPort string
|
|
||||||
AccessToken string
|
|
||||||
RegistryKey string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RedisConfig struct {
|
type RedisConfig struct {
|
||||||
@@ -128,32 +57,28 @@ func (c RedisConfig) Url() string {
|
|||||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemConfig struct {
|
type BaseConfig struct {
|
||||||
Title string `json:"title,omitempty"` // 网站标题
|
Title string `json:"title,omitempty"` // 网站标题
|
||||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||||
Logo string `json:"logo,omitempty"` // 圆形 Logo
|
Logo string `json:"logo,omitempty"` // 圆形 Logo
|
||||||
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
|
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
|
||||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
|
||||||
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
|
||||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
|
||||||
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
|
||||||
|
|
||||||
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
||||||
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||||
|
|
||||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间,单位:分钟
|
||||||
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
|
||||||
|
|
||||||
|
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||||
|
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||||
|
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||||
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
|
||||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||||
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||||
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||||
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
|
|
||||||
|
|
||||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||||
|
|
||||||
@@ -163,15 +88,44 @@ type SystemConfig struct {
|
|||||||
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||||
MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo
|
MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo
|
||||||
|
|
||||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||||
Copyright string `json:"copyright"` // 版权信息
|
Copyright string `json:"copyright"` // 版权信息
|
||||||
DefaultNickname string `json:"default_nickname"` // 默认昵称
|
ICP string `json:"icp"` // ICP 备案号
|
||||||
ICP string `json:"icp"` // ICP 备案号
|
GaBeian string `json:"ga_beian"` // 公安备案号
|
||||||
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
|
||||||
|
|
||||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
|
||||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||||
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
|
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
|
||||||
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB
|
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SystemConfig struct {
|
||||||
|
Base BaseConfig
|
||||||
|
Payment PaymentConfig
|
||||||
|
OSS OSSConfig
|
||||||
|
SMS SMSConfig
|
||||||
|
SMTP SmtpConfig
|
||||||
|
Captcha CaptchaConfig
|
||||||
|
WxLogin WxLoginConfig
|
||||||
|
Jimeng JimengConfig
|
||||||
|
License License
|
||||||
|
Moderation ModerationConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配置键名常量
|
||||||
|
const (
|
||||||
|
ConfigKeySystem = "system"
|
||||||
|
ConfigKeyNotice = "notice"
|
||||||
|
ConfigKeyAgreement = "agreement"
|
||||||
|
ConfigKeyPrivacy = "privacy"
|
||||||
|
ConfigKeyMarkMap = "mark_map"
|
||||||
|
ConfigKeyCaptcha = "captcha"
|
||||||
|
ConfigKeyWxLogin = "wx_login"
|
||||||
|
ConfigKeyLicense = "license"
|
||||||
|
ConfigKeySms = "sms"
|
||||||
|
ConfigKeySmtp = "smtp"
|
||||||
|
ConfigKeyOss = "oss"
|
||||||
|
ConfigKeyPayment = "payment"
|
||||||
|
ConfigKeyModeration = "moderation"
|
||||||
|
ConfigKeyAI3D = "ai3d"
|
||||||
|
ConfigKeyJimeng = "jimeng"
|
||||||
|
)
|
||||||
|
|||||||
33
api/core/types/geekai.go
Normal file
33
api/core/types/geekai.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
// GeekAI 增值服务
|
||||||
|
var GeekAPIURL = "https://sapi.geekai.me"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if os.Getenv("GEEK_API_URL") != "" {
|
||||||
|
GeekAPIURL = os.Getenv("GEEK_API_URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptchaConfig 行为验证码配置
|
||||||
|
type CaptchaConfig struct {
|
||||||
|
ApiKey string `json:"api_key"`
|
||||||
|
Type string `json:"type"` // 验证码类型, 可选值: "dot" 或 "slide"
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WxLoginConfig 微信登录配置
|
||||||
|
type WxLoginConfig struct {
|
||||||
|
ApiKey string `json:"api_key"`
|
||||||
|
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
|
||||||
|
Enabled bool `json:"enabled"` // 是否启用微信登录
|
||||||
|
}
|
||||||
73
api/core/types/moderation.go
Normal file
73
api/core/types/moderation.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
// 文本审查
|
||||||
|
type ModerationConfig struct {
|
||||||
|
Enable bool `json:"enable"` // 是否启用文本审查
|
||||||
|
Active string `json:"active"`
|
||||||
|
EnableGuide bool `json:"enable_guide"` // 是否启用模型引导提示词
|
||||||
|
GuidePrompt string `json:"guide_prompt"` // 模型引导提示词
|
||||||
|
Gitee ModerationGiteeConfig `json:"gitee"`
|
||||||
|
Baidu ModerationBaiduConfig `json:"baidu"`
|
||||||
|
Tencent ModerationTencentConfig `json:"tencent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModerationGitee = "gitee"
|
||||||
|
ModerationBaidu = "baidu"
|
||||||
|
ModerationTencent = "tencent"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GiteeAI 文本审查配置
|
||||||
|
type ModerationGiteeConfig struct {
|
||||||
|
ApiKey string `json:"api_key"`
|
||||||
|
Model string `json:"model"` // 文本审核模型
|
||||||
|
}
|
||||||
|
|
||||||
|
// 百度文本审查配置
|
||||||
|
type ModerationBaiduConfig struct {
|
||||||
|
AccessKey string `json:"access_key"`
|
||||||
|
SecretKey string `json:"secret_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 腾讯云文本审查配置
|
||||||
|
type ModerationTencentConfig struct {
|
||||||
|
AccessKey string `json:"access_key"`
|
||||||
|
SecretKey string `json:"secret_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModerationResult struct {
|
||||||
|
Flagged bool `json:"flagged"`
|
||||||
|
Categories map[string]bool `json:"categories"`
|
||||||
|
CategoryScores map[string]float64 `json:"category_scores"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var ModerationCategories = map[string]string{
|
||||||
|
"politic": "内容涉及人物、事件或敏感的政治观点",
|
||||||
|
"porn": "明确的色情内容",
|
||||||
|
"insult": "具有侮辱、攻击性语言、人身攻击或冒犯性表达",
|
||||||
|
"violence": "包含暴力、血腥、攻击行为或煽动暴力的言论",
|
||||||
|
"illegal": "涉及违法活动的内容,如诈骗、赌博等",
|
||||||
|
"terror": "宣扬恐怖主义、极端暴力或煽动恐怖行为的内容",
|
||||||
|
"ad": "垃圾广告或未经许可的推广内容",
|
||||||
|
"spam": "无意义重复内容或诱导性信息",
|
||||||
|
"abuse": "人身攻击、恶意辱骂或侮辱性言论",
|
||||||
|
"polity": "涉及国家政治、领导人或政策的违规讨论内容",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 敏感词来源
|
||||||
|
const (
|
||||||
|
ModerationSourceChat = "chat"
|
||||||
|
ModerationSourceMJ = "mj"
|
||||||
|
ModerationSourceDalle = "dalle"
|
||||||
|
ModerationSourceSD = "sd"
|
||||||
|
ModerationSourceSuno = "suno"
|
||||||
|
ModerationSourceVideo = "video"
|
||||||
|
ModerationSourceJiMeng = "jimeng"
|
||||||
|
)
|
||||||
@@ -11,29 +11,25 @@ type OrderStatus int
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
OrderNotPaid = OrderStatus(0)
|
OrderNotPaid = OrderStatus(0)
|
||||||
OrderScanned = OrderStatus(1) // 已扫码
|
OrderPaidSuccess = OrderStatus(2) // 已支付
|
||||||
OrderPaidSuccess = OrderStatus(2)
|
OrderPaidFailed = OrderStatus(3) // 已关闭
|
||||||
)
|
)
|
||||||
|
|
||||||
type OrderRemark struct {
|
type OrderRemark struct {
|
||||||
Days int `json:"days"` // 有效期
|
Days int `json:"days"` // 有效期
|
||||||
Power int `json:"power"` // 增加算力点数
|
Power int `json:"power"` // 增加算力点数
|
||||||
Name string `json:"name"` // 产品名称
|
Name string `json:"name"` // 产品名称
|
||||||
Price float64 `json:"price"`
|
Price float64 `json:"price"`
|
||||||
Discount float64 `json:"discount"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var PayMethods = map[string]string{
|
// PayChannel 支付渠道
|
||||||
|
var PayChannel = map[string]string{
|
||||||
"alipay": "支付宝商号",
|
"alipay": "支付宝商号",
|
||||||
"wechat": "微信商号",
|
"wxpay": "微信商号",
|
||||||
"hupi": "虎皮椒",
|
"epay": "易支付",
|
||||||
"geek": "易支付",
|
|
||||||
}
|
}
|
||||||
var PayNames = map[string]string{
|
|
||||||
|
var PayWays = map[string]string{
|
||||||
"alipay": "支付宝",
|
"alipay": "支付宝",
|
||||||
"wxpay": "微信支付",
|
"wxpay": "微信支付",
|
||||||
"qqpay": "QQ钱包",
|
|
||||||
"jdpay": "京东支付",
|
|
||||||
"douyin": "抖音支付",
|
|
||||||
"paypal": "PayPal支付",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,41 +8,39 @@ package types
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
type OSSConfig struct {
|
type OSSConfig struct {
|
||||||
Active string
|
Active string `json:"active"`
|
||||||
Local LocalStorageConfig
|
Local LocalStorageConfig `json:"local"`
|
||||||
Minio MiniOssConfig
|
Minio MiniOssConfig `json:"minio"`
|
||||||
QiNiu QiNiuOssConfig
|
QiNiu QiNiuOssConfig `json:"qiniu"`
|
||||||
AliYun AliYunOssConfig
|
AliYun AliYunOssConfig `json:"aliyun"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MiniOssConfig struct {
|
type MiniOssConfig struct {
|
||||||
Endpoint string
|
Endpoint string `json:"endpoint"`
|
||||||
AccessKey string
|
AccessKey string `json:"access_key"`
|
||||||
AccessSecret string
|
AccessSecret string `json:"access_secret"`
|
||||||
Bucket string
|
Bucket string `json:"bucket"`
|
||||||
SubDir string
|
UseSSL bool `json:"use_ssl"`
|
||||||
UseSSL bool
|
Domain string `json:"domain"`
|
||||||
Domain string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type QiNiuOssConfig struct {
|
type QiNiuOssConfig struct {
|
||||||
Zone string
|
Zone string `json:"zone"`
|
||||||
AccessKey string
|
AccessKey string `json:"access_key"`
|
||||||
AccessSecret string
|
AccessSecret string `json:"access_secret"`
|
||||||
Bucket string
|
Bucket string `json:"bucket"`
|
||||||
SubDir string
|
Domain string `json:"domain"`
|
||||||
Domain string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliYunOssConfig struct {
|
type AliYunOssConfig struct {
|
||||||
Endpoint string
|
Endpoint string `json:"endpoint"`
|
||||||
AccessKey string
|
AccessKey string `json:"access_key"`
|
||||||
AccessSecret string
|
AccessSecret string `json:"access_secret"`
|
||||||
Bucket string
|
Bucket string `json:"bucket"`
|
||||||
SubDir string
|
Domain string `json:"domain"`
|
||||||
Domain string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocalStorageConfig struct {
|
type LocalStorageConfig struct {
|
||||||
BasePath string
|
BasePath string `json:"base_path"`
|
||||||
BaseURL string
|
BaseURL string `json:"base_url"`
|
||||||
}
|
}
|
||||||
|
|||||||
60
api/core/types/payment.go
Normal file
60
api/core/types/payment.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type PaymentConfig struct {
|
||||||
|
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
|
||||||
|
Epay EpayConfig `json:"epay"` // 易支付配置
|
||||||
|
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlipayConfig 支付宝支付配置
|
||||||
|
type AlipayConfig struct {
|
||||||
|
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||||
|
SandBox bool `json:"sandbox"` // 是否沙盒环境
|
||||||
|
AppId string `json:"app_id"` // 应用 ID
|
||||||
|
PrivateKey string `json:"private_key"` // 应用私钥
|
||||||
|
AlipayPublicKey string `json:"alipay_public_key"` // 支付宝公钥
|
||||||
|
Domain string `json:"domain"` // 支付回调域名
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
|
||||||
|
return c.AppId == other.AppId &&
|
||||||
|
c.PrivateKey == other.PrivateKey &&
|
||||||
|
c.AlipayPublicKey == other.AlipayPublicKey &&
|
||||||
|
c.Domain == other.Domain
|
||||||
|
}
|
||||||
|
|
||||||
|
// WxPayConfig 微信支付配置
|
||||||
|
type WxPayConfig struct {
|
||||||
|
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||||
|
AppId string `json:"app_id"` // 公众号的APPID,如:wxd678efh567hg6787
|
||||||
|
MchId string `json:"mch_id"` // 直连商户的商户号,由微信支付生成并下发
|
||||||
|
SerialNo string `json:"serial_no"` // 商户证书的证书序列号
|
||||||
|
PrivateKey string `json:"private_key"` // 商户证书私钥
|
||||||
|
ApiV3Key string `json:"api_v3_key"` // API V3 秘钥
|
||||||
|
Domain string `json:"domain"` // 支付回调域名
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
|
||||||
|
return c.AppId == other.AppId &&
|
||||||
|
c.MchId == other.MchId &&
|
||||||
|
c.SerialNo == other.SerialNo &&
|
||||||
|
c.PrivateKey == other.PrivateKey &&
|
||||||
|
c.ApiV3Key == other.ApiV3Key &&
|
||||||
|
c.Domain == other.Domain
|
||||||
|
}
|
||||||
|
|
||||||
|
// EpayConfig 易支付配置
|
||||||
|
type EpayConfig struct {
|
||||||
|
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||||
|
AppId string `json:"app_id"` // 商户 ID
|
||||||
|
PrivateKey string `json:"private_key"` // 私钥
|
||||||
|
ApiURL string `json:"api_url"` // z支付 API 网关
|
||||||
|
Domain string `json:"domain"` // 支付回调域名
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EpayConfig) Equal(other *EpayConfig) bool {
|
||||||
|
return c.AppId == other.AppId &&
|
||||||
|
c.PrivateKey == other.PrivateKey &&
|
||||||
|
c.ApiURL == other.ApiURL &&
|
||||||
|
c.Domain == other.Domain
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ package types
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
const LoginUserID = "LOGIN_USER_ID"
|
const LoginUserID = "LOGIN_USER_ID"
|
||||||
|
const AdminUserID = "ADMIN_USER_ID"
|
||||||
const LoginUserCache = "LOGIN_USER_CACHE"
|
const LoginUserCache = "LOGIN_USER_CACHE"
|
||||||
|
|
||||||
const UserAuthHeader = "Authorization"
|
const UserAuthHeader = "Authorization"
|
||||||
|
|||||||
@@ -8,26 +8,23 @@ package types
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
type SMSConfig struct {
|
type SMSConfig struct {
|
||||||
Active string
|
Active string `json:"active"`
|
||||||
Ali SmsConfigAli
|
Ali SmsConfigAli `json:"aliyun"`
|
||||||
Bao SmsConfigBao
|
Bao SmsConfigBao `json:"bao"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SmsConfigAli 阿里云短信平台配置
|
// SmsConfigAli 阿里云短信平台配置
|
||||||
type SmsConfigAli struct {
|
type SmsConfigAli struct {
|
||||||
AccessKey string
|
AccessKey string `json:"access_key"`
|
||||||
AccessSecret string
|
AccessSecret string `json:"access_secret"`
|
||||||
Product string
|
Sign string `json:"sign"` // 短信签名
|
||||||
Domain string
|
CodeTempId string `json:"code_temp_id"` // 验证码短信模板 ID
|
||||||
Sign string // 短信签名
|
|
||||||
CodeTempId string // 验证码短信模板 ID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SmsConfigBao 短信宝平台配置
|
// SmsConfigBao 短信宝平台配置
|
||||||
type SmsConfigBao struct {
|
type SmsConfigBao struct {
|
||||||
Username string //短信宝平台注册的用户名
|
Username string `json:"username"` //短信宝平台注册的用户名
|
||||||
Password string //短信宝平台注册的密码
|
Password string `json:"password"` //短信宝平台注册的密码
|
||||||
Domain string //域名
|
Sign string `json:"sign"` // 短信签名
|
||||||
Sign string // 短信签名
|
CodeTemplate string `json:"code_template"` // 验证码短信模板 匹配
|
||||||
CodeTemplate string // 验证码短信模板 匹配
|
|
||||||
}
|
}
|
||||||
|
|||||||
26
api/core/types/smtp.go
Normal file
26
api/core/types/smtp.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
type SmtpConfig struct {
|
||||||
|
UseTls bool `json:"use_tls"` // 是否使用 TLS 发送
|
||||||
|
Host string `json:"host"` // 邮件服务器地址
|
||||||
|
Port int `json:"port"` // 邮件服务器端口
|
||||||
|
AppName string `json:"app_name"` // 应用名称
|
||||||
|
From string `json:"from"` // 发件人邮箱地址
|
||||||
|
Password string `json:"password"` // 发件人邮箱密码
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SmtpConfig) Equal(other *SmtpConfig) bool {
|
||||||
|
return s.UseTls == other.UseTls &&
|
||||||
|
s.Host == other.Host &&
|
||||||
|
s.Port == other.Port &&
|
||||||
|
s.AppName == other.AppName &&
|
||||||
|
s.From == other.From &&
|
||||||
|
s.Password == other.Password
|
||||||
|
}
|
||||||
@@ -70,17 +70,18 @@ type SdTaskParams struct {
|
|||||||
|
|
||||||
// DallTask DALL-E task
|
// DallTask DALL-E task
|
||||||
type DallTask struct {
|
type DallTask struct {
|
||||||
ModelId uint `json:"model_id"`
|
ModelId uint `json:"model_id"`
|
||||||
ModelName string `json:"model_name"`
|
ModelName string `json:"model_name"`
|
||||||
Id uint `json:"id"`
|
Image []string `json:"image,omitempty"`
|
||||||
UserId uint `json:"user_id"`
|
Id uint `json:"id"`
|
||||||
Prompt string `json:"prompt"`
|
UserId uint `json:"user_id"`
|
||||||
N int `json:"n"`
|
Prompt string `json:"prompt"`
|
||||||
Quality string `json:"quality"`
|
N int `json:"n"`
|
||||||
Size string `json:"size"`
|
Quality string `json:"quality"`
|
||||||
Style string `json:"style"`
|
Size string `json:"size"`
|
||||||
Power int `json:"power"`
|
Style string `json:"style"`
|
||||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
Power int `json:"power"`
|
||||||
|
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type SunoTask struct {
|
type SunoTask struct {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build_name: runner-build
|
|||||||
build_log: runner-build-errors.log
|
build_log: runner-build-errors.log
|
||||||
valid_ext: .go, .tpl, .tmpl, .html
|
valid_ext: .go, .tpl, .tmpl, .html
|
||||||
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
|
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
|
||||||
ignored: assets, tmp, web, .git, .idea, test, data
|
ignored: assets, tmp, web, .git, .idea, test, data, static
|
||||||
build_delay: 600
|
build_delay: 600
|
||||||
colors: 1
|
colors: 1
|
||||||
log_color_main: cyan
|
log_color_main: cyan
|
||||||
|
|||||||
10
api/go.mod
10
api/go.mod
@@ -24,11 +24,9 @@ require (
|
|||||||
gorm.io/driver/mysql v1.4.7
|
gorm.io/driver/mysql v1.4.7
|
||||||
)
|
)
|
||||||
|
|
||||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-pay/gopay v1.5.101
|
github.com/go-pay/gopay v1.5.101
|
||||||
github.com/go-rod/rod v0.116.2
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/go-tika v0.3.1
|
github.com/google/go-tika v0.3.1
|
||||||
github.com/microcosm-cc/bluemonday v1.0.26
|
github.com/microcosm-cc/bluemonday v1.0.26
|
||||||
github.com/sashabaranov/go-openai v1.38.1
|
github.com/sashabaranov/go-openai v1.38.1
|
||||||
@@ -50,11 +48,6 @@ require (
|
|||||||
github.com/gorilla/css v1.0.0 // indirect
|
github.com/gorilla/css v1.0.0 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||||
github.com/ysmood/fetchup v0.3.0 // indirect
|
|
||||||
github.com/ysmood/goob v0.4.0 // indirect
|
|
||||||
github.com/ysmood/got v0.40.0 // indirect
|
|
||||||
github.com/ysmood/gson v0.7.3 // indirect
|
|
||||||
github.com/ysmood/leakless v0.9.0 // indirect
|
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.uber.org/mock v0.4.0 // indirect
|
go.uber.org/mock v0.4.0 // indirect
|
||||||
)
|
)
|
||||||
@@ -69,7 +62,6 @@ require (
|
|||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||||
github.com/gaukas/godicttls v0.0.3 // indirect
|
github.com/gaukas/godicttls v0.0.3 // indirect
|
||||||
github.com/go-basic/ipv4 v1.0.0 // indirect
|
|
||||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
|||||||
22
api/go.sum
22
api/go.sum
@@ -46,8 +46,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
|||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
|
|
||||||
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
|
|
||||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
@@ -80,8 +78,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
|||||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||||
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
|
|
||||||
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
|
|
||||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||||
@@ -89,6 +85,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
|
|||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||||
@@ -261,22 +259,6 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
|
|||||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
|
||||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
|
||||||
github.com/ysmood/fetchup v0.3.0 h1:UhYz9xnLEVn2ukSuK3KCgcznWpHMdrmbsPpllcylyu8=
|
|
||||||
github.com/ysmood/fetchup v0.3.0/go.mod h1:hbysoq65PXL0NQeNzUczNYIKpwpkwFL4LXMDEvIQq9A=
|
|
||||||
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
|
|
||||||
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
|
|
||||||
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
|
|
||||||
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
|
|
||||||
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
|
|
||||||
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
|
|
||||||
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
|
|
||||||
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
|
|
||||||
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
|
|
||||||
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
|
|
||||||
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
|
|
||||||
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
|
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
@@ -19,9 +20,10 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -45,6 +47,26 @@ func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, cap
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ManagerHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.POST("login", h.Login)
|
||||||
|
group.GET("logout", h.Logout)
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("session", h.Session)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("resetPass", h.ResetPass)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Login 登录
|
// Login 登录
|
||||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
@@ -59,19 +81,6 @@ func (h *ManagerHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.App.SysConfig.EnabledVerify {
|
|
||||||
var check bool
|
|
||||||
if data.X != 0 {
|
|
||||||
check = h.captcha.SlideCheck(data)
|
|
||||||
} else {
|
|
||||||
check = h.captcha.Check(data)
|
|
||||||
}
|
|
||||||
if !check {
|
|
||||||
resp.ERROR(c, "请先完人机验证")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var manager model.AdminUser
|
var manager model.AdminUser
|
||||||
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
@@ -135,16 +144,15 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
|
|||||||
|
|
||||||
// Session 会话检测
|
// Session 会话检测
|
||||||
func (h *ManagerHandler) Session(c *gin.Context) {
|
func (h *ManagerHandler) Session(c *gin.Context) {
|
||||||
id := h.GetLoginUserId(c)
|
id := h.GetAdminId(c)
|
||||||
key := fmt.Sprintf("admin/%d", id)
|
if id == 0 {
|
||||||
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
|
resp.NotAuth(c, "当前用户已退出登录")
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var manager model.AdminUser
|
var manager model.AdminUser
|
||||||
res := h.DB.Where("id", id).First(&manager)
|
err := h.DB.Where("id", id).First(&manager).Error
|
||||||
if res.Error != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c, "当前用户已退出登录")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -30,6 +31,20 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
|||||||
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
|
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ApiKeyHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/apikey/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("set", h.Set)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -30,14 +31,29 @@ func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
|||||||
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatAppHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/role/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
group.POST("set", h.Set)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Save 创建或者更新某个角色
|
// Save 创建或者更新某个角色
|
||||||
func (h *ChatAppHandler) Save(c *gin.Context) {
|
func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||||
var data vo.ChatRole
|
var data vo.ChatApp
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var role model.ChatRole
|
var role model.ChatApp
|
||||||
err := utils.CopyObject(data, &role)
|
err := utils.CopyObject(data, &role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -65,8 +81,8 @@ func (h *ChatAppHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatAppHandler) List(c *gin.Context) {
|
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatRole
|
var items []model.ChatApp
|
||||||
var roles = make([]vo.ChatRole, 0)
|
var roles = make([]vo.ChatApp, 0)
|
||||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No data found")
|
resp.ERROR(c, "No data found")
|
||||||
@@ -107,7 +123,7 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var role vo.ChatRole
|
var role vo.ChatApp
|
||||||
err := utils.CopyObject(v, &role)
|
err := utils.CopyObject(v, &role)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
role.Id = v.Id
|
role.Id = v.Id
|
||||||
@@ -135,7 +151,7 @@ func (h *ChatAppHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
@@ -157,7 +173,7 @@ func (h *ChatAppHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
@@ -172,9 +188,8 @@ func (h *ChatAppHandler) Remove(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
|
res := h.DB.Where("id", id).Delete(&model.ChatApp{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("error with update database:", res.Error)
|
|
||||||
resp.ERROR(c, "删除失败!")
|
resp.ERROR(c, "删除失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -20,6 +22,21 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
|||||||
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatAppTypeHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/app/type/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Save 创建或更新App类型
|
// Save 创建或更新App类型
|
||||||
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
|
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -28,16 +29,31 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
|||||||
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/chat/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("list", h.List)
|
||||||
|
group.POST("message", h.Messages)
|
||||||
|
group.GET("history", h.History)
|
||||||
|
group.GET("remove", h.RemoveChat)
|
||||||
|
group.GET("message/remove", h.RemoveMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type chatItemVo struct {
|
type chatItemVo struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
UserId uint `json:"user_id"`
|
UserId uint `json:"user_id"`
|
||||||
ChatId string `json:"chat_id"`
|
ChatId string `json:"chat_id"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Role vo.ChatRole `json:"role"`
|
Role vo.ChatApp `json:"role"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Token int `json:"token"`
|
Token int `json:"token"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
MsgNum int `json:"msg_num"` // 消息数量
|
MsgNum int `json:"msg_num"` // 消息数量
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) List(c *gin.Context) {
|
func (h *ChatHandler) List(c *gin.Context) {
|
||||||
@@ -87,7 +103,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var messages []model.ChatMessage
|
var messages []model.ChatMessage
|
||||||
var users []model.User
|
var users []model.User
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatApp
|
||||||
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
||||||
h.DB.Where("id IN ?", userIds).Find(&users)
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||||
@@ -95,7 +111,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
tokenMap := make(map[string]int)
|
tokenMap := make(map[string]int)
|
||||||
userMap := make(map[uint]string)
|
userMap := make(map[uint]string)
|
||||||
msgMap := make(map[string]int)
|
msgMap := make(map[string]int)
|
||||||
roleMap := make(map[uint]vo.ChatRole)
|
roleMap := make(map[uint]vo.ChatApp)
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
tokenMap[msg.ChatId] += msg.Tokens
|
tokenMap[msg.ChatId] += msg.Tokens
|
||||||
msgMap[msg.ChatId] += 1
|
msgMap[msg.ChatId] += 1
|
||||||
@@ -104,7 +120,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
userMap[user.Id] = user.Username
|
userMap[user.Id] = user.Username
|
||||||
}
|
}
|
||||||
for _, r := range roles {
|
for _, r := range roles {
|
||||||
var roleVo vo.ChatRole
|
var roleVo vo.ChatApp
|
||||||
err := utils.CopyObject(r, &roleVo)
|
err := utils.CopyObject(r, &roleVo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ package admin
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -28,6 +30,22 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
|||||||
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatModelHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/model/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("set", h.Set)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("batch-remove", h.BatchRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ChatModelHandler) Save(c *gin.Context) {
|
func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
@@ -201,3 +219,33 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchRemove 批量删除模型
|
||||||
|
func (h *ChatModelHandler) BatchRemove(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Ids []uint `json:"ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data.Ids) == 0 {
|
||||||
|
resp.ERROR(c, "请选择要删除的模型")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行批量删除
|
||||||
|
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.ChatModel{}).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("批量删除模型失败:", err)
|
||||||
|
resp.ERROR(c, "批量删除失败:"+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, gin.H{
|
||||||
|
"message": fmt.Sprintf("成功删除 %d 个模型", len(data.Ids)),
|
||||||
|
"deleted_count": len(data.Ids),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,106 +9,399 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/store"
|
"geekai/service/oss"
|
||||||
|
"geekai/service/payment"
|
||||||
|
"geekai/service/sms"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/shirou/gopsutil/host"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConfigHandler struct {
|
type ConfigHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
levelDB *store.LevelDB
|
licenseService *service.LicenseService
|
||||||
licenseService *service.LicenseService
|
sysConfig *types.SystemConfig
|
||||||
|
alipayService *payment.AlipayService
|
||||||
|
wxpayService *payment.WxPayService
|
||||||
|
epayService *payment.EPayService
|
||||||
|
smsManager *sms.SmsManager
|
||||||
|
uploaderManager *oss.UploaderManager
|
||||||
|
smtpService *service.SmtpService
|
||||||
|
captchaService *service.CaptchaService
|
||||||
|
wxLoginService *service.WxLoginService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
|
func NewConfigHandler(
|
||||||
|
app *core.AppServer,
|
||||||
|
db *gorm.DB,
|
||||||
|
licenseService *service.LicenseService,
|
||||||
|
sysConfig *types.SystemConfig,
|
||||||
|
alipayService *payment.AlipayService,
|
||||||
|
wxpayService *payment.WxPayService,
|
||||||
|
epayService *payment.EPayService,
|
||||||
|
smsManager *sms.SmsManager,
|
||||||
|
uploaderManager *oss.UploaderManager,
|
||||||
|
smtpService *service.SmtpService,
|
||||||
|
captchaService *service.CaptchaService,
|
||||||
|
wxLoginService *service.WxLoginService,
|
||||||
|
) *ConfigHandler {
|
||||||
return &ConfigHandler{
|
return &ConfigHandler{
|
||||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
levelDB: levelDB,
|
licenseService: licenseService,
|
||||||
licenseService: licenseService,
|
sysConfig: sysConfig,
|
||||||
|
alipayService: alipayService,
|
||||||
|
wxpayService: wxpayService,
|
||||||
|
epayService: epayService,
|
||||||
|
smsManager: smsManager,
|
||||||
|
uploaderManager: uploaderManager,
|
||||||
|
smtpService: smtpService,
|
||||||
|
captchaService: captchaService,
|
||||||
|
wxLoginService: wxLoginService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ConfigHandler) Update(c *gin.Context) {
|
// RegisterRoutes 注册路由
|
||||||
var data struct {
|
func (h *ConfigHandler) RegisterRoutes() {
|
||||||
Key string `json:"key"`
|
rg := h.App.Engine.Group("/api/admin/config")
|
||||||
Config struct {
|
|
||||||
types.SystemConfig
|
// 需要管理员登录的接口
|
||||||
Content string `json:"content,omitempty"`
|
rg.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
Updated bool `json:"updated,omitempty"`
|
{
|
||||||
} `json:"config"`
|
rg.POST("update/base", h.UpdateBase)
|
||||||
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
|
rg.POST("update/power", h.UpdatePower)
|
||||||
|
rg.POST("update/notice", h.UpdateNotice)
|
||||||
|
rg.POST("update/agreement", h.UpdateAgreement)
|
||||||
|
rg.POST("update/privacy", h.UpdatePrivacy)
|
||||||
|
rg.POST("update/mark_map", h.UpdateMarkMap)
|
||||||
|
rg.POST("update/captcha", h.UpdateCaptcha)
|
||||||
|
rg.POST("update/wx_login", h.UpdateWxLogin)
|
||||||
|
rg.POST("update/payment", h.UpdatePayment)
|
||||||
|
rg.POST("update/sms", h.UpdateSms)
|
||||||
|
rg.POST("update/oss", h.UpdateOss)
|
||||||
|
rg.POST("update/smtp", h.UpdateStmp)
|
||||||
|
rg.GET("get", h.Get)
|
||||||
|
rg.POST("license/active", h.Active)
|
||||||
|
rg.GET("license/get", h.GetLicense)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBase 更新基础配置
|
||||||
|
func (h *ConfigHandler) UpdateBase(c *gin.Context) {
|
||||||
|
var data types.BaseConfig
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
logger.Errorf("Update config failed: %v", err)
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ONLY authorized user can change the copyright
|
// 未授权的话不允许修改版权
|
||||||
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
|
license := h.licenseService.GetLicense()
|
||||||
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
|
if !license.IsActive && data.Copyright != h.sysConfig.Base.Copyright {
|
||||||
|
resp.ERROR(c, "未授权系统不允许修改版权信息")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果要启用图形验证码功能,则检查是否配置了 API 服务
|
// 未授权的话不允许修改 Logo
|
||||||
if data.Config.EnabledVerify && h.App.Config.ApiConfig.AppId == "" {
|
if !license.IsActive && data.Logo != h.sysConfig.Base.Logo {
|
||||||
resp.ERROR(c, "启用验证码服务需要先配置 GeekAI 官方 API 服务 AppId 和 Token")
|
resp.ERROR(c, "未授权系统不允许修改 Logo")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
value := utils.JsonEncode(&data.Config)
|
err := h.Update(types.ConfigKeySystem, data)
|
||||||
config := model.Config{Name: data.Key, Value: value}
|
if err != nil {
|
||||||
res := h.DB.FirstOrCreate(&config, model.Config{Name: data.Key})
|
resp.ERROR(c, err.Error())
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, res.Error.Error())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Id > 0 {
|
h.sysConfig.Base = data
|
||||||
config.Value = value
|
|
||||||
res := h.DB.Updates(&config)
|
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, res.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// update config cache for AppServer
|
resp.SUCCESS(c, data)
|
||||||
var cfg model.Config
|
|
||||||
h.DB.Where("name", data.Key).First(&cfg)
|
|
||||||
var err error
|
|
||||||
if data.Key == "system" {
|
|
||||||
err = utils.JsonDecode(cfg.Value, &h.App.SysConfig)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Infof("Update AppServer's config successfully: %v", config.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, config)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get 获取指定的系统配置
|
// UpdatePower 更新系统配置
|
||||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
func (h *ConfigHandler) UpdatePower(c *gin.Context) {
|
||||||
key := c.Query("key")
|
var data struct {
|
||||||
|
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||||
|
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||||
|
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||||
|
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||||
|
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||||
|
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||||
|
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||||
|
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||||
|
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.sysConfig.Base.InitPower = data.InitPower
|
||||||
|
h.sysConfig.Base.DailyPower = data.DailyPower
|
||||||
|
h.sysConfig.Base.InvitePower = data.InvitePower
|
||||||
|
h.sysConfig.Base.MjPower = data.MjPower
|
||||||
|
h.sysConfig.Base.MjActionPower = data.MjActionPower
|
||||||
|
h.sysConfig.Base.SdPower = data.SdPower
|
||||||
|
h.sysConfig.Base.SunoPower = data.SunoPower
|
||||||
|
h.sysConfig.Base.LumaPower = data.LumaPower
|
||||||
|
h.sysConfig.Base.KeLingPowers = data.KeLingPowers
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeySystem, h.sysConfig.Base)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, h.sysConfig.Base)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateNotice 更新公告配置
|
||||||
|
func (h *ConfigHandler) UpdateNotice(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyNotice, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAgreement 更新用户协议配置
|
||||||
|
func (h *ConfigHandler) UpdateAgreement(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyAgreement, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePrivacy 更新隐私政策配置
|
||||||
|
func (h *ConfigHandler) UpdatePrivacy(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyPrivacy, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMarkMap 更新思维导图配置
|
||||||
|
func (h *ConfigHandler) UpdateMarkMap(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyMarkMap, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCaptcha 更新行为验证码配置
|
||||||
|
func (h *ConfigHandler) UpdateCaptcha(c *gin.Context) {
|
||||||
|
var data types.CaptchaConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyCaptcha, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.captchaService.UpdateConfig(data)
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePayment 更新支付配置
|
||||||
|
func (h *ConfigHandler) UpdatePayment(c *gin.Context) {
|
||||||
|
var data types.PaymentConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyPayment, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果启用状态发生改变,则需要更新支付服务配置
|
||||||
|
if data.WxPay.Enabled {
|
||||||
|
err = h.wxpayService.UpdateConfig(&data.WxPay)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Epay.Enabled {
|
||||||
|
h.epayService.UpdateConfig(&data.Epay)
|
||||||
|
}
|
||||||
|
if data.Alipay.Enabled {
|
||||||
|
err = h.alipayService.UpdateConfig(&data.Alipay)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.sysConfig.Payment = data
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSms 更新短信配置
|
||||||
|
func (h *ConfigHandler) UpdateSms(c *gin.Context) {
|
||||||
|
var data types.SMSConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeySms, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新服务配置
|
||||||
|
h.smsManager.UpdateConfig(data)
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOss 更新 Oss 配置
|
||||||
|
func (h *ConfigHandler) UpdateOss(c *gin.Context) {
|
||||||
|
var data types.OSSConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeyOss, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新服务配置
|
||||||
|
h.uploaderManager.UpdateConfig(data)
|
||||||
|
h.sysConfig.OSS = data
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStmp 更新 Stmp 配置
|
||||||
|
func (h *ConfigHandler) UpdateStmp(c *gin.Context) {
|
||||||
|
var data types.SmtpConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Update(types.ConfigKeySmtp, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新服务配置
|
||||||
|
h.smtpService.UpdateConfig(&data)
|
||||||
|
h.sysConfig.SMTP = data
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWxLogin 更新微信登录配置
|
||||||
|
func (h *ConfigHandler) UpdateWxLogin(c *gin.Context) {
|
||||||
|
var data types.WxLoginConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.Update(types.ConfigKeyWxLogin, data)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Enabled {
|
||||||
|
h.wxLoginService.UpdateConfig(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.sysConfig.WxLogin = data
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新系统配置
|
||||||
|
func (h *ConfigHandler) Update(name string, value any) error {
|
||||||
var config model.Config
|
var config model.Config
|
||||||
res := h.DB.Where("name", key).First(&config)
|
err := h.DB.Where("name", name).First(&config).Error
|
||||||
|
if err != nil { // 不存在则创建
|
||||||
|
config.Name = name
|
||||||
|
config.Value = utils.JsonEncode(value)
|
||||||
|
return h.DB.Create(&config).Error
|
||||||
|
} else { // 存在则更新
|
||||||
|
config.Value = utils.JsonEncode(value)
|
||||||
|
return h.DB.Updates(&config).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取指定名称的系统配置
|
||||||
|
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||||
|
name := c.Query("key")
|
||||||
|
var config model.Config
|
||||||
|
res := h.DB.Where("name", name).First(&config)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var value map[string]interface{}
|
var value map[string]any
|
||||||
err := utils.JsonDecode(config.Value, &value)
|
err := utils.JsonDecode(config.Value, &value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
@@ -127,19 +420,21 @@ func (h *ConfigHandler) Active(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
info, err := host.Info()
|
|
||||||
|
err := h.licenseService.ActiveLicense(data.License)
|
||||||
|
license := h.licenseService.GetLicense()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := h.Update(types.ConfigKeyLicense, license); err != nil {
|
||||||
err = h.licenseService.ActiveLicense(data.License, info.HostID)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 更新系统配置
|
||||||
|
h.sysConfig.License = *license
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c, license.MachineId)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,69 +443,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
|||||||
license := h.licenseService.GetLicense()
|
license := h.licenseService.GetLicense()
|
||||||
resp.SUCCESS(c, license)
|
resp.SUCCESS(c, license)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FixData 修复数据
|
|
||||||
func (h *ConfigHandler) FixData(c *gin.Context) {
|
|
||||||
resp.ERROR(c, "当前升级版本没有数据需要修正!")
|
|
||||||
//var fixed bool
|
|
||||||
//version := "data_fix_4.1.4"
|
|
||||||
//err := h.levelDB.Get(version, &fixed)
|
|
||||||
//if err == nil || fixed {
|
|
||||||
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//tx := h.DB.Begin()
|
|
||||||
//var users []model.User
|
|
||||||
//err = tx.Find(&users).Error
|
|
||||||
//if err != nil {
|
|
||||||
// resp.ERROR(c, err.Error())
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//for _, user := range users {
|
|
||||||
// if user.Email != "" || user.Mobile != "" {
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// if utils.IsValidEmail(user.Username) {
|
|
||||||
// user.Email = user.Username
|
|
||||||
// } else if utils.IsValidMobile(user.Username) {
|
|
||||||
// user.Mobile = user.Username
|
|
||||||
// }
|
|
||||||
// err = tx.Save(&user).Error
|
|
||||||
// if err != nil {
|
|
||||||
// resp.ERROR(c, err.Error())
|
|
||||||
// tx.Rollback()
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//var orders []model.Order
|
|
||||||
//err = h.DB.Find(&orders).Error
|
|
||||||
//if err != nil {
|
|
||||||
// resp.ERROR(c, err.Error())
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//for _, order := range orders {
|
|
||||||
// if order.PayWay == "支付宝" {
|
|
||||||
// order.PayWay = "alipay"
|
|
||||||
// order.PayType = "alipay"
|
|
||||||
// } else if order.PayWay == "微信支付" {
|
|
||||||
// order.PayWay = "wechat"
|
|
||||||
// order.PayType = "wxpay"
|
|
||||||
// } else if order.PayWay == "hupi" {
|
|
||||||
// order.PayType = "wxpay"
|
|
||||||
// }
|
|
||||||
// err = tx.Save(&order).Error
|
|
||||||
// if err != nil {
|
|
||||||
// resp.ERROR(c, err.Error())
|
|
||||||
// tx.Rollback()
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//tx.Commit()
|
|
||||||
//err = h.levelDB.Put(version, true)
|
|
||||||
//if err != nil {
|
|
||||||
// resp.ERROR(c, err.Error())
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ import (
|
|||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
@@ -27,46 +28,161 @@ func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
|||||||
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *DashboardHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/dashboard/")
|
||||||
|
group.GET("stats", h.Stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// statsVo 增加 recentOrders、recentUsers 字段
|
||||||
|
// 最近订单
|
||||||
|
type OrderBrief struct {
|
||||||
|
OrderNo string `json:"order_no"`
|
||||||
|
Amount float64 `json:"amount"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最近用户
|
||||||
|
type UserBrief struct {
|
||||||
|
Nickname string `json:"nickname"`
|
||||||
|
Avatar string `json:"avatar"`
|
||||||
|
LastActive time.Time `json:"last_active"`
|
||||||
|
}
|
||||||
|
|
||||||
type statsVo struct {
|
type statsVo struct {
|
||||||
Users int64 `json:"users"`
|
Users int64 `json:"users"`
|
||||||
Chats int64 `json:"chats"`
|
Chats int64 `json:"chats"`
|
||||||
Tokens int `json:"tokens"`
|
Tokens int `json:"tokens"`
|
||||||
Income float64 `json:"income"`
|
Income float64 `json:"income"`
|
||||||
Chart map[string]map[string]float64 `json:"chart"`
|
Chart map[string]map[string]float64 `json:"chart"`
|
||||||
|
TodayUsers int64 `json:"todayUsers"`
|
||||||
|
TodayChats int64 `json:"todayChats"`
|
||||||
|
TodayTokens int `json:"todayTokens"`
|
||||||
|
TodayIncome float64 `json:"todayIncome"`
|
||||||
|
TodayOrders int64 `json:"todayOrders"`
|
||||||
|
TodayImageJobs int64 `json:"todayImageJobs"`
|
||||||
|
TodayVideoJobs int64 `json:"todayVideoJobs"`
|
||||||
|
TodayMusicJobs int64 `json:"todayMusicJobs"`
|
||||||
|
Orders int64 `json:"orders"`
|
||||||
|
ImageJobs int64 `json:"imageJobs"`
|
||||||
|
VideoJobs int64 `json:"videoJobs"`
|
||||||
|
MusicJobs int64 `json:"musicJobs"`
|
||||||
|
RecentOrders []OrderBrief `json:"recentOrders"`
|
||||||
|
RecentUsers []UserBrief `json:"recentUsers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DashboardHandler) Stats(c *gin.Context) {
|
func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||||
stats := statsVo{}
|
stats := statsVo{}
|
||||||
// new users statistic
|
|
||||||
var userCount int64
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||||
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
|
||||||
if res.Error == nil {
|
// 总用户数
|
||||||
stats.Users = userCount
|
h.DB.Model(&model.User{}).Count(&stats.Users)
|
||||||
|
|
||||||
|
// 今日新增用户
|
||||||
|
h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&stats.TodayUsers)
|
||||||
|
|
||||||
|
// 总对话数
|
||||||
|
h.DB.Model(&model.ChatItem{}).Count(&stats.Chats)
|
||||||
|
|
||||||
|
// 今日新增对话
|
||||||
|
h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&stats.TodayChats)
|
||||||
|
|
||||||
|
// 总算力消耗
|
||||||
|
var powerLogs []model.PowerLog
|
||||||
|
h.DB.Where("mark = ?", types.PowerSub).Find(&powerLogs)
|
||||||
|
for _, item := range powerLogs {
|
||||||
|
stats.Tokens += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
// new chats statistic
|
// 今日算力消耗
|
||||||
var chatCount int64
|
var todayPowerLogs []model.PowerLog
|
||||||
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", zeroTime).Find(&todayPowerLogs)
|
||||||
if res.Error == nil {
|
for _, item := range todayPowerLogs {
|
||||||
stats.Chats = chatCount
|
stats.TodayTokens += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokens took stats
|
// 总收入
|
||||||
var historyMessages []model.ChatMessage
|
var allOrders []model.Order
|
||||||
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
h.DB.Where("status = ?", types.OrderPaidSuccess).Find(&allOrders)
|
||||||
for _, item := range historyMessages {
|
for _, item := range allOrders {
|
||||||
stats.Tokens += item.Tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// 订单收入
|
|
||||||
var orders []model.Order
|
|
||||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
|
||||||
for _, item := range orders {
|
|
||||||
stats.Income += item.Amount
|
stats.Income += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 今日收入
|
||||||
|
var todayOrders []model.Order
|
||||||
|
h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&todayOrders)
|
||||||
|
for _, item := range todayOrders {
|
||||||
|
stats.TodayIncome += item.Amount
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订单总数
|
||||||
|
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Count(&stats.Orders)
|
||||||
|
|
||||||
|
// 今日订单数
|
||||||
|
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Count(&stats.TodayOrders)
|
||||||
|
|
||||||
|
// 图片生成任务统计
|
||||||
|
var mjJobs, sdJobs, dallJobs, jimengImageJobs int64
|
||||||
|
h.DB.Model(&model.MidJourneyJob{}).Count(&mjJobs)
|
||||||
|
h.DB.Model(&model.SdJob{}).Count(&sdJobs)
|
||||||
|
h.DB.Model(&model.DallJob{}).Count(&dallJobs)
|
||||||
|
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Count(&jimengImageJobs)
|
||||||
|
stats.ImageJobs = mjJobs + sdJobs + dallJobs + jimengImageJobs
|
||||||
|
|
||||||
|
logger.Info("stats.ImageJobs", stats.ImageJobs)
|
||||||
|
|
||||||
|
// 今日图片生成任务统计
|
||||||
|
var todayMjJobs, todaySdJobs, todayDallJobs, todayJimengImageJobs int64
|
||||||
|
h.DB.Model(&model.MidJourneyJob{}).Where("created_at > ?", zeroTime).Count(&todayMjJobs)
|
||||||
|
h.DB.Model(&model.SdJob{}).Where("created_at > ?", zeroTime).Count(&todaySdJobs)
|
||||||
|
h.DB.Model(&model.DallJob{}).Where("created_at > ?", zeroTime).Count(&todayDallJobs)
|
||||||
|
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Where("created_at > ?", zeroTime).Count(&todayJimengImageJobs)
|
||||||
|
stats.TodayImageJobs = todayMjJobs + todaySdJobs + todayDallJobs + todayJimengImageJobs
|
||||||
|
|
||||||
|
// 视频生成任务统计
|
||||||
|
var videoJobs, jimengVideoJobs int64
|
||||||
|
h.DB.Model(&model.VideoJob{}).Count(&videoJobs)
|
||||||
|
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Count(&jimengVideoJobs)
|
||||||
|
stats.VideoJobs = videoJobs + jimengVideoJobs
|
||||||
|
|
||||||
|
// 今日视频生成任务统计
|
||||||
|
var todayVideoJobs, todayJimengVideoJobs int64
|
||||||
|
h.DB.Model(&model.VideoJob{}).Where("created_at > ?", zeroTime).Count(&todayVideoJobs)
|
||||||
|
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Where("created_at > ?", zeroTime).Count(&todayJimengVideoJobs)
|
||||||
|
stats.TodayVideoJobs = todayVideoJobs + todayJimengVideoJobs
|
||||||
|
|
||||||
|
// 音乐生成任务统计
|
||||||
|
h.DB.Model(&model.SunoJob{}).Count(&stats.MusicJobs)
|
||||||
|
|
||||||
|
// 今日音乐生成任务统计
|
||||||
|
h.DB.Model(&model.SunoJob{}).Where("created_at > ?", zeroTime).Count(&stats.TodayMusicJobs)
|
||||||
|
|
||||||
|
// recentOrders: 最近10条已支付订单
|
||||||
|
var orderList []model.Order
|
||||||
|
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Order("created_at desc").Limit(10).Find(&orderList)
|
||||||
|
for _, o := range orderList {
|
||||||
|
stats.RecentOrders = append(stats.RecentOrders, OrderBrief{
|
||||||
|
OrderNo: o.OrderNo,
|
||||||
|
Amount: o.Amount,
|
||||||
|
CreatedAt: o.CreatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// recentUsers: 最近10个注册用户
|
||||||
|
var userList []model.User
|
||||||
|
h.DB.Model(&model.User{}).Order("created_at desc").Limit(10).Find(&userList)
|
||||||
|
for _, u := range userList {
|
||||||
|
lastActive := u.UpdatedAt
|
||||||
|
if lastActive.IsZero() {
|
||||||
|
lastActive = u.CreatedAt
|
||||||
|
}
|
||||||
|
stats.RecentUsers = append(stats.RecentUsers, UserBrief{
|
||||||
|
Nickname: u.Nickname,
|
||||||
|
Avatar: u.Avatar,
|
||||||
|
LastActive: lastActive,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// 统计7天的订单的图表
|
// 统计7天的订单的图表
|
||||||
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
|
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
|
||||||
var statsChart = make(map[string]map[string]float64)
|
var statsChart = make(map[string]map[string]float64)
|
||||||
@@ -81,23 +197,29 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
|||||||
|
|
||||||
// 统计用户7天增加的曲线
|
// 统计用户7天增加的曲线
|
||||||
var users []model.User
|
var users []model.User
|
||||||
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
|
err := h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users).Error
|
||||||
if res.Error == nil {
|
if err == nil {
|
||||||
for _, item := range users {
|
for _, item := range users {
|
||||||
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
|
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统计7天Token 消耗
|
// 统计7天算力消耗
|
||||||
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
|
var chartPowerLogs []model.PowerLog
|
||||||
for _, item := range historyMessages {
|
err = h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", startDate).Find(&chartPowerLogs).Error
|
||||||
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
if err == nil {
|
||||||
|
for _, item := range chartPowerLogs {
|
||||||
|
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Amount)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统计最近7天的订单
|
// 统计最近7天的订单
|
||||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
var orders []model.Order
|
||||||
for _, item := range orders {
|
err = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders).Error
|
||||||
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
if err == nil {
|
||||||
|
for _, item := range orders {
|
||||||
|
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
statsChart["users"] = userStatistic
|
statsChart["users"] = userStatistic
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -30,6 +31,21 @@ func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
|||||||
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *FunctionHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/function/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("set", h.Set)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("token", h.GenToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *FunctionHandler) Save(c *gin.Context) {
|
func (h *FunctionHandler) Save(c *gin.Context) {
|
||||||
var data vo.Function
|
var data vo.Function
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
@@ -119,7 +135,6 @@ func (h *FunctionHandler) GenToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with generate token", err)
|
|
||||||
resp.ERROR(c)
|
resp.ERROR(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
@@ -33,6 +34,20 @@ func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
|||||||
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ImageHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/image/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("list/mj", h.MjList)
|
||||||
|
group.POST("list/sd", h.SdList)
|
||||||
|
group.POST("list/dall", h.DallList)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type imageQuery struct {
|
type imageQuery struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
|||||||
@@ -21,18 +21,18 @@ import (
|
|||||||
// AdminJimengHandler 管理后台即梦AI处理器
|
// AdminJimengHandler 管理后台即梦AI处理器
|
||||||
type AdminJimengHandler struct {
|
type AdminJimengHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
jimengService *jimeng.Service
|
jimengClient *jimeng.Client
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengClient *jimeng.Client, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||||
return &AdminJimengHandler{
|
return &AdminJimengHandler{
|
||||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
jimengService: jimengService,
|
jimengClient: jimengClient,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
uploader: uploader,
|
uploader: uploader,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +43,6 @@ func (h *AdminJimengHandler) RegisterRoutes() {
|
|||||||
rg.GET("/jobs/:id", h.JobDetail)
|
rg.GET("/jobs/:id", h.JobDetail)
|
||||||
rg.POST("/jobs/remove", h.BatchRemove)
|
rg.POST("/jobs/remove", h.BatchRemove)
|
||||||
rg.GET("/stats", h.Stats)
|
rg.GET("/stats", h.Stats)
|
||||||
rg.GET("/config", h.GetConfig)
|
|
||||||
rg.POST("/config/update", h.UpdateConfig)
|
rg.POST("/config/update", h.UpdateConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,12 +212,6 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, result)
|
resp.SUCCESS(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfig 获取即梦AI配置
|
|
||||||
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
|
|
||||||
jimengConfig := h.jimengService.GetConfig()
|
|
||||||
resp.SUCCESS(c, jimengConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConfig 更新即梦AI配置
|
// UpdateConfig 更新即梦AI配置
|
||||||
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||||
var req types.JimengConfig
|
var req types.JimengConfig
|
||||||
@@ -266,31 +259,35 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
|||||||
// 保存配置
|
// 保存配置
|
||||||
tx := h.DB.Begin()
|
tx := h.DB.Begin()
|
||||||
value := utils.JsonEncode(&req)
|
value := utils.JsonEncode(&req)
|
||||||
config := model.Config{Name: "jimeng", Value: value}
|
var exist model.Config
|
||||||
|
tx.Where("name", types.ConfigKeyJimeng).First(&exist)
|
||||||
|
|
||||||
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
|
if exist.Id > 0 {
|
||||||
if err != nil {
|
exist.Value = value
|
||||||
resp.ERROR(c, "保存配置失败: "+err.Error())
|
err := tx.Updates(&exist).Error
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Id > 0 {
|
|
||||||
config.Value = value
|
|
||||||
err = tx.Updates(&config).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "更新配置失败: "+err.Error())
|
resp.ERROR(c, "更新配置失败: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
exist.Name = types.ConfigKeyJimeng
|
||||||
|
exist.Value = value
|
||||||
|
err := tx.Create(&exist).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "创建配置失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新服务中的客户端配置
|
// 更新服务中的客户端配置
|
||||||
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
|
err := h.jimengClient.UpdateConfig(req)
|
||||||
if updateErr != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, updateErr.Error())
|
resp.ERROR(c, err.Error())
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
|
h.App.SysConfig.Jimeng = req
|
||||||
|
|
||||||
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
@@ -33,6 +34,19 @@ func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
|||||||
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *MediaHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/media/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("suno", h.SunoList)
|
||||||
|
group.POST("videos", h.Videos)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type mediaQuery struct {
|
type mediaQuery struct {
|
||||||
Type string `json:"type"` // 任务类型 luma, keling
|
Type string `json:"type"` // 任务类型 luma, keling
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|||||||
@@ -27,6 +27,16 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
|||||||
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *MenuHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/menu/")
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *MenuHandler) Save(c *gin.Context) {
|
func (h *MenuHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
|
|||||||
333
api/handler/admin/moderation_handler.go
Normal file
333
api/handler/admin/moderation_handler.go
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/handler"
|
||||||
|
"geekai/service/moderation"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"geekai/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModerationHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
sysConfig *types.SystemConfig
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModerationHandler(app *core.AppServer, db *gorm.DB, sysConfig *types.SystemConfig, moderationManager *moderation.ServiceManager) *ModerationHandler {
|
||||||
|
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, sysConfig: sysConfig, moderationManager: moderationManager}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ModerationHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/moderation/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("batch-remove", h.BatchRemove)
|
||||||
|
group.GET("source-list", h.GetSourceList)
|
||||||
|
group.POST("config", h.UpdateModeration)
|
||||||
|
group.POST("test", h.TestModeration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 获取文本审核记录列表
|
||||||
|
func (h *ModerationHandler) List(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
StartDate string `json:"start_date"`
|
||||||
|
EndDate string `json:"end_date"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
|
||||||
|
// 构建查询条件
|
||||||
|
if data.Username != "" {
|
||||||
|
// 通过用户名查找用户ID
|
||||||
|
var user model.User
|
||||||
|
if err := h.DB.Where("username LIKE ?", "%"+data.Username+"%").First(&user).Error; err == nil {
|
||||||
|
session = session.Where("user_id", user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Source != "" {
|
||||||
|
session = session.Where("source", data.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.StartDate != "" && data.EndDate != "" {
|
||||||
|
startTime := data.StartDate + " 00:00:00"
|
||||||
|
endTime := data.EndDate + " 23:59:59"
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", startTime, endTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计总数
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.Moderation{}).Count(&total)
|
||||||
|
|
||||||
|
// 分页
|
||||||
|
page := data.Page
|
||||||
|
pageSize := data.PageSize
|
||||||
|
if page <= 0 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
session = session.Offset(offset).Limit(pageSize)
|
||||||
|
|
||||||
|
// 查询数据
|
||||||
|
var items []model.Moderation
|
||||||
|
err := session.Order("id DESC").Find(&items).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息
|
||||||
|
userIds := make([]uint, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
userIds = append(userIds, item.UserId)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []model.User
|
||||||
|
if len(userIds) > 0 {
|
||||||
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
|
}
|
||||||
|
|
||||||
|
userMap := make(map[uint]string)
|
||||||
|
for _, user := range users {
|
||||||
|
userMap[user.Id] = user.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应数据
|
||||||
|
list := make([]map[string]any, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
var moderation types.ModerationResult
|
||||||
|
err := utils.JsonDecode(item.Result, &moderation)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var result []string
|
||||||
|
for value, label := range types.ModerationCategories {
|
||||||
|
if moderation.Categories[value] {
|
||||||
|
result = append(result, label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list = append(list, map[string]any{
|
||||||
|
"id": item.Id,
|
||||||
|
"user_id": item.UserId,
|
||||||
|
"username": userMap[item.UserId],
|
||||||
|
"source": item.Source,
|
||||||
|
"input": item.Input,
|
||||||
|
"output": item.Output,
|
||||||
|
"result": result,
|
||||||
|
"created_at": item.CreatedAt.Unix(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, map[string]any{
|
||||||
|
"items": list,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": pageSize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ModerationHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.DB.Where("id", id).Delete(&model.Moderation{}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchRemove 批量删除文本审核记录
|
||||||
|
func (h *ModerationHandler) BatchRemove(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Ids []uint `json:"ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data.Ids) == 0 {
|
||||||
|
resp.ERROR(c, "请选择要删除的记录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.Moderation{}).Error
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 source 列表
|
||||||
|
func (h *ModerationHandler) GetSourceList(c *gin.Context) {
|
||||||
|
sources := []gin.H{
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceChat,
|
||||||
|
"name": "AI对话",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceMJ,
|
||||||
|
"name": "Midjourney 绘图",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceDalle,
|
||||||
|
"name": "Dalle 绘图",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceSD,
|
||||||
|
"name": "StableDiffusion 绘图",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceSuno,
|
||||||
|
"name": "Suno 音乐",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceVideo,
|
||||||
|
"name": "视频生成",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": types.ModerationSourceJiMeng,
|
||||||
|
"name": "即梦AI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, sources)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModeration 更新文本审查配置
|
||||||
|
func (h *ModerationHandler) UpdateModeration(c *gin.Context) {
|
||||||
|
var data types.ModerationConfig
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var config model.Config
|
||||||
|
err := h.DB.Where("name", types.ConfigKeyModeration).First(&config).Error
|
||||||
|
if err != nil {
|
||||||
|
config.Name = types.ConfigKeyModeration
|
||||||
|
config.Value = utils.JsonEncode(data)
|
||||||
|
err = h.DB.Create(&config).Error
|
||||||
|
} else {
|
||||||
|
config.Value = utils.JsonEncode(data)
|
||||||
|
err = h.DB.Updates(&config).Error
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.moderationManager.UpdateConfig(data)
|
||||||
|
h.sysConfig.Moderation = data
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试结果类型,用于前端显示
|
||||||
|
type ModerationTestResult struct {
|
||||||
|
IsAbnormal bool `json:"isAbnormal"`
|
||||||
|
Details []ModerationTestDetail `json:"details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModerationTestDetail struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
IsCategory bool `json:"isCategory"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModeration 测试文本审查服务
|
||||||
|
func (h *ModerationHandler) TestModeration(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Service string `json:"service"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Text == "" {
|
||||||
|
resp.ERROR(c, "测试文本不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否启用了文本审查
|
||||||
|
if !h.sysConfig.Moderation.Enable {
|
||||||
|
resp.ERROR(c, "文本审查服务未启用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前激活的审核服务
|
||||||
|
service := h.moderationManager.GetService()
|
||||||
|
// 执行文本审核
|
||||||
|
result, err := service.Moderate(data.Text)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "审核服务调用失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为前端需要的格式
|
||||||
|
testResult := ModerationTestResult{
|
||||||
|
IsAbnormal: result.Flagged,
|
||||||
|
Details: make([]ModerationTestDetail, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建详细信息
|
||||||
|
for category, description := range types.ModerationCategories {
|
||||||
|
score := result.CategoryScores[category]
|
||||||
|
isCategory := result.Categories[category]
|
||||||
|
|
||||||
|
testResult.Details = append(testResult.Details, ModerationTestDetail{
|
||||||
|
Category: category,
|
||||||
|
Description: description,
|
||||||
|
Confidence: fmt.Sprintf("%.2f", score),
|
||||||
|
IsCategory: isCategory,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, testResult)
|
||||||
|
}
|
||||||
@@ -29,6 +29,14 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
|||||||
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *OrderHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/order/")
|
||||||
|
group.POST("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("clear", h.Clear)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *OrderHandler) List(c *gin.Context) {
|
func (h *OrderHandler) List(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
OrderNo string `json:"order_no"`
|
OrderNo string `json:"order_no"`
|
||||||
@@ -68,16 +76,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
order.Id = item.Id
|
order.Id = item.Id
|
||||||
order.CreatedAt = item.CreatedAt.Unix()
|
order.CreatedAt = item.CreatedAt.Unix()
|
||||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||||
payMethod, ok := types.PayMethods[item.PayWay]
|
payChannel, ok := types.PayChannel[item.Channel]
|
||||||
if !ok {
|
if !ok {
|
||||||
payMethod = item.PayWay
|
payChannel = item.Channel
|
||||||
}
|
}
|
||||||
payName, ok := types.PayNames[item.PayType]
|
payWays, ok := types.PayWays[item.PayWay]
|
||||||
if !ok {
|
if !ok {
|
||||||
payName = item.PayWay
|
payWays = item.PayWay
|
||||||
}
|
}
|
||||||
order.PayMethod = payMethod
|
order.ChannelName = payChannel
|
||||||
order.PayName = payName
|
order.PayName = payWays
|
||||||
list = append(list, order)
|
list = append(list, order)
|
||||||
} else {
|
} else {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -121,8 +129,8 @@ func (h *OrderHandler) Clear(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
deleteIds := make([]uint, 0)
|
deleteIds := make([]uint, 0)
|
||||||
for _, order := range orders {
|
for _, order := range orders {
|
||||||
// 只删除 15 分钟内的未支付订单
|
// 只删除超时的未支付订单
|
||||||
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
|
if time.Now().After(order.CreatedAt.Add(time.Minute * time.Duration(h.App.SysConfig.Base.OrderPayTimeout))) {
|
||||||
deleteIds = append(deleteIds, order.Id)
|
deleteIds = append(deleteIds, order.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,12 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
|||||||
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *PowerLogHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/powerLog/")
|
||||||
|
group.POST("list", h.List)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
|||||||
@@ -15,9 +15,10 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProductHandler struct {
|
type ProductHandler struct {
|
||||||
@@ -28,14 +29,22 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
|||||||
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ProductHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/product/")
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.POST("sort", h.Sort)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ProductHandler) Save(c *gin.Context) {
|
func (h *ProductHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Price float64 `json:"price"`
|
Price float64 `json:"price"`
|
||||||
Discount float64 `json:"discount"`
|
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Days int `json:"days"`
|
|
||||||
Power int `json:"power"`
|
Power int `json:"power"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
@@ -45,12 +54,10 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
item := model.Product{
|
item := model.Product{
|
||||||
Name: data.Name,
|
Name: data.Name,
|
||||||
Price: data.Price,
|
Price: data.Price,
|
||||||
Discount: data.Discount,
|
Power: data.Power,
|
||||||
Days: data.Days,
|
Enabled: data.Enabled}
|
||||||
Power: data.Power,
|
|
||||||
Enabled: data.Enabled}
|
|
||||||
item.Id = data.Id
|
item.Id = data.Id
|
||||||
if item.Id > 0 {
|
if item.Id > 0 {
|
||||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
|
|||||||
@@ -29,6 +29,16 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
|
|||||||
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *RedeemHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/redeem/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("create", h.Create)
|
||||||
|
group.POST("set", h.Set)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("export", h.Export)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *RedeemHandler) List(c *gin.Context) {
|
func (h *RedeemHandler) List(c *gin.Context) {
|
||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
pageSize := h.GetInt(c, "page_size", 20)
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -28,6 +29,17 @@ func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderMan
|
|||||||
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
|
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *UploadHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/upload")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("", h.Upload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||||
// 判断文件大小
|
// 判断文件大小
|
||||||
f, err := c.FormFile("file")
|
f, err := c.FormFile("file")
|
||||||
@@ -36,7 +48,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 {
|
if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 {
|
||||||
resp.ERROR(c, "文件大小超过限制")
|
resp.ERROR(c, "文件大小超过限制")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
@@ -19,10 +20,9 @@ import (
|
|||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,6 +36,22 @@ func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.Li
|
|||||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
|
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *UserHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/admin/user/")
|
||||||
|
|
||||||
|
// 需要管理员授权的接口
|
||||||
|
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("loginLog", h.LoginLog)
|
||||||
|
group.GET("genLoginLink", h.GenLoginLink)
|
||||||
|
group.POST("resetPass", h.ResetPass)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// List 用户列表
|
// List 用户列表
|
||||||
func (h *UserHandler) List(c *gin.Context) {
|
func (h *UserHandler) List(c *gin.Context) {
|
||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
|
|||||||
@@ -15,9 +15,10 @@ import (
|
|||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"gorm.io/gorm"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,6 +70,14 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
|
|||||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *BaseHandler) GetAdminId(c *gin.Context) uint {
|
||||||
|
userId, ok := c.Get(types.AdminUserID)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||||
|
}
|
||||||
|
|
||||||
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
|
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
|
||||||
return h.GetLoginUserId(c) > 0
|
return h.GetLoginUserId(c) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,23 +8,45 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 今日头条函数实现
|
|
||||||
|
|
||||||
type CaptchaHandler struct {
|
type CaptchaHandler struct {
|
||||||
|
App *core.AppServer
|
||||||
service *service.CaptchaService
|
service *service.CaptchaService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCaptchaHandler(s *service.CaptchaService) *CaptchaHandler {
|
func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler {
|
||||||
return &CaptchaHandler{service: s}
|
return &CaptchaHandler{App: app, service: s}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *CaptchaHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/captcha/")
|
||||||
|
|
||||||
|
// 无需授权的接口
|
||||||
|
group.GET("get", h.Get)
|
||||||
|
group.POST("check", h.Check)
|
||||||
|
group.GET("slide/get", h.SlideGet)
|
||||||
|
group.POST("slide/check", h.SlideCheck)
|
||||||
|
group.GET("config", h.GetConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
|
||||||
|
resp.SUCCESS(c, gin.H{"enabled": h.service.GetConfig().Enabled, "type": h.service.GetConfig().Type})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *CaptchaHandler) Get(c *gin.Context) {
|
func (h *CaptchaHandler) Get(c *gin.Context) {
|
||||||
|
if !h.service.GetConfig().Enabled {
|
||||||
|
resp.ERROR(c, "验证码服务未启用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
data, err := h.service.Get()
|
data, err := h.service.Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
@@ -36,6 +58,11 @@ func (h *CaptchaHandler) Get(c *gin.Context) {
|
|||||||
|
|
||||||
// Check verify the captcha data
|
// Check verify the captcha data
|
||||||
func (h *CaptchaHandler) Check(c *gin.Context) {
|
func (h *CaptchaHandler) Check(c *gin.Context) {
|
||||||
|
if !h.service.GetConfig().Enabled {
|
||||||
|
resp.ERROR(c, "验证码服务未启用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Dots string `json:"dots"`
|
Dots string `json:"dots"`
|
||||||
@@ -55,6 +82,11 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
|
|||||||
|
|
||||||
// SlideGet 获取滑动验证图片
|
// SlideGet 获取滑动验证图片
|
||||||
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||||
|
if !h.service.GetConfig().Enabled {
|
||||||
|
resp.ERROR(c, "验证码服务未启用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
data, err := h.service.SlideGet()
|
data, err := h.service.SlideGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
@@ -66,6 +98,11 @@ func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
|||||||
|
|
||||||
// SlideCheck 滑动验证结果校验
|
// SlideCheck 滑动验证结果校验
|
||||||
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
|
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
|
||||||
|
if !h.service.GetConfig().Enabled {
|
||||||
|
resp.ERROR(c, "验证码服务未启用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
X int `json:"x"`
|
X int `json:"x"`
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
@@ -19,18 +20,31 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatAppHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||||
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &ChatAppHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatAppHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/app/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list/user", h.ListByUser)
|
||||||
|
group.POST("update", h.UpdateApp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 获取用户聊天应用列表
|
// List 获取用户聊天应用列表
|
||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||||
tid := h.GetInt(c, "tid", 0)
|
tid := h.GetInt(c, "tid", 0)
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatApp
|
||||||
session := h.DB.Where("enable", true)
|
session := h.DB.Where("enable", true)
|
||||||
if tid > 0 {
|
if tid > 0 {
|
||||||
session = session.Where("tid", tid)
|
session = session.Where("tid", tid)
|
||||||
@@ -41,9 +55,9 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var roleVos = make([]vo.ChatRole, 0)
|
var roleVos = make([]vo.ChatApp, 0)
|
||||||
for _, r := range roles {
|
for _, r := range roles {
|
||||||
var v vo.ChatRole
|
var v vo.ChatApp
|
||||||
err := utils.CopyObject(r, &v)
|
err := utils.CopyObject(r, &v)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
v.Id = r.Id
|
v.Id = r.Id
|
||||||
@@ -54,10 +68,10 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListByUser 获取用户添加的角色列表
|
// ListByUser 获取用户添加的角色列表
|
||||||
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
func (h *ChatAppHandler) ListByUser(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatApp
|
||||||
session := h.DB.Where("enable", true)
|
session := h.DB.Where("enable", true)
|
||||||
// 如果用户没登录,则获取所有角色
|
// 如果用户没登录,则获取所有角色
|
||||||
if userId > 0 {
|
if userId > 0 {
|
||||||
@@ -86,9 +100,9 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var roleVos = make([]vo.ChatRole, 0)
|
var roleVos = make([]vo.ChatApp, 0)
|
||||||
for _, r := range roles {
|
for _, r := range roles {
|
||||||
var v vo.ChatRole
|
var v vo.ChatApp
|
||||||
err := utils.CopyObject(r, &v)
|
err := utils.CopyObject(r, &v)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
v.Id = r.Id
|
v.Id = r.Id
|
||||||
@@ -98,8 +112,8 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, roleVos)
|
resp.SUCCESS(c, roleVos)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRole 更新用户聊天角色
|
// UpdateApp 更新用户聊天应用
|
||||||
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
func (h *ChatAppHandler) UpdateApp(c *gin.Context) {
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
@@ -19,6 +19,12 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
|||||||
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatAppTypeHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/app/type/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
}
|
||||||
|
|
||||||
// List 获取App类型列表
|
// List 获取App类型列表
|
||||||
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||||
var items []model.AppType
|
var items []model.AppType
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
@@ -39,6 +41,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
ChatEventStart = "start"
|
ChatEventStart = "start"
|
||||||
ChatEventEnd = "end"
|
ChatEventEnd = "end"
|
||||||
|
ChatEventComplete = "complete"
|
||||||
ChatEventError = "error"
|
ChatEventError = "error"
|
||||||
ChatEventMessageDelta = "message_delta"
|
ChatEventMessageDelta = "message_delta"
|
||||||
ChatEventTitle = "title"
|
ChatEventTitle = "title"
|
||||||
@@ -54,44 +57,69 @@ type ChatInput struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Files []vo.File `json:"files"`
|
Files []vo.File `json:"files"`
|
||||||
ChatModel model.ChatModel `json:"chat_model,omitempty"`
|
ChatModel model.ChatModel `json:"chat_model,omitempty"`
|
||||||
ChatRole model.ChatRole `json:"chat_role,omitempty"`
|
ChatRole model.ChatApp `json:"chat_role,omitempty"`
|
||||||
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID,用于重新生成答案的时候过滤上下文
|
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID,用于重新生成答案的时候过滤上下文
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
|
||||||
return &ChatHandler{
|
return &ChatHandler{
|
||||||
BaseHandler: BaseHandler{App: app, DB: db},
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
redis: redis,
|
redis: redis,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
licenseService: licenseService,
|
licenseService: licenseService,
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/chat/")
|
||||||
|
|
||||||
|
// 聊天接口不需要授权(已在authConfig中配置)
|
||||||
|
group.Any("message", h.Chat)
|
||||||
|
|
||||||
|
// 其他接口需要用户授权
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("detail", h.Detail)
|
||||||
|
group.POST("update", h.Update)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("history", h.History)
|
||||||
|
group.GET("clear", h.Clear)
|
||||||
|
group.POST("tokens", h.Tokens)
|
||||||
|
group.GET("stop", h.StopGenerate)
|
||||||
|
group.POST("tts", h.TextToSpeech)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat 处理聊天请求
|
// Chat 处理聊天请求
|
||||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||||
var input ChatInput
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置SSE响应头
|
// 设置SSE响应头
|
||||||
c.Header("Prompt-Type", "text/event-stream")
|
c.Header("Prompt-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
var input ChatInput
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
pushMessage(c, ChatEventError, types.InvalidArgs)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -113,7 +141,7 @@ func (h *ChatHandler) Chat(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证聊天角色
|
// 验证聊天角色
|
||||||
var chatRole model.ChatRole
|
var chatRole model.ChatApp
|
||||||
err := h.DB.First(&chatRole, input.RoleId).Error
|
err := h.DB.First(&chatRole, input.RoleId).Error
|
||||||
if err != nil || !chatRole.Enable {
|
if err != nil || !chatRole.Enable {
|
||||||
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
|
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
|
||||||
@@ -166,7 +194,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userVo.Power < input.ChatModel.Power {
|
if userVo.Power < input.ChatModel.Power {
|
||||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, input.ChatModel.Power)
|
return fmt.Errorf("您的算力不足,请购买算力。")
|
||||||
}
|
}
|
||||||
|
|
||||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||||
@@ -229,17 +257,24 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
|||||||
// 加载聊天上下文
|
// 加载聊天上下文
|
||||||
chatCtx := make([]any, 0)
|
chatCtx := make([]any, 0)
|
||||||
messages := make([]any, 0)
|
messages := make([]any, 0)
|
||||||
if h.App.SysConfig.EnableContext {
|
if h.App.SysConfig.Base.EnableContext {
|
||||||
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
|
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
|
||||||
if h.App.SysConfig.ContextDeep > 0 {
|
if h.App.SysConfig.Base.ContextDeep > 0 {
|
||||||
var historyMessages []model.ChatMessage
|
var historyMessages []model.ChatMessage
|
||||||
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
|
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
|
||||||
if input.LastMsgId > 0 { // 重新生成逻辑
|
if input.LastMsgId > 0 { // 重新生成逻辑
|
||||||
|
var lastMessage model.ChatMessage
|
||||||
|
err = dbSession.Where("id <= ?", input.LastMsgId).Where("type", types.PromptMsg).First(&lastMessage).Error
|
||||||
|
if err != nil {
|
||||||
|
input.LastMsgId = 0
|
||||||
|
} else {
|
||||||
|
input.LastMsgId = lastMessage.Id
|
||||||
|
}
|
||||||
dbSession = dbSession.Where("id < ?", input.LastMsgId)
|
dbSession = dbSession.Where("id < ?", input.LastMsgId)
|
||||||
// 删除对应的聊天记录
|
// 删除对应的聊天记录
|
||||||
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
|
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
|
||||||
}
|
}
|
||||||
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||||
msg := historyMessages[i]
|
msg := historyMessages[i]
|
||||||
@@ -267,7 +302,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 上下文的深度超出了模型的最大上下文深度
|
// 上下文的深度超出了模型的最大上下文深度
|
||||||
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
|
if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,6 +312,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
|||||||
}
|
}
|
||||||
reqMgs := make([]any, 0)
|
reqMgs := make([]any, 0)
|
||||||
|
|
||||||
|
// 添加引导提示词,防止模型生成违规内容
|
||||||
|
if h.App.SysConfig.Moderation.EnableGuide {
|
||||||
|
reqMgs = append(reqMgs, map[string]any{
|
||||||
|
"role": "system",
|
||||||
|
"content": h.App.SysConfig.Moderation.GuidePrompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for i := len(chatCtx) - 1; i >= 0; i-- {
|
for i := len(chatCtx) - 1; i >= 0; i-- {
|
||||||
reqMgs = append(reqMgs, chatCtx[i])
|
reqMgs = append(reqMgs, chatCtx[i])
|
||||||
}
|
}
|
||||||
@@ -295,16 +338,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// 如果不是逆向模型,则提取文件内容
|
// 处理文件,提取文件内容
|
||||||
modelValue := input.ChatModel.Value
|
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
||||||
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) {
|
if err != nil {
|
||||||
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
logger.Error("error with read file: ", err)
|
||||||
if err != nil {
|
continue
|
||||||
logger.Error("error with read file: ", err)
|
} else {
|
||||||
continue
|
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||||
} else {
|
logger.Debugf("fileContents: %s", fileContents)
|
||||||
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -320,16 +361,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(imgList) > 0 {
|
if len(imgList) > 0 {
|
||||||
imgList = append(imgList, map[string]interface{}{
|
imgList = append(imgList, map[string]any{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": input.Prompt,
|
"text": input.Prompt,
|
||||||
})
|
})
|
||||||
req.Messages = append(reqMgs, map[string]interface{}{
|
req.Messages = append(reqMgs, map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": imgList,
|
"content": imgList,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
req.Messages = append(reqMgs, map[string]interface{}{
|
req.Messages = append(reqMgs, map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": finalPrompt,
|
"content": finalPrompt,
|
||||||
})
|
})
|
||||||
@@ -445,7 +486,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
|||||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
|
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
|
||||||
// if the chat model bind a KEY, use it directly
|
// if the chat model bind a KEY, use it directly
|
||||||
if input.ChatModel.KeyId > 0 {
|
if input.ChatModel.KeyId > 0 {
|
||||||
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey)
|
h.DB.Where("id", input.ChatModel.KeyId).Where("enabled", true).Find(apiKey)
|
||||||
} else { // use the last unused key
|
} else { // use the last unused key
|
||||||
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||||
}
|
}
|
||||||
@@ -516,6 +557,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) saveChatHistory(
|
func (h *ChatHandler) saveChatHistory(
|
||||||
|
c *gin.Context,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
usage Usage,
|
usage Usage,
|
||||||
message types.Message,
|
message types.Message,
|
||||||
@@ -524,6 +566,34 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
promptCreatedAt time.Time,
|
promptCreatedAt time.Time,
|
||||||
replyCreatedAt time.Time) {
|
replyCreatedAt time.Time) {
|
||||||
|
|
||||||
|
// 文本审核
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(usage.Content)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
logger.Debugf("moderationResult: %+v", moderationResult)
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
Source: types.ModerationSourceChat,
|
||||||
|
Input: usage.Prompt,
|
||||||
|
Output: usage.Content,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
pushMessage(c, ChatEventError, "很抱歉,内容触发敏感词预警,AI 无法回答!!!")
|
||||||
|
// 更新用户算力
|
||||||
|
if input.ChatModel.Power > 0 {
|
||||||
|
h.subUserPower(userVo, input, 0, 0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
// for prompt
|
// for prompt
|
||||||
var promptTokens, replyTokens, totalTokens int
|
var promptTokens, replyTokens, totalTokens int
|
||||||
@@ -586,6 +656,22 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
logger.Error("failed to save reply history message: ", err)
|
logger.Error("failed to save reply history message: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 发送完整聊天记录给前端
|
||||||
|
var messageVo vo.ChatMessage
|
||||||
|
err = utils.CopyObject(historyReplyMsg, &messageVo)
|
||||||
|
if err == nil {
|
||||||
|
// 解析内容
|
||||||
|
var content vo.MsgContent
|
||||||
|
err = utils.JsonDecode(historyReplyMsg.Content, &content)
|
||||||
|
if err != nil {
|
||||||
|
content.Text = historyReplyMsg.Content
|
||||||
|
}
|
||||||
|
messageVo.Content = content
|
||||||
|
messageVo.CreatedAt = historyReplyMsg.CreatedAt.Unix()
|
||||||
|
messageVo.UpdatedAt = historyReplyMsg.UpdatedAt.Unix()
|
||||||
|
pushMessage(c, ChatEventComplete, messageVo)
|
||||||
|
}
|
||||||
|
|
||||||
// 更新用户算力
|
// 更新用户算力
|
||||||
if input.ChatModel.Power > 0 {
|
if input.ChatModel.Power > 0 {
|
||||||
h.subUserPower(userVo, input, promptTokens, replyTokens)
|
h.subUserPower(userVo, input, promptTokens, replyTokens)
|
||||||
@@ -710,221 +796,3 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
|||||||
logger.Error("写入音频数据到响应失败:", err)
|
logger.Error("写入音频数据到响应失败:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// // OPenAI 消息发送实现
|
|
||||||
// func (h *ChatHandler) sendOpenAiMessage(
|
|
||||||
// req types.ApiRequest,
|
|
||||||
// userVo vo.User,
|
|
||||||
// ctx context.Context,
|
|
||||||
// session *types.ChatSession,
|
|
||||||
// role model.ChatRole,
|
|
||||||
// prompt string,
|
|
||||||
// c *gin.Context) error {
|
|
||||||
// promptCreatedAt := time.Now() // 记录提问时间
|
|
||||||
// start := time.Now()
|
|
||||||
// var apiKey = model.ApiKey{}
|
|
||||||
// response, err := h.doRequest(ctx, req, session, &apiKey)
|
|
||||||
// logger.Info("HTTP请求完成,耗时:", time.Since(start))
|
|
||||||
// if err != nil {
|
|
||||||
// if strings.Contains(err.Error(), "context canceled") {
|
|
||||||
// return fmt.Errorf("用户取消了请求:%s", prompt)
|
|
||||||
// } else if strings.Contains(err.Error(), "no available key") {
|
|
||||||
// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
|
||||||
// }
|
|
||||||
// return err
|
|
||||||
// } else {
|
|
||||||
// defer response.Body.Close()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if response.StatusCode != 200 {
|
|
||||||
// body, _ := io.ReadAll(response.Body)
|
|
||||||
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
|
||||||
// }
|
|
||||||
|
|
||||||
// contentType := response.Header.Get("Prompt-Type")
|
|
||||||
// if strings.Contains(contentType, "text/event-stream") {
|
|
||||||
// replyCreatedAt := time.Now() // 记录回复时间
|
|
||||||
// // 循环读取 Chunk 消息
|
|
||||||
// var message = types.Message{Role: "assistant"}
|
|
||||||
// var contents = make([]string, 0)
|
|
||||||
// var function model.Function
|
|
||||||
// var toolCall = false
|
|
||||||
// var arguments = make([]string, 0)
|
|
||||||
// var reasoning = false
|
|
||||||
|
|
||||||
// pushMessage(c, ChatEventStart, "开始响应")
|
|
||||||
// scanner := bufio.NewScanner(response.Body)
|
|
||||||
// for scanner.Scan() {
|
|
||||||
// line := scanner.Text()
|
|
||||||
// if !strings.Contains(line, "data:") || len(line) < 30 {
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// var responseBody = types.ApiResponse{}
|
|
||||||
// err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
|
||||||
// if err != nil { // 数据解析出错
|
|
||||||
// return errors.New(line)
|
|
||||||
// }
|
|
||||||
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// if responseBody.Choices[0].Delta.Prompt == nil &&
|
|
||||||
// responseBody.Choices[0].Delta.ToolCalls == nil &&
|
|
||||||
// responseBody.Choices[0].Delta.ReasoningContent == "" {
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
|
||||||
// pushMessage(c, ChatEventError, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
|
|
||||||
// var tool types.ToolCall
|
|
||||||
// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
|
||||||
// tool = responseBody.Choices[0].Delta.ToolCalls[0]
|
|
||||||
// if toolCall && tool.Function.Name == "" {
|
|
||||||
// arguments = append(arguments, tool.Function.Arguments)
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // 兼容 Function Call
|
|
||||||
// fun := responseBody.Choices[0].Delta.FunctionCall
|
|
||||||
// if fun.Name != "" {
|
|
||||||
// tool = *new(types.ToolCall)
|
|
||||||
// tool.Function.Name = fun.Name
|
|
||||||
// } else if toolCall {
|
|
||||||
// arguments = append(arguments, fun.Arguments)
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if !utils.IsEmptyValue(tool) {
|
|
||||||
// res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
|
|
||||||
// if res.Error == nil {
|
|
||||||
// toolCall = true
|
|
||||||
// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
|
||||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
|
||||||
// "type": "text",
|
|
||||||
// "content": callMsg,
|
|
||||||
// })
|
|
||||||
// contents = append(contents, callMsg)
|
|
||||||
// }
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if responseBody.Choices[0].FinishReason == "tool_calls" ||
|
|
||||||
// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // output stopped
|
|
||||||
// if responseBody.Choices[0].FinishReason != "" {
|
|
||||||
// break // 输出完成或者输出中断了
|
|
||||||
// } else { // 正常输出结果
|
|
||||||
// // 兼容思考过程
|
|
||||||
// if responseBody.Choices[0].Delta.ReasoningContent != "" {
|
|
||||||
// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
|
|
||||||
// if !reasoning {
|
|
||||||
// reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
|
|
||||||
// reasoning = true
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
|
||||||
// "type": "text",
|
|
||||||
// "content": reasoningContent,
|
|
||||||
// })
|
|
||||||
// contents = append(contents, reasoningContent)
|
|
||||||
// } else if responseBody.Choices[0].Delta.Prompt != "" {
|
|
||||||
// finalContent := responseBody.Choices[0].Delta.Prompt
|
|
||||||
// if reasoning {
|
|
||||||
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
|
|
||||||
// reasoning = false
|
|
||||||
// }
|
|
||||||
// contents = append(contents, utils.InterfaceToString(finalContent))
|
|
||||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
|
||||||
// "type": "text",
|
|
||||||
// "content": finalContent,
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// } // end for
|
|
||||||
|
|
||||||
// if err := scanner.Err(); err != nil {
|
|
||||||
// if strings.Contains(err.Error(), "context canceled") {
|
|
||||||
// logger.Info("用户取消了请求:", prompt)
|
|
||||||
// } else {
|
|
||||||
// logger.Error("信息读取出错:", err)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if toolCall { // 调用函数完成任务
|
|
||||||
// params := make(map[string]any)
|
|
||||||
// _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
|
||||||
// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
|
||||||
// params["user_id"] = userVo.Id
|
|
||||||
// var apiRes types.BizVo
|
|
||||||
// r, err := req2.C().R().SetHeader("Body-Type", "application/json").
|
|
||||||
// SetHeader("Authorization", function.Token).
|
|
||||||
// SetBody(params).Post(function.Action)
|
|
||||||
// errMsg := ""
|
|
||||||
// if err != nil {
|
|
||||||
// errMsg = err.Error()
|
|
||||||
// } else {
|
|
||||||
// all, _ := io.ReadAll(r.Body)
|
|
||||||
// err = json.Unmarshal(all, &apiRes)
|
|
||||||
// if err != nil {
|
|
||||||
// errMsg = err.Error()
|
|
||||||
// } else if apiRes.Code != types.Success {
|
|
||||||
// errMsg = apiRes.Message
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if errMsg != "" {
|
|
||||||
// errMsg = "调用函数工具出错:" + errMsg
|
|
||||||
// contents = append(contents, errMsg)
|
|
||||||
// } else {
|
|
||||||
// errMsg = utils.InterfaceToString(apiRes.Data)
|
|
||||||
// contents = append(contents, errMsg)
|
|
||||||
// }
|
|
||||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
|
||||||
// "type": "text",
|
|
||||||
// "content": errMsg,
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // 消息发送成功
|
|
||||||
// if len(contents) > 0 {
|
|
||||||
// usage := Usage{
|
|
||||||
// Prompt: prompt,
|
|
||||||
// Prompt: strings.Join(contents, ""),
|
|
||||||
// PromptTokens: 0,
|
|
||||||
// CompletionTokens: 0,
|
|
||||||
// TotalTokens: 0,
|
|
||||||
// }
|
|
||||||
// message.Prompt = usage.Prompt
|
|
||||||
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
|
||||||
// }
|
|
||||||
// } else {
|
|
||||||
// var respVo OpenAIResVo
|
|
||||||
// body, err := io.ReadAll(response.Body)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("读取响应失败:%v", body)
|
|
||||||
// }
|
|
||||||
// err = json.Unmarshal(body, &respVo)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("解析响应失败:%v", body)
|
|
||||||
// }
|
|
||||||
// content := respVo.Choices[0].Message.Prompt
|
|
||||||
// if strings.HasPrefix(req.Model, "o1-") {
|
|
||||||
// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
|
|
||||||
// }
|
|
||||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
|
||||||
// "type": "text",
|
|
||||||
// "content": content,
|
|
||||||
// })
|
|
||||||
// respVo.Usage.Prompt = prompt
|
|
||||||
// respVo.Usage.Prompt = content
|
|
||||||
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
modelValues = append(modelValues, chat.Model)
|
modelValues = append(modelValues, chat.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatApp
|
||||||
var models []model.ChatModel
|
var models []model.ChatModel
|
||||||
roleMap := make(map[uint]model.ChatRole)
|
roleMap := make(map[uint]model.ChatApp)
|
||||||
modelMap := make(map[string]model.ChatModel)
|
modelMap := make(map[string]model.ChatModel)
|
||||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||||
h.DB.Where("value IN ?", modelValues).Find(&models)
|
h.DB.Where("value IN ?", modelValues).Find(&models)
|
||||||
@@ -205,7 +205,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 填充角色名称
|
// 填充角色名称
|
||||||
var role model.ChatRole
|
var role model.ChatApp
|
||||||
res = h.DB.Where("id", chatItem.RoleId).First(&role)
|
res = h.DB.Where("id", chatItem.RoleId).First(&role)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Role not found")
|
resp.ERROR(c, "Role not found")
|
||||||
|
|||||||
@@ -26,6 +26,12 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
|||||||
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ChatModelHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/model/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatModel
|
var items []model.ChatModel
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
TotalTokens: 0,
|
TotalTokens: 0,
|
||||||
}
|
}
|
||||||
message.Content = usage.Content
|
message.Content = usage.Content
|
||||||
h.saveChatHistory(req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
|
h.saveChatHistory(c, req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
}
|
}
|
||||||
} else { // 非流式输出
|
} else { // 非流式输出
|
||||||
var respVo OpenAIResVo
|
var respVo OpenAIResVo
|
||||||
@@ -242,7 +242,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
pushMessage(c, "text", content)
|
pushMessage(c, "text", content)
|
||||||
respVo.Usage.Prompt = input.Prompt
|
respVo.Usage.Prompt = input.Prompt
|
||||||
respVo.Usage.Content = content
|
respVo.Usage.Content = content
|
||||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
|
h.saveChatHistory(c, req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -27,6 +27,15 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
|
|||||||
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ConfigHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/config/")
|
||||||
|
|
||||||
|
// 无需授权的接口
|
||||||
|
group.GET("get", h.Get)
|
||||||
|
group.GET("license", h.License)
|
||||||
|
}
|
||||||
|
|
||||||
// Get 获取指定的系统配置
|
// Get 获取指定的系统配置
|
||||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||||
key := c.Query("key")
|
key := c.Query("key")
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/dalle"
|
"geekai/service/dalle"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
@@ -25,16 +27,18 @@ import (
|
|||||||
|
|
||||||
type DallJobHandler struct {
|
type DallJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
dallService *dalle.Service
|
dallService *dalle.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
|
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *DallJobHandler {
|
||||||
return &DallJobHandler{
|
return &DallJobHandler{
|
||||||
dallService: service,
|
dallService: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -42,6 +46,24 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *DallJobHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/dall/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
|
group.GET("models", h.GetModels)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("image", h.Image)
|
||||||
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("publish", h.Publish)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *DallJobHandler) Image(c *gin.Context) {
|
func (h *DallJobHandler) Image(c *gin.Context) {
|
||||||
var data types.DallTask
|
var data types.DallTask
|
||||||
@@ -50,6 +72,29 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 文本审核
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: h.GetLoginUserId(c),
|
||||||
|
Source: types.ModerationSourceDalle,
|
||||||
|
Input: data.Prompt,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
resp.ERROR(c, "当前创作内容包含敏感词,提示词未通过文本审核,请重新输入!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
|
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
|
||||||
resp.ERROR(c, "模型不存在")
|
resp.ERROR(c, "模型不存在")
|
||||||
@@ -73,11 +118,12 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
|||||||
UserId: uint(userId),
|
UserId: uint(userId),
|
||||||
ModelId: chatModel.Id,
|
ModelId: chatModel.Id,
|
||||||
ModelName: chatModel.Value,
|
ModelName: chatModel.Value,
|
||||||
|
Image: data.Image,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Quality: data.Quality,
|
Quality: data.Quality,
|
||||||
Size: data.Size,
|
Size: data.Size,
|
||||||
Style: data.Style,
|
Style: data.Style,
|
||||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||||
Power: chatModel.Power,
|
Power: chatModel.Power,
|
||||||
}
|
}
|
||||||
job := model.DallJob{
|
job := model.DallJob{
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/crawler"
|
|
||||||
"geekai/service/dalle"
|
"geekai/service/dalle"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -31,7 +30,6 @@ import (
|
|||||||
|
|
||||||
type FunctionHandler struct {
|
type FunctionHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
config types.ApiConfig
|
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
dallService *dalle.Service
|
dallService *dalle.Service
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
@@ -49,13 +47,23 @@ func NewFunctionHandler(
|
|||||||
App: server,
|
App: server,
|
||||||
DB: db,
|
DB: db,
|
||||||
},
|
},
|
||||||
config: config.ApiConfig,
|
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
dallService: dallService,
|
dallService: dallService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *FunctionHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/function/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.POST("weibo", h.WeiBo)
|
||||||
|
group.POST("zaobao", h.ZaoBao)
|
||||||
|
group.POST("dalle3", h.Dall3)
|
||||||
|
}
|
||||||
|
|
||||||
type resVo struct {
|
type resVo struct {
|
||||||
Code types.BizCode `json:"code"`
|
Code types.BizCode `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
@@ -107,16 +115,10 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.config.Token == "" {
|
url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL)
|
||||||
resp.ERROR(c, "无效的 API Token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
|
|
||||||
var res resVo
|
var res resVo
|
||||||
r, err := req.C().R().
|
r, err := req.C().R().
|
||||||
SetHeader("AppId", h.config.AppId).
|
SetHeader("Authorization", "Bearer geekai-plus").
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||||
@@ -146,16 +148,10 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.config.Token == "" {
|
url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL)
|
||||||
resp.ERROR(c, "无效的 API Token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
|
|
||||||
var res resVo
|
var res resVo
|
||||||
r, err := req.C().R().
|
r, err := req.C().R().
|
||||||
SetHeader("AppId", h.config.AppId).
|
SetHeader("Authorization", "Bearer geekai-plus").
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||||
@@ -193,16 +189,23 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var chatModel model.ChatModel
|
||||||
|
res := h.DB.Where("type = ?", "img").Where("enabled", true).First(&chatModel)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "没有找到可用的AI绘图模型!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debugf("绘画参数:%+v", params)
|
logger.Debugf("绘画参数:%+v", params)
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
res = h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "当前用户不存在!")
|
resp.ERROR(c, "当前用户不存在!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.DallPower {
|
if user.Power < chatModel.Power {
|
||||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
|
resp.ERROR(c, "创建绘图任务失败,算力不足")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,24 +214,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
task := types.DallTask{
|
task := types.DallTask{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
ModelId: 0,
|
ModelId: chatModel.Id,
|
||||||
ModelName: "dall-e-3",
|
ModelName: chatModel.Value,
|
||||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||||
N: 1,
|
N: 1,
|
||||||
Quality: "standard",
|
Quality: "standard",
|
||||||
Size: "1024x1024",
|
Size: "1024x1024",
|
||||||
Style: "vivid",
|
Style: "vivid",
|
||||||
Power: h.App.SysConfig.DallPower,
|
Power: chatModel.Power,
|
||||||
}
|
}
|
||||||
job := model.DallJob{
|
job := model.DallJob{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Power: h.App.SysConfig.DallPower,
|
Power: chatModel.Power,
|
||||||
TaskInfo: utils.JsonEncode(task),
|
TaskInfo: utils.JsonEncode(task),
|
||||||
}
|
}
|
||||||
err := h.DB.Create(&job).Error
|
err := h.DB.Create(&job).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error())
|
resp.ERROR(c, "创建绘图任务失败:"+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,76 +256,6 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, content)
|
resp.SUCCESS(c, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 实现一个联网搜索的函数工具,采用爬虫实现
|
|
||||||
func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
|
||||||
if err := h.checkAuth(c); err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var params map[string]interface{}
|
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从参数中获取搜索关键词
|
|
||||||
keyword, ok := params["keyword"].(string)
|
|
||||||
if !ok || keyword == "" {
|
|
||||||
resp.ERROR(c, "搜索关键词不能为空")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从参数中获取最大页数,默认为1页
|
|
||||||
maxPages := 1
|
|
||||||
if pages, ok := params["max_pages"].(float64); ok {
|
|
||||||
maxPages = int(pages)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取用户ID
|
|
||||||
userID, ok := params["user_id"].(float64)
|
|
||||||
if !ok {
|
|
||||||
resp.ERROR(c, "用户ID不能为空")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查询用户信息
|
|
||||||
var user model.User
|
|
||||||
res := h.DB.Where("id = ?", int(userID)).First(&user)
|
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "用户不存在")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查用户算力是否足够
|
|
||||||
searchPower := 1 // 每次搜索消耗1点算力
|
|
||||||
if user.Power < searchPower {
|
|
||||||
resp.ERROR(c, "算力不足,无法执行网络搜索")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行网络搜索
|
|
||||||
searchResults, err := crawler.SearchWeb(keyword, maxPages)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 扣减用户算力
|
|
||||||
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Model: "web_search",
|
|
||||||
Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 返回搜索结果
|
|
||||||
resp.SUCCESS(c, searchResults)
|
|
||||||
}
|
|
||||||
|
|
||||||
// List 获取所有的工具函数列表
|
// List 获取所有的工具函数列表
|
||||||
func (h *FunctionHandler) List(c *gin.Context) {
|
func (h *FunctionHandler) List(c *gin.Context) {
|
||||||
var items []model.Function
|
var items []model.Function
|
||||||
|
|||||||
@@ -8,14 +8,18 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// InviteHandler 用户邀请
|
// InviteHandler 用户邀请
|
||||||
@@ -27,6 +31,23 @@ func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
|||||||
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *InviteHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/invite/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.GET("hits", h.Hits)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("code", h.Code)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("stats", h.Stats)
|
||||||
|
group.GET("rules", h.Rules)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Code 获取当前用户邀请码
|
// Code 获取当前用户邀请码
|
||||||
func (h *InviteHandler) Code(c *gin.Context) {
|
func (h *InviteHandler) Code(c *gin.Context) {
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
@@ -65,21 +86,34 @@ func (h *InviteHandler) List(c *gin.Context) {
|
|||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.InviteLog{}).Count(&total)
|
session.Model(&model.InviteLog{}).Count(&total)
|
||||||
var items []model.InviteLog
|
var items []model.InviteLog
|
||||||
var list = make([]vo.InviteLog, 0)
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
|
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items).Error
|
||||||
if res.Error == nil {
|
if err != nil {
|
||||||
for _, item := range items {
|
resp.ERROR(c, err.Error())
|
||||||
var v vo.InviteLog
|
return
|
||||||
err := utils.CopyObject(item, &v)
|
}
|
||||||
if err == nil {
|
|
||||||
v.Id = item.Id
|
userIds := make([]uint, 0)
|
||||||
v.CreatedAt = item.CreatedAt.Unix()
|
for _, item := range items {
|
||||||
list = append(list, v)
|
userIds = append(userIds, item.UserId)
|
||||||
} else {
|
}
|
||||||
logger.Error(err)
|
userMap := make(map[uint]model.User)
|
||||||
}
|
var users []model.User
|
||||||
|
h.DB.Model(&model.User{}).Where("id IN (?)", userIds).Find(&users)
|
||||||
|
for _, user := range users {
|
||||||
|
userMap[user.Id] = user
|
||||||
|
}
|
||||||
|
|
||||||
|
var list = make([]vo.InviteLog, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
var v vo.InviteLog
|
||||||
|
err := utils.CopyObject(item, &v)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
v.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
v.Avatar = userMap[item.UserId].Avatar
|
||||||
|
list = append(list, v)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||||
}
|
}
|
||||||
@@ -90,3 +124,89 @@ func (h *InviteHandler) Hits(c *gin.Context) {
|
|||||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stats 获取邀请统计
|
||||||
|
func (h *InviteHandler) Stats(c *gin.Context) {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
|
||||||
|
// 获取邀请码
|
||||||
|
var inviteCode model.InviteCode
|
||||||
|
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "邀请码不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计累计邀请数
|
||||||
|
var totalInvite int64
|
||||||
|
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ?", userId).Count(&totalInvite)
|
||||||
|
|
||||||
|
// 统计今日邀请数
|
||||||
|
today := time.Now().Format("2006-01-02")
|
||||||
|
var todayInvite int64
|
||||||
|
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ? AND DATE(created_at) = ?", userId, today).Count(&todayInvite)
|
||||||
|
|
||||||
|
// 获取系统配置中的邀请奖励
|
||||||
|
var config model.Config
|
||||||
|
var invitePower int = 200 // 默认值
|
||||||
|
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
|
||||||
|
var configMap map[string]any
|
||||||
|
if utils.JsonDecode(config.Value, &configMap) == nil {
|
||||||
|
if power, ok := configMap["invite_power"].(float64); ok {
|
||||||
|
invitePower = int(power)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算获得奖励总数
|
||||||
|
rewardTotal := int(totalInvite) * invitePower
|
||||||
|
|
||||||
|
// 构建邀请链接
|
||||||
|
inviteLink := fmt.Sprintf("%s/register?invite=%s", h.App.Config.StaticUrl, inviteCode.Code)
|
||||||
|
|
||||||
|
stats := vo.InviteStats{
|
||||||
|
InviteCount: int(totalInvite),
|
||||||
|
RewardTotal: rewardTotal,
|
||||||
|
TodayInvite: int(todayInvite),
|
||||||
|
InviteCode: inviteCode.Code,
|
||||||
|
InviteLink: inviteLink,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rules 获取奖励规则
|
||||||
|
func (h *InviteHandler) Rules(c *gin.Context) {
|
||||||
|
// 获取系统配置中的邀请奖励
|
||||||
|
var config model.Config
|
||||||
|
var invitePower int = 200 // 默认值
|
||||||
|
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
|
||||||
|
var configMap map[string]interface{}
|
||||||
|
if utils.JsonDecode(config.Value, &configMap) == nil {
|
||||||
|
if power, ok := configMap["invite_power"].(float64); ok {
|
||||||
|
invitePower = int(power)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := []vo.RewardRule{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Title: "好友注册",
|
||||||
|
Desc: "好友通过邀请链接成功注册",
|
||||||
|
Icon: "icon-user-fill",
|
||||||
|
Color: "#1989fa",
|
||||||
|
Reward: invitePower,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
Title: "好友首次充值",
|
||||||
|
Desc: "好友首次充值任意金额",
|
||||||
|
Icon: "icon-money",
|
||||||
|
Color: "#07c160",
|
||||||
|
Reward: invitePower * 2, // 假设首次充值奖励是注册奖励的2倍
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, rules)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/jimeng"
|
"geekai/service/jimeng"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
@@ -19,27 +20,34 @@ import (
|
|||||||
// JimengHandler 即梦AI处理器
|
// JimengHandler 即梦AI处理器
|
||||||
type JimengHandler struct {
|
type JimengHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
jimengService *jimeng.Service
|
jimengService *jimeng.Service
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJimengHandler 创建即梦AI处理器
|
// NewJimengHandler 创建即梦AI处理器
|
||||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler {
|
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService, moderationManager *moderation.ServiceManager) *JimengHandler {
|
||||||
return &JimengHandler{
|
return &JimengHandler{
|
||||||
BaseHandler: BaseHandler{App: app, DB: db},
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
jimengService: jimengService,
|
jimengService: jimengService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterRoutes 注册路由,新增统一任务接口
|
// RegisterRoutes 注册路由,新增统一任务接口
|
||||||
func (h *JimengHandler) RegisterRoutes() {
|
func (h *JimengHandler) RegisterRoutes() {
|
||||||
rg := h.App.Engine.Group("/api/jimeng")
|
group := h.App.Engine.Group("/api/jimeng/")
|
||||||
rg.POST("task", h.CreateTask) // 只保留统一任务接口
|
|
||||||
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
|
// 需要用户授权的接口
|
||||||
rg.POST("jobs", h.Jobs)
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
rg.GET("remove", h.Remove)
|
{
|
||||||
rg.GET("retry", h.Retry)
|
group.POST("task", h.CreateTask)
|
||||||
|
group.GET("power-config", h.GetPowerConfig)
|
||||||
|
group.POST("jobs", h.Jobs)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("retry", h.Retry)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// JimengTaskRequest 统一任务请求结构体
|
// JimengTaskRequest 统一任务请求结构体
|
||||||
@@ -70,6 +78,31 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 文本审核
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: h.GetLoginUserId(c),
|
||||||
|
Source: types.ModerationSourceJiMeng,
|
||||||
|
Input: req.Prompt,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||||
resp.ERROR(c, "提示词不能为空")
|
resp.ERROR(c, "提示词不能为空")
|
||||||
@@ -153,12 +186,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
"seed": req.Seed,
|
"seed": req.Seed,
|
||||||
"scale": req.Scale,
|
"scale": req.Scale,
|
||||||
}
|
}
|
||||||
if len(req.ImageUrls) > 0 {
|
params["image_urls"] = []string{req.ImageInput}
|
||||||
params["image_urls"] = req.ImageUrls
|
|
||||||
}
|
|
||||||
if len(req.BinaryDataBase64) > 0 {
|
|
||||||
params["binary_data_base64"] = req.BinaryDataBase64
|
|
||||||
}
|
|
||||||
case "image_effects":
|
case "image_effects":
|
||||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||||
taskType = model.JMTaskTypeImageEffects
|
taskType = model.JMTaskTypeImageEffects
|
||||||
@@ -181,9 +209,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
taskType = model.JMTaskTypeTextToVideo
|
taskType = model.JMTaskTypeTextToVideo
|
||||||
reqKey = jimeng.ReqKeyTextToVideo
|
reqKey = jimeng.ReqKeyTextToVideo
|
||||||
modelName = "即梦文生视频"
|
modelName = "即梦文生视频"
|
||||||
if req.Seed == 0 {
|
|
||||||
req.Seed = -1
|
|
||||||
}
|
|
||||||
if req.AspectRatio == "" {
|
if req.AspectRatio == "" {
|
||||||
req.AspectRatio = jimeng.AspectRatio16_9
|
req.AspectRatio = jimeng.AspectRatio16_9
|
||||||
}
|
}
|
||||||
@@ -196,9 +221,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
taskType = model.JMTaskTypeImageToVideo
|
taskType = model.JMTaskTypeImageToVideo
|
||||||
reqKey = jimeng.ReqKeyImageToVideo
|
reqKey = jimeng.ReqKeyImageToVideo
|
||||||
modelName = "即梦图生视频"
|
modelName = "即梦图生视频"
|
||||||
if req.Seed == 0 {
|
|
||||||
req.Seed = -1
|
|
||||||
}
|
|
||||||
params = map[string]any{
|
params = map[string]any{
|
||||||
"seed": req.Seed,
|
"seed": req.Seed,
|
||||||
"aspect_ratio": req.AspectRatio,
|
"aspect_ratio": req.AspectRatio,
|
||||||
@@ -333,8 +355,10 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
|||||||
resp.ERROR(c, "无权限操作")
|
resp.ERROR(c, "无权限操作")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if job.Status != model.JMTaskStatusFailed {
|
|
||||||
resp.ERROR(c, "只有失败的任务才能删除")
|
// 正在运行中的任务不能删除
|
||||||
|
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
|
||||||
|
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,17 +369,20 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 退回算力
|
// 失败任务删除后退回算力
|
||||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
if job.Status != model.JMTaskStatusFailed {
|
||||||
Type: types.PowerRefund,
|
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||||
Model: "jimeng",
|
Type: types.PowerRefund,
|
||||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
Model: "jimeng",
|
||||||
})
|
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||||
if err != nil {
|
})
|
||||||
resp.ERROR(c, "退回算力失败")
|
if err != nil {
|
||||||
tx.Rollback()
|
resp.ERROR(c, "退回算力失败")
|
||||||
return
|
tx.Rollback()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
|
|
||||||
resp.SUCCESS(c, gin.H{})
|
resp.SUCCESS(c, gin.H{})
|
||||||
@@ -408,7 +435,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
|||||||
|
|
||||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||||
config := h.jimengService.GetConfig()
|
config := h.App.SysConfig.Jimeng
|
||||||
|
|
||||||
switch taskType {
|
switch taskType {
|
||||||
case model.JMTaskTypeTextToImage:
|
case model.JMTaskTypeTextToImage:
|
||||||
@@ -430,7 +457,7 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
|||||||
|
|
||||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||||
config := h.jimengService.GetConfig()
|
config := h.App.SysConfig.Jimeng
|
||||||
resp.SUCCESS(c, gin.H{
|
resp.SUCCESS(c, gin.H{
|
||||||
"text_to_image": config.Power.TextToImage,
|
"text_to_image": config.Power.TextToImage,
|
||||||
"image_to_image": config.Power.ImageToImage,
|
"image_to_image": config.Power.ImageToImage,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -35,6 +36,17 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *MarkMapHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/markMap/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("gen", h.Generate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Generate 生成思维导图
|
// Generate 生成思维导图
|
||||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -25,6 +26,12 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
|||||||
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *MenuHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/menu/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
}
|
||||||
|
|
||||||
// List 数据列表
|
// List 数据列表
|
||||||
func (h *MenuHandler) List(c *gin.Context) {
|
func (h *MenuHandler) List(c *gin.Context) {
|
||||||
index := h.GetBool(c, "index")
|
index := h.GetBool(c, "index")
|
||||||
@@ -33,7 +40,7 @@ func (h *MenuHandler) List(c *gin.Context) {
|
|||||||
session := h.DB.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
session = session.Where("enabled", true)
|
session = session.Where("enabled", true)
|
||||||
if index {
|
if index {
|
||||||
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
|
session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs)
|
||||||
}
|
}
|
||||||
res := session.Order("sort_num ASC").Find(&items)
|
res := session.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/mj"
|
"geekai/service/mj"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
@@ -27,18 +29,20 @@ import (
|
|||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
mjService *mj.Service
|
mjService *mj.Service
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
|
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *MidJourneyHandler {
|
||||||
return &MidJourneyHandler{
|
return &MidJourneyHandler{
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
mjService: service,
|
mjService: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -46,6 +50,25 @@ func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.S
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *MidJourneyHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/mj/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("image", h.Image)
|
||||||
|
group.POST("upscale", h.Upscale)
|
||||||
|
group.POST("variation", h.Variation)
|
||||||
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("publish", h.Publish)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -53,7 +76,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.MjPower {
|
if user.Power < h.App.SysConfig.Base.MjPower {
|
||||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -90,6 +113,29 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 文本审核
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: h.GetLoginUserId(c),
|
||||||
|
Source: types.ModerationSourceMJ,
|
||||||
|
Input: data.Prompt,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var params = ""
|
var params = ""
|
||||||
if data.Rate != "" && !strings.Contains(params, "--ar") {
|
if data.Rate != "" && !strings.Contains(params, "--ar") {
|
||||||
params += " --ar " + data.Rate
|
params += " --ar " + data.Rate
|
||||||
@@ -159,8 +205,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
Params: params,
|
Params: params,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
ImgArr: data.ImgArr,
|
ImgArr: data.ImgArr,
|
||||||
Mode: h.App.SysConfig.MjMode,
|
Mode: h.App.SysConfig.Base.MjMode,
|
||||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||||
}
|
}
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
Type: data.TaskType,
|
Type: data.TaskType,
|
||||||
@@ -169,7 +215,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
TaskInfo: utils.JsonEncode(task),
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
||||||
Power: h.App.SysConfig.MjPower,
|
Power: h.App.SysConfig.Base.MjPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
opt := "绘图"
|
opt := "绘图"
|
||||||
@@ -232,7 +278,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
Index: data.Index,
|
Index: data.Index,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
Mode: h.App.SysConfig.MjMode,
|
Mode: h.App.SysConfig.Base.MjMode,
|
||||||
}
|
}
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
Type: types.TaskUpscale.String(),
|
Type: types.TaskUpscale.String(),
|
||||||
@@ -240,7 +286,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
TaskInfo: utils.JsonEncode(task),
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Power: h.App.SysConfig.MjActionPower,
|
Power: h.App.SysConfig.Base.MjActionPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
@@ -287,7 +333,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
ChannelId: data.ChannelId,
|
ChannelId: data.ChannelId,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
Mode: h.App.SysConfig.MjMode,
|
Mode: h.App.SysConfig.Base.MjMode,
|
||||||
}
|
}
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
Type: types.TaskVariation.String(),
|
Type: types.TaskVariation.String(),
|
||||||
@@ -296,7 +342,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
TaskInfo: utils.JsonEncode(task),
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Power: h.App.SysConfig.MjActionPower,
|
Power: h.App.SysConfig.Base.MjActionPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -32,6 +33,22 @@ func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManage
|
|||||||
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *NetHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/upload")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("", h.Upload)
|
||||||
|
group.POST("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
h.App.Engine.GET("/api/download", h.Download)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *NetHandler) Upload(c *gin.Context) {
|
func (h *NetHandler) Upload(c *gin.Context) {
|
||||||
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -28,6 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
|||||||
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *OrderHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/order/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("query", h.Query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// List 订单列表
|
// List 订单列表
|
||||||
func (h *OrderHandler) List(c *gin.Context) {
|
func (h *OrderHandler) List(c *gin.Context) {
|
||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
@@ -48,20 +60,21 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
order.Id = item.Id
|
order.Id = item.Id
|
||||||
order.CreatedAt = item.CreatedAt.Unix()
|
order.CreatedAt = item.CreatedAt.Unix()
|
||||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||||
payMethod, ok := types.PayMethods[item.PayWay]
|
payChannel, ok := types.PayChannel[item.Channel]
|
||||||
if !ok {
|
if !ok {
|
||||||
payMethod = item.PayWay
|
payChannel = item.PayWay
|
||||||
}
|
}
|
||||||
payName, ok := types.PayNames[item.PayType]
|
payWays, ok := types.PayWays[item.PayWay]
|
||||||
if !ok {
|
if !ok {
|
||||||
payName = item.PayWay
|
payWays = item.PayWay
|
||||||
}
|
}
|
||||||
order.PayMethod = payMethod
|
order.ChannelName = payChannel
|
||||||
order.PayName = payName
|
order.PayName = payWays
|
||||||
list = append(list, order)
|
list = append(list, order)
|
||||||
} else {
|
} else {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||||
@@ -82,17 +95,8 @@ func (h *OrderHandler) Query(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
counter := 0
|
var item model.Order
|
||||||
for {
|
h.DB.Where("order_no = ?", orderNo).First(&item)
|
||||||
time.Sleep(time.Second)
|
|
||||||
var item model.Order
|
|
||||||
h.DB.Where("order_no = ?", orderNo).First(&item)
|
|
||||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
|
||||||
order.Status = item.Status
|
|
||||||
break
|
|
||||||
}
|
|
||||||
counter++
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/payment"
|
"geekai/service/payment"
|
||||||
@@ -33,52 +34,148 @@ type PayWay struct {
|
|||||||
// PaymentHandler 支付服务回调 handler
|
// PaymentHandler 支付服务回调 handler
|
||||||
type PaymentHandler struct {
|
type PaymentHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
alipayService *payment.AlipayService
|
alipayService *payment.AlipayService
|
||||||
huPiPayService *payment.HuPiPayService
|
epayService *payment.EPayService
|
||||||
geekPayService *payment.GeekPayService
|
wxpayService *payment.WxPayService
|
||||||
wechatPayService *payment.WechatPayService
|
snowflake *service.Snowflake
|
||||||
snowflake *service.Snowflake
|
userService *service.UserService
|
||||||
userService *service.UserService
|
fs embed.FS
|
||||||
fs embed.FS
|
lock sync.Mutex
|
||||||
lock sync.Mutex
|
config *types.PaymentConfig
|
||||||
signKey string // 用来签名的随机秘钥
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPaymentHandler(
|
func NewPaymentHandler(
|
||||||
server *core.AppServer,
|
server *core.AppServer,
|
||||||
alipayService *payment.AlipayService,
|
alipayService *payment.AlipayService,
|
||||||
huPiPayService *payment.HuPiPayService,
|
geekPayService *payment.EPayService,
|
||||||
geekPayService *payment.GeekPayService,
|
wxpayService *payment.WxPayService,
|
||||||
wechatPayService *payment.WechatPayService,
|
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
userService *service.UserService,
|
userService *service.UserService,
|
||||||
snowflake *service.Snowflake,
|
snowflake *service.Snowflake,
|
||||||
fs embed.FS) *PaymentHandler {
|
fs embed.FS,
|
||||||
|
sysConfig *types.SystemConfig) *PaymentHandler {
|
||||||
return &PaymentHandler{
|
return &PaymentHandler{
|
||||||
alipayService: alipayService,
|
alipayService: alipayService,
|
||||||
huPiPayService: huPiPayService,
|
epayService: geekPayService,
|
||||||
geekPayService: geekPayService,
|
wxpayService: wxpayService,
|
||||||
wechatPayService: wechatPayService,
|
snowflake: snowflake,
|
||||||
snowflake: snowflake,
|
userService: userService,
|
||||||
userService: userService,
|
fs: fs,
|
||||||
fs: fs,
|
lock: sync.Mutex{},
|
||||||
lock: sync.Mutex{},
|
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: server,
|
App: server,
|
||||||
DB: db,
|
DB: db,
|
||||||
},
|
},
|
||||||
signKey: utils.RandString(32),
|
config: &sysConfig.Payment,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PaymentHandler) Pay(c *gin.Context) {
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *PaymentHandler) RegisterRoutes() {
|
||||||
|
rg := h.App.Engine.Group("/api/payment/")
|
||||||
|
|
||||||
|
// 支付回调接口(公开)
|
||||||
|
rg.POST("notify/alipay", h.AlipayNotify)
|
||||||
|
rg.GET("notify/epay", h.EPayNotify)
|
||||||
|
rg.POST("notify/wxpay", h.WxpayNotify)
|
||||||
|
|
||||||
|
// 需要用户登录的接口
|
||||||
|
rg.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
rg.POST("create", h.CreateOrder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PaymentHandler) StartSyncOrders() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
err := h.SyncOrders()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncOrders 同步订单状态
|
||||||
|
func (h *PaymentHandler) SyncOrders() error {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
logger.Errorf("同步订单状态发生异常: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var orders []model.Order
|
||||||
|
err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, order := range orders {
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
//超时15分钟的订单,直接标记为已关闭
|
||||||
|
if time.Now().After(order.CreatedAt.Add(time.Minute * 5)) {
|
||||||
|
h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true)
|
||||||
|
logger.Errorf("订单超时:%v", order)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 查询订单状态
|
||||||
|
var res payment.OrderInfo
|
||||||
|
switch order.Channel {
|
||||||
|
case payment.PayChannelEpay:
|
||||||
|
res, err = h.epayService.Query(order.OrderNo)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with query order info: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 微信支付
|
||||||
|
case payment.PayChannelWX:
|
||||||
|
res, err = h.wxpayService.Query(order.OrderNo)
|
||||||
|
logger.Debugf("微信支付订单状态:%+v", res)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with query order info: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case payment.PayChannelAL:
|
||||||
|
res, err = h.alipayService.Query(order.OrderNo)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with query order info: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订单已关闭
|
||||||
|
if res.Closed() {
|
||||||
|
h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{
|
||||||
|
"checked": true,
|
||||||
|
"status": types.OrderPaidFailed,
|
||||||
|
})
|
||||||
|
logger.Errorf("订单已关闭:%v", order)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订单未支付,不处理,继续轮询
|
||||||
|
if !res.Success() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订单支付成功
|
||||||
|
err = h.paySuccess(res)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("error with deal order: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
PayWay string `json:"pay_way"`
|
PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信
|
||||||
PayType string `json:"pay_type"`
|
Pid int `json:"pid,omitempty"`
|
||||||
ProductId int `json:"product_id"`
|
Device string `json:"device,omitempty"`
|
||||||
UserId int `json:"user_id"`
|
Domain string `json:"domain,omitempty"` // 支付回调域名
|
||||||
Device string `json:"device"`
|
Channel string `json:"channel,omitempty"`
|
||||||
Host string `json:"host"`
|
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -86,7 +183,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var product model.Product
|
var product model.Product
|
||||||
err := h.DB.Where("id", data.ProductId).First(&product).Error
|
err := h.DB.Where("id", data.Pid).First(&product).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Product not found")
|
resp.ERROR(c, "Product not found")
|
||||||
return
|
return
|
||||||
@@ -97,136 +194,118 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
|||||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
var user model.User
|
var user model.User
|
||||||
err = h.DB.Where("id", data.UserId).First(&user).Error
|
err = h.DB.Where("id", userId).First(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
amount := product.Discount
|
amount := product.Price
|
||||||
var payURL, returnURL, notifyURL string
|
var payURL, notifyURL string
|
||||||
switch data.PayWay {
|
switch data.PayWay {
|
||||||
case "alipay":
|
case "wxpay":
|
||||||
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
|
logger.Debugf("微信支付,%+v", data)
|
||||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
data.Channel = payment.PayChannelWX
|
||||||
} else {
|
// 优先使用微信官方支付
|
||||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
|
if h.config.WxPay.Enabled {
|
||||||
}
|
data.Channel = "wxpay"
|
||||||
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
|
if h.config.WxPay.Domain != "" {
|
||||||
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
data.Domain = h.config.WxPay.Domain
|
||||||
} else {
|
}
|
||||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/wxpay", data.Domain)
|
||||||
}
|
payURL, err = h.wxpayService.Pay(payment.PayRequest{
|
||||||
money := fmt.Sprintf("%.2f", amount)
|
|
||||||
if data.Device == "wechat" {
|
|
||||||
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
|
|
||||||
OutTradeNo: orderNo,
|
OutTradeNo: orderNo,
|
||||||
Subject: product.Name,
|
TotalFee: fmt.Sprintf("%d", int(amount*100)),
|
||||||
TotalFee: money,
|
|
||||||
ReturnURL: returnURL,
|
|
||||||
NotifyURL: notifyURL,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
|
|
||||||
OutTradeNo: orderNo,
|
|
||||||
Subject: product.Name,
|
|
||||||
TotalFee: money,
|
|
||||||
ReturnURL: returnURL,
|
|
||||||
NotifyURL: notifyURL,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
break
|
|
||||||
case "wechat":
|
|
||||||
if h.App.Config.WechatPayConfig.NotifyURL != "" {
|
|
||||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
|
||||||
} else {
|
|
||||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
|
|
||||||
}
|
|
||||||
if data.Device == "wechat" {
|
|
||||||
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
|
|
||||||
OutTradeNo: orderNo,
|
|
||||||
TotalFee: int(amount * 100),
|
|
||||||
Subject: product.Name,
|
Subject: product.Name,
|
||||||
NotifyURL: notifyURL,
|
NotifyURL: notifyURL,
|
||||||
ClientIP: c.ClientIP(),
|
ClientIP: c.ClientIP(),
|
||||||
|
Device: data.Device,
|
||||||
|
PayWay: payment.PayWayWX,
|
||||||
})
|
})
|
||||||
} else {
|
if err != nil {
|
||||||
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if h.config.Epay.Enabled { // 聚合支付
|
||||||
|
logger.Debugf("聚合支付%+v", data)
|
||||||
|
data.Channel = payment.PayChannelEpay
|
||||||
|
if h.config.Epay.Domain != "" {
|
||||||
|
data.Domain = h.config.Epay.Domain
|
||||||
|
}
|
||||||
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
|
||||||
|
params := payment.PayRequest{
|
||||||
OutTradeNo: orderNo,
|
OutTradeNo: orderNo,
|
||||||
TotalFee: int(amount * 100),
|
|
||||||
Subject: product.Name,
|
Subject: product.Name,
|
||||||
|
TotalFee: fmt.Sprintf("%f", amount),
|
||||||
|
ClientIP: c.ClientIP(),
|
||||||
|
Device: data.Device,
|
||||||
|
PayWay: payment.PayWayWX,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := h.epayService.Pay(params)
|
||||||
|
logger.Debugf("请求支付结果,%+v", r)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
payURL = r
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resp.ERROR(c, "系统没有配置可用的支付渠道!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "alipay":
|
||||||
|
if h.config.Alipay.Enabled {
|
||||||
|
logger.Debugf("支付宝,%+v", data)
|
||||||
|
data.Channel = payment.PayChannelAL
|
||||||
|
if h.config.Alipay.Domain != "" { // 用于本地调试支付
|
||||||
|
data.Domain = h.config.Alipay.Domain
|
||||||
|
}
|
||||||
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain)
|
||||||
|
money := fmt.Sprintf("%.2f", amount)
|
||||||
|
payURL, err = h.alipayService.Pay(payment.PayRequest{
|
||||||
|
Device: data.Device,
|
||||||
|
OutTradeNo: orderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
TotalFee: money,
|
||||||
NotifyURL: notifyURL,
|
NotifyURL: notifyURL,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
break
|
|
||||||
case "hupi":
|
|
||||||
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
|
|
||||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
|
||||||
} else {
|
|
||||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
|
|
||||||
}
|
|
||||||
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
|
|
||||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
|
||||||
} else {
|
|
||||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
|
||||||
}
|
|
||||||
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
|
|
||||||
Version: "1.1",
|
|
||||||
TradeOrderId: orderNo,
|
|
||||||
TotalFee: fmt.Sprintf("%f", amount),
|
|
||||||
Title: product.Name,
|
|
||||||
NotifyURL: notifyURL,
|
|
||||||
ReturnURL: returnURL,
|
|
||||||
WapName: "GeekAI助手",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payURL = r.URL
|
|
||||||
break
|
|
||||||
case "geek":
|
|
||||||
if h.App.Config.GeekPayConfig.NotifyURL != "" {
|
|
||||||
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
|
|
||||||
} else {
|
|
||||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
|
|
||||||
}
|
|
||||||
if h.App.Config.GeekPayConfig.ReturnURL != "" {
|
|
||||||
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
|
|
||||||
}
|
|
||||||
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
|
|
||||||
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
|
|
||||||
} else {
|
|
||||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
|
||||||
}
|
|
||||||
params := payment.GeekPayParams{
|
|
||||||
OutTradeNo: orderNo,
|
|
||||||
Method: "web",
|
|
||||||
Name: product.Name,
|
|
||||||
Money: fmt.Sprintf("%f", amount),
|
|
||||||
ClientIP: c.ClientIP(),
|
|
||||||
Device: data.Device,
|
|
||||||
Type: data.PayType,
|
|
||||||
ReturnURL: returnURL,
|
|
||||||
NotifyURL: notifyURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := h.geekPayService.Pay(params)
|
if err != nil {
|
||||||
if err != nil {
|
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||||
resp.ERROR(c, err.Error())
|
return
|
||||||
|
}
|
||||||
|
} else if h.config.Epay.Enabled { // 聚合支付
|
||||||
|
logger.Debugf("聚合支付,%+v", data)
|
||||||
|
data.Channel = payment.PayChannelEpay
|
||||||
|
if h.config.Epay.Domain != "" {
|
||||||
|
data.Domain = h.config.Epay.Domain
|
||||||
|
}
|
||||||
|
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
|
||||||
|
params := payment.PayRequest{
|
||||||
|
OutTradeNo: orderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
TotalFee: fmt.Sprintf("%f", amount),
|
||||||
|
ClientIP: c.ClientIP(),
|
||||||
|
Device: data.Device,
|
||||||
|
PayWay: data.PayWay,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := h.epayService.Pay(params)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
payURL = r
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resp.ERROR(c, "系统没有配置可用的支付渠道!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payURL = res.PayURL
|
|
||||||
default:
|
default:
|
||||||
resp.ERROR(c, "不支持的支付渠道")
|
resp.ERROR(c, "不支持的支付渠道")
|
||||||
return
|
return
|
||||||
@@ -234,43 +313,40 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
|||||||
|
|
||||||
// 创建订单
|
// 创建订单
|
||||||
remark := types.OrderRemark{
|
remark := types.OrderRemark{
|
||||||
Days: product.Days,
|
Power: product.Power,
|
||||||
Power: product.Power,
|
Name: product.Name,
|
||||||
Name: product.Name,
|
Price: product.Price,
|
||||||
Price: product.Price,
|
|
||||||
Discount: product.Discount,
|
|
||||||
}
|
}
|
||||||
order := model.Order{
|
order := model.Order{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
ProductId: product.Id,
|
OrderNo: orderNo,
|
||||||
OrderNo: orderNo,
|
Subject: product.Name,
|
||||||
Subject: product.Name,
|
Amount: amount,
|
||||||
Amount: amount,
|
Status: types.OrderNotPaid,
|
||||||
Status: types.OrderNotPaid,
|
PayWay: data.PayWay,
|
||||||
PayWay: data.PayWay,
|
Channel: data.Channel,
|
||||||
PayType: data.PayType,
|
Remark: utils.JsonEncode(remark),
|
||||||
Remark: utils.JsonEncode(remark),
|
|
||||||
}
|
}
|
||||||
err = h.DB.Create(&order).Error
|
err = h.DB.Create(&order).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "error with create order: "+err.Error())
|
resp.ERROR(c, "error with create order: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, payURL)
|
resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 异步通知回调公共逻辑
|
// 支付成功处理
|
||||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
var order model.Order
|
var order model.Order
|
||||||
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
|
err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error with fetch order: %v", err)
|
return fmt.Errorf("error with fetch order: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.lock.Lock()
|
|
||||||
defer h.lock.Unlock()
|
|
||||||
|
|
||||||
// 已支付订单,直接返回
|
// 已支付订单,直接返回
|
||||||
if order.Status == types.OrderPaidSuccess {
|
if order.Status == types.OrderPaidSuccess {
|
||||||
return nil
|
return nil
|
||||||
@@ -290,19 +366,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
|
|
||||||
// 增加用户算力
|
// 增加用户算力
|
||||||
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
|
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
|
||||||
Type: types.PowerRecharge,
|
Type: types.PowerRecharge,
|
||||||
Model: order.PayWay,
|
Model: order.Subject,
|
||||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新订单状态
|
// 更新订单状态
|
||||||
order.PayTime = time.Now().Unix()
|
order.PayTime = utils.Str2stamp(info.PayTime)
|
||||||
order.Status = types.OrderPaidSuccess
|
order.Status = types.OrderPaidSuccess
|
||||||
order.TradeNo = tradeNo
|
order.TradeNo = info.TradeId
|
||||||
err = h.DB.Updates(&order).Error
|
order.Checked = true
|
||||||
|
err = h.DB.Debug().Updates(&order).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error with update order info: %v", err)
|
return fmt.Errorf("error with update order info: %v", err)
|
||||||
}
|
}
|
||||||
@@ -317,54 +395,6 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPayWays 获取支付方式
|
|
||||||
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
|
||||||
payWays := make([]gin.H, 0)
|
|
||||||
if h.App.Config.AlipayConfig.Enabled {
|
|
||||||
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
|
|
||||||
}
|
|
||||||
if h.App.Config.HuPiPayConfig.Enabled {
|
|
||||||
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
|
|
||||||
}
|
|
||||||
if h.App.Config.GeekPayConfig.Enabled {
|
|
||||||
for _, v := range h.App.Config.GeekPayConfig.Methods {
|
|
||||||
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if h.App.Config.WechatPayConfig.Enabled {
|
|
||||||
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c, payWays)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HuPiPayNotify 虎皮椒支付异步回调
|
|
||||||
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
|
|
||||||
err := c.Request.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusOK, "fail")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
orderNo := c.Request.Form.Get("trade_order_id")
|
|
||||||
tradeNo := c.Request.Form.Get("open_order_id")
|
|
||||||
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
|
|
||||||
|
|
||||||
if err = h.huPiPayService.Check(orderNo); err != nil {
|
|
||||||
logger.Error("订单校验失败:", err)
|
|
||||||
c.String(http.StatusOK, "fail")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.notify(orderNo, tradeNo)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
c.String(http.StatusOK, "fail")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.String(http.StatusOK, "success")
|
|
||||||
}
|
|
||||||
|
|
||||||
// AlipayNotify 支付宝支付回调
|
// AlipayNotify 支付宝支付回调
|
||||||
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||||
err := c.Request.ParseForm()
|
err := c.Request.ParseForm()
|
||||||
@@ -373,16 +403,15 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.alipayService.TradeVerify(c.Request)
|
orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no"))
|
||||||
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
|
logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo)
|
||||||
if !result.Success() {
|
if !orderInfo.Success() {
|
||||||
logger.Error("订单校验失败:", result.Message)
|
logger.Errorf("订单校验失败:%v", err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tradeNo := c.Request.Form.Get("trade_no")
|
err = h.paySuccess(orderInfo)
|
||||||
err = h.notify(result.OutTradeNo, tradeNo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
@@ -392,28 +421,35 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
|||||||
c.String(http.StatusOK, "success")
|
c.String(http.StatusOK, "success")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GeekPayNotify 支付异步回调
|
// EPayNotify 易支付支付异步回调
|
||||||
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
func (h *PaymentHandler) EPayNotify(c *gin.Context) {
|
||||||
var params = make(map[string]string)
|
var params = make(map[string]string)
|
||||||
for k := range c.Request.URL.Query() {
|
for k := range c.Request.URL.Query() {
|
||||||
params[k] = c.Query(k)
|
params[k] = c.Query(k)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("收到GeekPay订单支付回调:%+v", params)
|
logger.Infof("收到易支付订单支付回调:%+v", params)
|
||||||
// 检查支付状态
|
// 检查支付状态, 如果未支付,则返回成功
|
||||||
if params["trade_status"] != "TRADE_SUCCESS" {
|
if params["trade_status"] != "TRADE_SUCCESS" {
|
||||||
c.String(http.StatusOK, "success")
|
c.String(http.StatusOK, "success")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sign := h.geekPayService.Sign(params)
|
sign := h.epayService.Sign(params)
|
||||||
if sign != c.Query("sign") {
|
if sign != c.Query("sign") {
|
||||||
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
|
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 查询订单状态
|
||||||
|
order, err := h.epayService.Query(params["out_trade_no"])
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
c.String(http.StatusOK, "fail")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err := h.notify(params["out_trade_no"], params["trade_no"])
|
err = h.paySuccess(order)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
@@ -423,26 +459,23 @@ func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
|||||||
c.String(http.StatusOK, "success")
|
c.String(http.StatusOK, "success")
|
||||||
}
|
}
|
||||||
|
|
||||||
// WechatPayNotify 微信商户支付异步回调
|
// WxpayNotify 微信商户支付异步回调
|
||||||
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
func (h *PaymentHandler) WxpayNotify(c *gin.Context) {
|
||||||
err := c.Request.ParseForm()
|
err := c.Request.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.wechatPayService.TradeVerify(c.Request)
|
orderInfo, err := h.wxpayService.TradeVerify(c.Request)
|
||||||
logger.Infof("收到微信商号订单支付回调:%+v", result)
|
logger.Infof("收到微信商号订单支付回调:%+v", orderInfo)
|
||||||
if !result.Success() {
|
if err != nil {
|
||||||
logger.Error("订单校验失败:", err)
|
logger.Errorf("订单校验失败:%v", err)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"})
|
||||||
"code": "FAIL",
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notify(result.OutTradeNo, result.TradeId)
|
err = h.paySuccess(orderInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
c.String(http.StatusOK, "fail")
|
c.String(http.StatusOK, "fail")
|
||||||
|
|||||||
@@ -9,11 +9,13 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -27,6 +29,18 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
|||||||
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *PowerLogHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/powerLog/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("list", h.List)
|
||||||
|
group.GET("stats", h.Stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
@@ -72,3 +86,45 @@ func (h *PowerLogHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stats 获取用户算力统计
|
||||||
|
func (h *PowerLogHandler) Stats(c *gin.Context) {
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
if userId == 0 {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息(包含余额)
|
||||||
|
var user model.User
|
||||||
|
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||||
|
resp.ERROR(c, "用户不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总消费(所有支出记录)
|
||||||
|
var totalConsume int64
|
||||||
|
h.DB.Model(&model.PowerLog{}).
|
||||||
|
Where("user_id", userId).
|
||||||
|
Where("mark", types.PowerSub).
|
||||||
|
Select("COALESCE(SUM(amount), 0)").
|
||||||
|
Scan(&totalConsume)
|
||||||
|
|
||||||
|
// 计算今日消费
|
||||||
|
today := time.Now().Format("2006-01-02")
|
||||||
|
var todayConsume int64
|
||||||
|
h.DB.Model(&model.PowerLog{}).
|
||||||
|
Where("user_id", userId).
|
||||||
|
Where("mark", types.PowerSub).
|
||||||
|
Where("DATE(created_at) = ?", today).
|
||||||
|
Select("COALESCE(SUM(amount), 0)").
|
||||||
|
Scan(&todayConsume)
|
||||||
|
|
||||||
|
stats := map[string]interface{}{
|
||||||
|
"total": totalConsume,
|
||||||
|
"today": todayConsume,
|
||||||
|
"balance": user.Power,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, stats)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -25,6 +26,12 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
|||||||
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *ProductHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/product/")
|
||||||
|
group.GET("list", h.List)
|
||||||
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ProductHandler) List(c *gin.Context) {
|
func (h *ProductHandler) List(c *gin.Context) {
|
||||||
var items []model.Product
|
var items []model.Product
|
||||||
|
|||||||
@@ -10,12 +10,14 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -39,6 +41,20 @@ func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *PromptHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/prompt/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis)).Use(middleware.RateLimitEvery(h.App.Redis, 30*time.Second))
|
||||||
|
{
|
||||||
|
group.POST("lyric", h.Lyric)
|
||||||
|
group.POST("image", h.Image)
|
||||||
|
group.POST("video", h.Video)
|
||||||
|
group.POST("meta", h.MetaPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Lyric 生成歌词
|
// Lyric 生成歌词
|
||||||
func (h *PromptHandler) Lyric(c *gin.Context) {
|
func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
@@ -48,25 +64,12 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.App.SysConfig.PromptPower > 0 {
|
|
||||||
userId := h.GetLoginUserId(c)
|
|
||||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Model: h.getPromptModel(),
|
|
||||||
Remark: "生成歌词",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
resp.SUCCESS(c, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,23 +82,12 @@ func (h *PromptHandler) Image(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.App.SysConfig.PromptPower > 0 {
|
|
||||||
userId := h.GetLoginUserId(c)
|
|
||||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Model: h.getPromptModel(),
|
|
||||||
Remark: "生成绘画提示词",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,25 +100,12 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.App.SysConfig.PromptPower > 0 {
|
|
||||||
userId := h.GetLoginUserId(c)
|
|
||||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
|
||||||
Type: types.PowerConsume,
|
|
||||||
Model: h.getPromptModel(),
|
|
||||||
Remark: "生成视频脚本",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,9 +137,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *PromptHandler) getPromptModel() string {
|
func (h *PromptHandler) getPromptModel() string {
|
||||||
if h.App.SysConfig.AssistantModelId > 0 {
|
if h.App.SysConfig.Base.AssistantModelId > 0 {
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
|
h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel)
|
||||||
return chatModel.Value
|
return chatModel.Value
|
||||||
}
|
}
|
||||||
return "gpt-4o"
|
return "gpt-4o"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -39,6 +40,18 @@ func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *servic
|
|||||||
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
|
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *RealtimeHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/realtime/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.Any("", h.Connection)
|
||||||
|
group.POST("voice", h.VoiceChat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||||
// 获取客户端请求中指定的子协议
|
// 获取客户端请求中指定的子协议
|
||||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||||
@@ -154,7 +167,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.AdvanceVoicePower {
|
if user.Power < h.App.SysConfig.Base.AdvanceVoicePower {
|
||||||
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
|
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -198,7 +211,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
|||||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
|
||||||
// 扣减算力
|
// 扣减算力
|
||||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
|
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{
|
||||||
Type: types.PowerConsume,
|
Type: types.PowerConsume,
|
||||||
Model: "advanced-voice",
|
Model: "advanced-voice",
|
||||||
Remark: "实时语音通话",
|
Remark: "实时语音通话",
|
||||||
|
|||||||
@@ -10,14 +10,16 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils/resp"
|
"geekai/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RedeemHandler struct {
|
type RedeemHandler struct {
|
||||||
@@ -30,6 +32,17 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
|
|||||||
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
|
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *RedeemHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/redeem/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("verify", h.Verify)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *RedeemHandler) Verify(c *gin.Context) {
|
func (h *RedeemHandler) Verify(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/sd"
|
"geekai/service/sd"
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
@@ -28,12 +30,13 @@ import (
|
|||||||
|
|
||||||
type SdJobHandler struct {
|
type SdJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
sdService *sd.Service
|
sdService *sd.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
leveldb *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSdJobHandler(app *core.AppServer,
|
func NewSdJobHandler(app *core.AppServer,
|
||||||
@@ -42,13 +45,15 @@ func NewSdJobHandler(app *core.AppServer,
|
|||||||
manager *oss.UploaderManager,
|
manager *oss.UploaderManager,
|
||||||
snowflake *service.Snowflake,
|
snowflake *service.Snowflake,
|
||||||
userService *service.UserService,
|
userService *service.UserService,
|
||||||
levelDB *store.LevelDB) *SdJobHandler {
|
levelDB *store.LevelDB,
|
||||||
|
moderationManager *moderation.ServiceManager) *SdJobHandler {
|
||||||
return &SdJobHandler{
|
return &SdJobHandler{
|
||||||
sdService: service,
|
sdService: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
leveldb: levelDB,
|
leveldb: levelDB,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -56,6 +61,23 @@ func NewSdJobHandler(app *core.AppServer,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *SdJobHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/sd/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("image", h.Image)
|
||||||
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("publish", h.Publish)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -63,7 +85,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.SdPower {
|
if user.Power < h.App.SysConfig.Base.SdPower {
|
||||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -84,6 +106,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: h.GetLoginUserId(c),
|
||||||
|
Source: types.ModerationSourceSD,
|
||||||
|
Input: data.Prompt,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
if data.Width <= 0 {
|
if data.Width <= 0 {
|
||||||
data.Width = 512
|
data.Width = 512
|
||||||
}
|
}
|
||||||
@@ -131,7 +176,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
HdSteps: data.HdSteps,
|
HdSteps: data.HdSteps,
|
||||||
},
|
},
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||||
}
|
}
|
||||||
|
|
||||||
job := model.SdJob{
|
job := model.SdJob{
|
||||||
@@ -142,7 +187,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
TaskInfo: utils.JsonEncode(task),
|
TaskInfo: utils.JsonEncode(task),
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Power: h.App.SysConfig.SdPower,
|
Power: h.App.SysConfig.Base.SdPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
res := h.DB.Create(&job)
|
res := h.DB.Create(&job)
|
||||||
|
|||||||
@@ -24,24 +24,31 @@ const CodeStorePrefix = "/verify/codes/"
|
|||||||
|
|
||||||
type SmsHandler struct {
|
type SmsHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
sms *sms.ServiceManager
|
sms *sms.SmsManager
|
||||||
smtp *service.SmtpService
|
smtp *service.SmtpService
|
||||||
captcha *service.CaptchaService
|
captchaService *service.CaptchaService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSmsHandler(
|
func NewSmsHandler(
|
||||||
app *core.AppServer,
|
app *core.AppServer,
|
||||||
client *redis.Client,
|
client *redis.Client,
|
||||||
sms *sms.ServiceManager,
|
sms *sms.SmsManager,
|
||||||
smtp *service.SmtpService,
|
smtp *service.SmtpService,
|
||||||
captcha *service.CaptchaService) *SmsHandler {
|
captcha *service.CaptchaService) *SmsHandler {
|
||||||
return &SmsHandler{
|
return &SmsHandler{
|
||||||
redis: client,
|
redis: client,
|
||||||
sms: sms,
|
sms: sms,
|
||||||
captcha: captcha,
|
captchaService: captcha,
|
||||||
smtp: smtp,
|
smtp: smtp,
|
||||||
BaseHandler: BaseHandler{App: app}}
|
BaseHandler: BaseHandler{App: app}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *SmsHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/sms/")
|
||||||
|
// 无需授权的接口
|
||||||
|
group.POST("code", h.SendCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendCode 发送验证码
|
// SendCode 发送验证码
|
||||||
@@ -56,12 +63,12 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.App.SysConfig.EnabledVerify {
|
if h.captchaService.GetConfig().Enabled {
|
||||||
var check bool
|
var check bool
|
||||||
if data.X != 0 {
|
if data.X != 0 {
|
||||||
check = h.captcha.SlideCheck(data)
|
check = h.captchaService.SlideCheck(data)
|
||||||
} else {
|
} else {
|
||||||
check = h.captcha.Check(data)
|
check = h.captchaService.Check(data)
|
||||||
}
|
}
|
||||||
if !check {
|
if !check {
|
||||||
resp.ERROR(c, "请先完人机验证")
|
resp.ERROR(c, "请先完人机验证")
|
||||||
@@ -72,14 +79,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
code := utils.RandomNumber(6)
|
code := utils.RandomNumber(6)
|
||||||
var err error
|
var err error
|
||||||
if strings.Contains(data.Receiver, "@") { // email
|
if strings.Contains(data.Receiver, "@") { // email
|
||||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
|
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") {
|
||||||
resp.ERROR(c, "系统已禁用邮箱注册!")
|
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 检查邮箱后缀是否在白名单
|
// 检查邮箱后缀是否在白名单
|
||||||
if len(h.App.SysConfig.EmailWhiteList) > 0 {
|
if len(h.App.SysConfig.Base.EmailWhiteList) > 0 {
|
||||||
inWhiteList := false
|
inWhiteList := false
|
||||||
for _, suffix := range h.App.SysConfig.EmailWhiteList {
|
for _, suffix := range h.App.SysConfig.Base.EmailWhiteList {
|
||||||
if strings.HasSuffix(data.Receiver, suffix) {
|
if strings.HasSuffix(data.Receiver, suffix) {
|
||||||
inWhiteList = true
|
inWhiteList = true
|
||||||
break
|
break
|
||||||
@@ -92,7 +99,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||||
} else {
|
} else {
|
||||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") {
|
||||||
resp.ERROR(c, "系统已禁用手机号注册!")
|
resp.ERROR(c, "系统已禁用手机号注册!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/suno"
|
"geekai/service/suno"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -26,20 +28,41 @@ import (
|
|||||||
|
|
||||||
type SunoHandler struct {
|
type SunoHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
sunoService *suno.Service
|
sunoService *suno.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
|
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *SunoHandler {
|
||||||
return &SunoHandler{
|
return &SunoHandler{
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
},
|
},
|
||||||
sunoService: service,
|
sunoService: service,
|
||||||
uploader: uploader,
|
uploader: uploader,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *SunoHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/suno/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.GET("play", h.Play)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("create", h.Create)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("publish", h.Publish)
|
||||||
|
group.POST("update", h.Update)
|
||||||
|
group.GET("detail", h.Detail)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,13 +87,36 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: h.GetLoginUserId(c),
|
||||||
|
Source: types.ModerationSourceSuno,
|
||||||
|
Input: data.Prompt,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.SunoPower {
|
if user.Power < h.App.SysConfig.Base.SunoPower {
|
||||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -118,7 +164,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
|||||||
RefSongId: data.RefSongId,
|
RefSongId: data.RefSongId,
|
||||||
RefTaskId: data.RefTaskId,
|
RefTaskId: data.RefTaskId,
|
||||||
ExtendSecs: data.ExtendSecs,
|
ExtendSecs: data.ExtendSecs,
|
||||||
Power: h.App.SysConfig.SunoPower,
|
Power: h.App.SysConfig.Base.SunoPower,
|
||||||
SongId: utils.RandString(32),
|
SongId: utils.RandString(32),
|
||||||
}
|
}
|
||||||
if data.Lyrics != "" {
|
if data.Lyrics != "" {
|
||||||
|
|||||||
@@ -1,21 +1,36 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/payment"
|
"geekai/service/payment"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"net/http"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestHandler struct {
|
type TestHandler struct {
|
||||||
|
App *core.AppServer
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
js *payment.GeekPayService
|
js *payment.EPayService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
|
func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler {
|
||||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
return &TestHandler{App: app, db: db, snowflake: snowflake, js: js}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *TestHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/test/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.Any("sse", h.PostTest, h.SseTest)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TestHandler) SseTest(c *gin.Context) {
|
func (h *TestHandler) SseTest(c *gin.Context) {
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ package handler
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
@@ -20,8 +22,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
@@ -36,8 +36,10 @@ type UserHandler struct {
|
|||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
levelDB *store.LevelDB
|
levelDB *store.LevelDB
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
captcha *service.CaptchaService
|
captchaService *service.CaptchaService
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
wxLoginService *service.WxLoginService
|
||||||
|
ipSearcher *xdb.Searcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(
|
func NewUserHandler(
|
||||||
@@ -48,15 +50,45 @@ func NewUserHandler(
|
|||||||
levelDB *store.LevelDB,
|
levelDB *store.LevelDB,
|
||||||
captcha *service.CaptchaService,
|
captcha *service.CaptchaService,
|
||||||
userService *service.UserService,
|
userService *service.UserService,
|
||||||
|
wxLoginService *service.WxLoginService,
|
||||||
|
ipSearcher *xdb.Searcher,
|
||||||
licenseService *service.LicenseService) *UserHandler {
|
licenseService *service.LicenseService) *UserHandler {
|
||||||
return &UserHandler{
|
return &UserHandler{
|
||||||
BaseHandler: BaseHandler{DB: db, App: app},
|
BaseHandler: BaseHandler{DB: db, App: app},
|
||||||
searcher: searcher,
|
searcher: searcher,
|
||||||
redis: client,
|
redis: client,
|
||||||
levelDB: levelDB,
|
levelDB: levelDB,
|
||||||
captcha: captcha,
|
captchaService: captcha,
|
||||||
licenseService: licenseService,
|
licenseService: licenseService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
wxLoginService: wxLoginService,
|
||||||
|
ipSearcher: ipSearcher,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *UserHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/user/")
|
||||||
|
|
||||||
|
// 公开接口,不需要授权
|
||||||
|
group.POST("register", h.Register)
|
||||||
|
group.POST("login", h.Login)
|
||||||
|
group.POST("resetPass", h.ResetPass)
|
||||||
|
group.GET("login/qrcode", h.GetWxLoginQRCode)
|
||||||
|
group.POST("login/callback", h.WxLoginCallback)
|
||||||
|
group.GET("login/status", h.GetWxLoginState)
|
||||||
|
group.GET("logout", h.Logout)
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.GET("session", h.Session)
|
||||||
|
group.GET("profile", h.Profile)
|
||||||
|
group.POST("profile/update", h.ProfileUpdate)
|
||||||
|
group.POST("password", h.UpdatePass)
|
||||||
|
group.POST("bind/mobile", h.BindMobile)
|
||||||
|
group.POST("bind/email", h.BindEmail)
|
||||||
|
group.GET("signin", h.SignIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,12 +112,13 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
|
// 人机验证
|
||||||
|
if h.captchaService.GetConfig().Enabled {
|
||||||
var check bool
|
var check bool
|
||||||
if data.X != 0 {
|
if data.X != 0 {
|
||||||
check = h.captcha.SlideCheck(data)
|
check = h.captchaService.SlideCheck(data)
|
||||||
} else {
|
} else {
|
||||||
check = h.captcha.Check(data)
|
check = h.captchaService.Check(data)
|
||||||
}
|
}
|
||||||
if !check {
|
if !check {
|
||||||
resp.ERROR(c, "请先完人机验证")
|
resp.ERROR(c, "请先完人机验证")
|
||||||
@@ -125,30 +158,8 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证邀请码
|
|
||||||
inviteCode := model.InviteCode{}
|
|
||||||
if data.InviteCode != "" {
|
|
||||||
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
|
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "无效的邀请码")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
salt := utils.RandString(8)
|
|
||||||
user := model.User{
|
|
||||||
Username: data.Username,
|
|
||||||
Password: utils.GenPassword(data.Password, salt),
|
|
||||||
Avatar: "/images/avatar/user.png",
|
|
||||||
Salt: salt,
|
|
||||||
Status: true,
|
|
||||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
|
||||||
ChatConfig: "{}",
|
|
||||||
ChatModels: "{}",
|
|
||||||
Power: h.App.SysConfig.InitPower,
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if the username is existing
|
// check if the username is existing
|
||||||
|
user := model.User{Username: data.Username, Password: data.Password}
|
||||||
var item model.User
|
var item model.User
|
||||||
session := h.DB.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if data.Mobile != "" {
|
if data.Mobile != "" {
|
||||||
@@ -168,78 +179,19 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 被邀请人也获得赠送算力
|
user, err := h.createNewUser(user, data.InviteCode)
|
||||||
if data.InviteCode != "" {
|
if err != nil {
|
||||||
user.Power += h.App.SysConfig.InvitePower
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.licenseService.GetLicense().Configs.DeCopy {
|
|
||||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
|
||||||
} else {
|
|
||||||
defaultNickname := h.App.SysConfig.DefaultNickname
|
|
||||||
if defaultNickname == "" {
|
|
||||||
defaultNickname = "极客学长"
|
|
||||||
}
|
|
||||||
user.Nickname = fmt.Sprintf("%s@%d", defaultNickname, utils.RandomNumber(6))
|
|
||||||
}
|
|
||||||
|
|
||||||
tx := h.DB.Begin()
|
|
||||||
if err := tx.Create(&user).Error; err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录邀请关系
|
token, err := h.doLogin(&user, c.ClientIP())
|
||||||
if data.InviteCode != "" {
|
|
||||||
// 增加邀请数量
|
|
||||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
|
||||||
if h.App.SysConfig.InvitePower > 0 {
|
|
||||||
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{
|
|
||||||
Type: types.PowerInvite,
|
|
||||||
Model: "Invite",
|
|
||||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加邀请记录
|
|
||||||
err := tx.Create(&model.InviteLog{
|
|
||||||
InviterId: inviteCode.UserId,
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
InviteCode: inviteCode.Code,
|
|
||||||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
|
||||||
}).Error
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
|
|
||||||
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
|
|
||||||
// 自动登录创建 token
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
||||||
"user_id": user.Id,
|
|
||||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
|
||||||
})
|
|
||||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 保存到 redis
|
|
||||||
key = fmt.Sprintf("users/%d", user.Id)
|
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
|
||||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
|
||||||
resp.ERROR(c, "error with save token: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login 用户登录
|
// Login 用户登录
|
||||||
@@ -255,15 +207,12 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
|
if h.captchaService.GetConfig().Enabled {
|
||||||
needVerify, err := h.redis.Get(c, verifyKey).Bool()
|
|
||||||
|
|
||||||
if h.App.SysConfig.EnabledVerify && needVerify {
|
|
||||||
var check bool
|
var check bool
|
||||||
if data.X != 0 {
|
if data.X != 0 {
|
||||||
check = h.captcha.SlideCheck(data)
|
check = h.captchaService.SlideCheck(data)
|
||||||
} else {
|
} else {
|
||||||
check = h.captcha.Check(data)
|
check = h.captchaService.Check(data)
|
||||||
}
|
}
|
||||||
if !check {
|
if !check {
|
||||||
resp.ERROR(c, "请先完人机验证")
|
resp.ERROR(c, "请先完人机验证")
|
||||||
@@ -274,54 +223,28 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
var user model.User
|
var user model.User
|
||||||
res := h.DB.Where("username = ?", data.Username).First(&user)
|
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
h.redis.Set(c, verifyKey, true, 0)
|
|
||||||
resp.ERROR(c, "用户名不存在")
|
resp.ERROR(c, "用户名不存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
if password != user.Password {
|
if password != user.Password {
|
||||||
h.redis.Set(c, verifyKey, true, 0)
|
|
||||||
resp.ERROR(c, "用户名或密码错误")
|
resp.ERROR(c, "用户名或密码错误")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Status == false {
|
if !user.Status {
|
||||||
resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
|
resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新最后登录时间和IP
|
token, err := h.doLogin(&user, c.ClientIP())
|
||||||
user.LastLoginIp = c.ClientIP()
|
|
||||||
user.LastLoginAt = time.Now().Unix()
|
|
||||||
h.DB.Model(&user).Updates(user)
|
|
||||||
|
|
||||||
h.DB.Create(&model.UserLoginLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
LoginIp: c.ClientIP(),
|
|
||||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
|
||||||
})
|
|
||||||
|
|
||||||
// 创建 token
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
||||||
"user_id": user.Id,
|
|
||||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
|
||||||
})
|
|
||||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 保存到 redis
|
|
||||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
|
||||||
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
|
||||||
resp.ERROR(c, "error with save token: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 移除登录行为验证码
|
|
||||||
h.redis.Del(c, verifyKey)
|
|
||||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout 注 销
|
// Logout 注 销
|
||||||
@@ -333,134 +256,165 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CLogin 第三方登录请求二维码
|
// GetWxLoginQRCode 获取微信登录二维码URL
|
||||||
func (h *UserHandler) CLogin(c *gin.Context) {
|
func (h *UserHandler) GetWxLoginQRCode(c *gin.Context) {
|
||||||
returnURL := h.GetTrim(c, "return_url")
|
if !h.wxLoginService.GetConfig().Enabled {
|
||||||
var res types.BizVo
|
resp.ERROR(c, "微信登录功能未启用")
|
||||||
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
|
return
|
||||||
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
|
}
|
||||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
if h.wxLoginService.GetConfig().ApiKey == "" {
|
||||||
SetSuccessResult(&res).
|
resp.ERROR(c, "微信登录服务令牌未配置")
|
||||||
Post(apiURL)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := utils.RandString(32)
|
||||||
|
qrCodeURL, err := h.wxLoginService.GetLoginQrCodeUrl(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.IsErrorState() {
|
|
||||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Code != types.Success {
|
resp.SUCCESS(c, gin.H{
|
||||||
resp.ERROR(c, "error with http response: "+res.Message)
|
"url": qrCodeURL,
|
||||||
return
|
"state": state,
|
||||||
}
|
})
|
||||||
|
|
||||||
resp.SUCCESS(c, res.Data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CLoginCallback 第三方登录回调
|
// 查询微信登录状态
|
||||||
func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
func (h *UserHandler) GetWxLoginState(c *gin.Context) {
|
||||||
loginType := c.Query("login_type")
|
state := c.Query("state")
|
||||||
code := c.Query("code")
|
if state == "" {
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
resp.ERROR(c, "参数错误")
|
||||||
action := c.Query("action")
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var res types.BizVo
|
status, err := h.wxLoginService.GetLoginStatus(state)
|
||||||
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
|
|
||||||
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
|
|
||||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.IsErrorState() {
|
|
||||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
if status.Status != service.LoginStatusSuccess {
|
||||||
|
resp.SUCCESS(c, status)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Code != types.Success {
|
// 登录成功
|
||||||
resp.ERROR(c, "error with http response: "+res.Message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// login successfully
|
|
||||||
data := res.Data.(map[string]interface{})
|
|
||||||
var user model.User
|
var user model.User
|
||||||
if action == "bind" && userId > 0 {
|
h.DB.Where("openid = ?", status.OpenID).First(&user)
|
||||||
err = h.DB.Where("openid", data["openid"]).First(&user).Error
|
if user.Id == 0 {
|
||||||
if err == nil {
|
// 创建新用户
|
||||||
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
|
user, err = h.createNewUser(model.User{OpenId: status.OpenID}, "")
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.DB.Where("id", userId).First(&user).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "绑定用户不存在")
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
|
token, err := h.doLogin(&user, c.ClientIP())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "更新用户信息失败,"+err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, gin.H{"token": ""})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session := gin.H{}
|
status.Status = service.LoginStatusExpired
|
||||||
tx := h.DB.Where("openid", data["openid"]).First(&user)
|
h.wxLoginService.SetLoginStatus(state, *status)
|
||||||
if tx.Error != nil {
|
|
||||||
// create new user
|
|
||||||
var totalUser int64
|
|
||||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
|
||||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
|
||||||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
salt := utils.RandString(8)
|
status.Status = service.LoginStatusSuccess
|
||||||
password := fmt.Sprintf("%d", utils.RandomNumber(8))
|
status.Token = token
|
||||||
user = model.User{
|
resp.SUCCESS(c, status)
|
||||||
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
|
}
|
||||||
Password: utils.GenPassword(password, salt),
|
|
||||||
Avatar: fmt.Sprintf("%s", data["avatar"]),
|
|
||||||
Salt: salt,
|
|
||||||
Status: true,
|
|
||||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
|
||||||
Power: h.App.SysConfig.InitPower,
|
|
||||||
OpenId: fmt.Sprintf("%s", data["openid"]),
|
|
||||||
Nickname: fmt.Sprintf("%s", data["nickname"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
tx = h.DB.Create(&user)
|
// createNewUser 创建新用户
|
||||||
if tx.Error != nil {
|
func (h *UserHandler) createNewUser(user model.User, inviteCode string) (model.User, error) {
|
||||||
resp.ERROR(c, "保存数据失败")
|
if user.OpenId != "" {
|
||||||
logger.Error(tx.Error)
|
user.Platform = "wechat"
|
||||||
return
|
user.Nickname = fmt.Sprintf("微信用户@%d", utils.RandomNumber(6))
|
||||||
|
user.Username = fmt.Sprintf("wx@%d", utils.RandomNumber(8))
|
||||||
|
user.Password = "geekai123"
|
||||||
|
} else {
|
||||||
|
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||||
|
if user.Username == "" || user.Password == "" {
|
||||||
|
return user, fmt.Errorf("用户名或密码不能为空")
|
||||||
}
|
}
|
||||||
session["username"] = user.Username
|
|
||||||
session["password"] = password
|
|
||||||
} else { // login directly
|
|
||||||
// 更新最后登录时间和IP
|
|
||||||
user.LastLoginIp = c.ClientIP()
|
|
||||||
user.LastLoginAt = time.Now().Unix()
|
|
||||||
h.DB.Model(&user).Updates(user)
|
|
||||||
|
|
||||||
h.DB.Create(&model.UserLoginLog{
|
|
||||||
UserId: user.Id,
|
|
||||||
Username: user.Username,
|
|
||||||
LoginIp: c.ClientIP(),
|
|
||||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
salt := utils.RandString(8)
|
||||||
|
user.Salt = salt
|
||||||
|
user.Password = utils.GenPassword(user.Password, salt)
|
||||||
|
user.Avatar = "/images/avatar/user.png"
|
||||||
|
user.Status = true
|
||||||
|
user.ChatRoles = utils.JsonEncode([]string{"gpt"})
|
||||||
|
user.ChatConfig = "{}"
|
||||||
|
user.ChatModels = "{}"
|
||||||
|
user.Power = h.App.SysConfig.Base.InitPower
|
||||||
|
|
||||||
|
// 创建用户
|
||||||
|
tx := h.DB.Begin()
|
||||||
|
if err := tx.Create(&user).Error; err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录邀请关系
|
||||||
|
if inviteCode != "" {
|
||||||
|
inviteCode := model.InviteCode{}
|
||||||
|
err := h.DB.Where("code = ?", inviteCode).First(&inviteCode).Error
|
||||||
|
if err != nil {
|
||||||
|
return user, fmt.Errorf("无效的邀请码")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加邀请数量
|
||||||
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", inviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||||
|
if h.App.SysConfig.Base.InvitePower > 0 {
|
||||||
|
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{
|
||||||
|
Type: types.PowerInvite,
|
||||||
|
Model: "Invite",
|
||||||
|
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加邀请记录
|
||||||
|
err = tx.Create(&model.InviteLog{
|
||||||
|
InviterId: inviteCode.UserId,
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
InviteCode: inviteCode.Code,
|
||||||
|
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower),
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doLogin 执行登录操作
|
||||||
|
func (h *UserHandler) doLogin(user *model.User, ip string) (string, error) {
|
||||||
|
// 更新最后登录时间和IP
|
||||||
|
user.LastLoginIp = ip
|
||||||
|
user.LastLoginAt = time.Now().Unix()
|
||||||
|
err := h.DB.Model(user).Updates(user).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to update user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录登录日志
|
||||||
|
h.DB.Create(&model.UserLoginLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
LoginIp: ip,
|
||||||
|
LoginAddress: utils.Ip2Region(h.ipSearcher, ip),
|
||||||
|
})
|
||||||
|
|
||||||
// 创建 token
|
// 创建 token
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
"user_id": user.Id,
|
"user_id": user.Id,
|
||||||
@@ -468,17 +422,42 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
return "", fmt.Errorf("failed to generate token: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存到 redis
|
// 保存到 redis
|
||||||
key := fmt.Sprintf("users/%d", user.Id)
|
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
if _, err = h.redis.Set(context.Background(), sessionKey, tokenString, 0).Result(); err != nil {
|
||||||
resp.ERROR(c, "error with save token: "+err.Error())
|
return "", fmt.Errorf("error with save token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenString, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WxLoginCallback 微信登录回调处理
|
||||||
|
func (h *UserHandler) WxLoginCallback(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
OpenID string `json:"openid"`
|
||||||
|
State string `json:"state"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
session["token"] = tokenString
|
|
||||||
resp.SUCCESS(c, session)
|
if data.OpenID == "" || data.State == "" {
|
||||||
|
resp.ERROR(c, "参数错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置登录状态
|
||||||
|
status := service.LoginStatus{
|
||||||
|
Status: service.LoginStatusSuccess,
|
||||||
|
OpenID: data.OpenID,
|
||||||
|
}
|
||||||
|
h.wxLoginService.SetLoginStatus(data.State, status)
|
||||||
|
|
||||||
|
resp.SUCCESS(c, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session 获取/验证会话
|
// Session 获取/验证会话
|
||||||
@@ -742,11 +721,11 @@ func (h *UserHandler) SignIn(c *gin.Context) {
|
|||||||
|
|
||||||
// 签到
|
// 签到
|
||||||
h.levelDB.Put(key, true)
|
h.levelDB.Put(key, true)
|
||||||
if h.App.SysConfig.DailyPower > 0 {
|
if h.App.SysConfig.Base.DailyPower > 0 {
|
||||||
h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{
|
h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{
|
||||||
Type: types.PowerSignIn,
|
Type: types.PowerSignIn,
|
||||||
Model: "SignIn",
|
Model: "SignIn",
|
||||||
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower),
|
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
|
"geekai/core/middleware"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/video"
|
"geekai/service/video"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
@@ -26,20 +28,37 @@ import (
|
|||||||
|
|
||||||
type VideoHandler struct {
|
type VideoHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
videoService *video.Service
|
videoService *video.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
moderationManager *moderation.ServiceManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
|
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *VideoHandler {
|
||||||
return &VideoHandler{
|
return &VideoHandler{
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: app,
|
App: app,
|
||||||
DB: db,
|
DB: db,
|
||||||
},
|
},
|
||||||
videoService: service,
|
videoService: service,
|
||||||
uploader: uploader,
|
uploader: uploader,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
moderationManager: moderationManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes 注册路由
|
||||||
|
func (h *VideoHandler) RegisterRoutes() {
|
||||||
|
group := h.App.Engine.Group("/api/video/")
|
||||||
|
|
||||||
|
// 需要用户授权的接口
|
||||||
|
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||||
|
{
|
||||||
|
group.POST("luma/create", h.LumaCreate)
|
||||||
|
group.POST("keling/create", h.KeLingCreate)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.GET("publish", h.Publish)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,13 +81,36 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.App.SysConfig.Moderation.Enable {
|
||||||
|
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to moderate content: ", err)
|
||||||
|
}
|
||||||
|
if moderationResult.Flagged {
|
||||||
|
// 记录违规内容
|
||||||
|
moderation := model.Moderation{
|
||||||
|
UserId: h.GetLoginUserId(c),
|
||||||
|
Source: types.ModerationSourceVideo,
|
||||||
|
Input: data.Prompt,
|
||||||
|
Result: utils.JsonEncode(moderationResult),
|
||||||
|
}
|
||||||
|
err = h.DB.Create(&moderation).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to save moderation: ", err)
|
||||||
|
}
|
||||||
|
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
user, err := h.GetLoginUser(c)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.LumaPower {
|
if user.Power < h.App.SysConfig.Base.LumaPower {
|
||||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -85,14 +127,14 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
|||||||
Type: types.VideoLuma,
|
Type: types.VideoLuma,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Params: params,
|
Params: params,
|
||||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||||
}
|
}
|
||||||
// 插入数据库
|
// 插入数据库
|
||||||
job := model.VideoJob{
|
job := model.VideoJob{
|
||||||
UserId: uint(userId),
|
UserId: uint(userId),
|
||||||
Type: types.VideoLuma,
|
Type: types.VideoLuma,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Power: h.App.SysConfig.LumaPower,
|
Power: h.App.SysConfig.Base.LumaPower,
|
||||||
TaskInfo: utils.JsonEncode(task),
|
TaskInfo: utils.JsonEncode(task),
|
||||||
}
|
}
|
||||||
tx := h.DB.Create(&job)
|
tx := h.DB.Create(&job)
|
||||||
@@ -147,7 +189,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
|||||||
|
|
||||||
// 计算当前任务所需算力
|
// 计算当前任务所需算力
|
||||||
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
|
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
|
||||||
power := h.App.SysConfig.KeLingPowers[key]
|
power := h.App.SysConfig.Base.KeLingPowers[key]
|
||||||
if power == 0 {
|
if power == 0 {
|
||||||
resp.ERROR(c, "当前模型暂不支持")
|
resp.ERROR(c, "当前模型暂不支持")
|
||||||
return
|
return
|
||||||
@@ -181,7 +223,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
|||||||
Type: types.VideoKeLing,
|
Type: types.VideoKeLing,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Params: params,
|
Params: params,
|
||||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||||
Channel: data.Channel,
|
Channel: data.Channel,
|
||||||
}
|
}
|
||||||
// 插入数据库
|
// 插入数据库
|
||||||
|
|||||||
375
api/main.go
375
api/main.go
@@ -19,6 +19,7 @@ import (
|
|||||||
"geekai/service/dalle"
|
"geekai/service/dalle"
|
||||||
"geekai/service/jimeng"
|
"geekai/service/jimeng"
|
||||||
"geekai/service/mj"
|
"geekai/service/mj"
|
||||||
|
"geekai/service/moderation"
|
||||||
"geekai/service/oss"
|
"geekai/service/oss"
|
||||||
"geekai/service/payment"
|
"geekai/service/payment"
|
||||||
"geekai/service/sd"
|
"geekai/service/sd"
|
||||||
@@ -30,7 +31,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strconv"
|
"runtime/debug"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -71,15 +72,16 @@ func main() {
|
|||||||
if configFile == "" {
|
if configFile == "" {
|
||||||
configFile = "config.toml"
|
configFile = "config.toml"
|
||||||
}
|
}
|
||||||
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
|
|
||||||
logger.Info("Loading config file: ", configFile)
|
logger.Info("Loading config file: ", configFile)
|
||||||
if !debug {
|
defer func() {
|
||||||
defer func() {
|
if err := recover(); err != nil {
|
||||||
if err := recover(); err != nil {
|
logger.Error("Panic Error:", err)
|
||||||
logger.Error("Panic Error:", err)
|
// 打印堆栈信息
|
||||||
|
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||||
|
debug.PrintStack()
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
app := fx.New(
|
app := fx.New(
|
||||||
// 初始化配置应用配置
|
// 初始化配置应用配置
|
||||||
@@ -89,16 +91,16 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
config.Path = configFile
|
config.Path = configFile
|
||||||
if debug {
|
|
||||||
_ = core.SaveConfig(config)
|
|
||||||
}
|
|
||||||
return config
|
return config
|
||||||
}),
|
}),
|
||||||
// 创建应用服务
|
// 创建应用服务
|
||||||
fx.Provide(core.NewServer),
|
fx.Provide(core.NewServer),
|
||||||
// 初始化
|
// 初始化
|
||||||
fx.Invoke(func(s *core.AppServer, client *redis.Client) {
|
fx.Invoke(func(s *core.AppServer, client *redis.Client) {
|
||||||
s.Init(debug, client)
|
s.Init(client)
|
||||||
|
}),
|
||||||
|
fx.Provide(func(db *gorm.DB) *types.SystemConfig {
|
||||||
|
return core.LoadSystemConfig(db)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 初始化数据库
|
// 初始化数据库
|
||||||
@@ -126,7 +128,7 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建控制器
|
// 创建控制器
|
||||||
fx.Provide(handler.NewChatRoleHandler),
|
fx.Provide(handler.NewChatAppHandler),
|
||||||
fx.Provide(handler.NewUserHandler),
|
fx.Provide(handler.NewUserHandler),
|
||||||
fx.Provide(handler.NewChatHandler),
|
fx.Provide(handler.NewChatHandler),
|
||||||
fx.Provide(handler.NewNetHandler),
|
fx.Provide(handler.NewNetHandler),
|
||||||
@@ -143,6 +145,12 @@ func main() {
|
|||||||
fx.Provide(handler.NewPowerLogHandler),
|
fx.Provide(handler.NewPowerLogHandler),
|
||||||
fx.Provide(handler.NewJimengHandler),
|
fx.Provide(handler.NewJimengHandler),
|
||||||
|
|
||||||
|
fx.Provide(service.NewMigrationService),
|
||||||
|
fx.Invoke(func(migrationService *service.MigrationService) {
|
||||||
|
migrationService.StartMigrate()
|
||||||
|
}),
|
||||||
|
|
||||||
|
// 管理后台控制器
|
||||||
fx.Provide(admin.NewConfigHandler),
|
fx.Provide(admin.NewConfigHandler),
|
||||||
fx.Provide(admin.NewAdminHandler),
|
fx.Provide(admin.NewAdminHandler),
|
||||||
fx.Provide(admin.NewApiKeyHandler),
|
fx.Provide(admin.NewApiKeyHandler),
|
||||||
@@ -153,34 +161,23 @@ func main() {
|
|||||||
fx.Provide(admin.NewChatModelHandler),
|
fx.Provide(admin.NewChatModelHandler),
|
||||||
fx.Provide(admin.NewProductHandler),
|
fx.Provide(admin.NewProductHandler),
|
||||||
fx.Provide(admin.NewOrderHandler),
|
fx.Provide(admin.NewOrderHandler),
|
||||||
fx.Provide(admin.NewChatHandler),
|
|
||||||
fx.Provide(admin.NewPowerLogHandler),
|
fx.Provide(admin.NewPowerLogHandler),
|
||||||
fx.Provide(admin.NewAdminJimengHandler),
|
fx.Provide(admin.NewAdminJimengHandler),
|
||||||
|
|
||||||
// 创建服务
|
|
||||||
fx.Provide(sms.NewSendServiceManager),
|
|
||||||
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
|
|
||||||
return service.NewCaptchaService(config.ApiConfig)
|
|
||||||
}),
|
|
||||||
fx.Provide(oss.NewUploaderManager),
|
|
||||||
fx.Provide(dalle.NewService),
|
|
||||||
fx.Invoke(func(s *dalle.Service) {
|
|
||||||
s.Run()
|
|
||||||
s.DownloadImages()
|
|
||||||
s.CheckTaskStatus()
|
|
||||||
}),
|
|
||||||
|
|
||||||
fx.Provide(service.NewMigrationService),
|
|
||||||
fx.Invoke(func(s *service.MigrationService) {
|
|
||||||
s.Migrate()
|
|
||||||
}),
|
|
||||||
|
|
||||||
// 邮件服务
|
// 邮件服务
|
||||||
fx.Provide(service.NewSmtpService),
|
fx.Provide(service.NewSmtpService),
|
||||||
// License 服务
|
// License 服务
|
||||||
fx.Provide(service.NewLicenseService),
|
fx.Provide(service.NewLicenseService),
|
||||||
fx.Invoke(func(licenseService *service.LicenseService) {
|
fx.Invoke(func(licenseService *service.LicenseService) {
|
||||||
// licenseService.SyncLicense()
|
licenseService.SyncLicense()
|
||||||
|
}),
|
||||||
|
|
||||||
|
// Dalle 服务
|
||||||
|
fx.Provide(dalle.NewService),
|
||||||
|
fx.Invoke(func(s *dalle.Service) {
|
||||||
|
s.Run()
|
||||||
|
s.DownloadImages()
|
||||||
|
s.CheckTaskStatus()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// MidJourney service pool
|
// MidJourney service pool
|
||||||
@@ -213,302 +210,179 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// 即梦AI 服务
|
// 即梦AI 服务
|
||||||
|
fx.Provide(jimeng.NewClient),
|
||||||
fx.Provide(jimeng.NewService),
|
fx.Provide(jimeng.NewService),
|
||||||
fx.Invoke(func(service *jimeng.Service) {
|
fx.Invoke(func(service *jimeng.Service) {
|
||||||
service.Start()
|
service.Start()
|
||||||
}),
|
}),
|
||||||
fx.Provide(service.NewUserService),
|
|
||||||
fx.Provide(payment.NewAlipayService),
|
|
||||||
fx.Provide(payment.NewHuPiPay),
|
|
||||||
fx.Provide(payment.NewJPayService),
|
|
||||||
fx.Provide(payment.NewWechatService),
|
|
||||||
fx.Provide(service.NewSnowflake),
|
fx.Provide(service.NewSnowflake),
|
||||||
fx.Provide(service.NewXXLJobExecutor),
|
|
||||||
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
|
// 创建短信服务
|
||||||
if config.XXLConfig.Enabled {
|
fx.Provide(sms.NewAliYunSmsService),
|
||||||
go func() {
|
fx.Provide(sms.NewBaoSmsService),
|
||||||
log.Fatal(exec.Run())
|
fx.Provide(sms.NewSmsManager),
|
||||||
}()
|
fx.Provide(func(config *types.SystemConfig) *service.CaptchaService {
|
||||||
}
|
return service.NewCaptchaService(config.Captcha)
|
||||||
|
}),
|
||||||
|
fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService {
|
||||||
|
return service.NewWxLoginService(config.WxLogin, client)
|
||||||
|
}),
|
||||||
|
|
||||||
|
// 支付服务
|
||||||
|
fx.Provide(payment.NewAlipayService),
|
||||||
|
fx.Provide(payment.NewEPayService),
|
||||||
|
fx.Provide(payment.NewWxpayService),
|
||||||
|
|
||||||
|
// 文件上传服务
|
||||||
|
fx.Provide(oss.NewLocalStorage),
|
||||||
|
fx.Provide(oss.NewMiniOss),
|
||||||
|
fx.Provide(oss.NewQiNiuOss),
|
||||||
|
fx.Provide(oss.NewAliYunOss),
|
||||||
|
fx.Provide(oss.NewUploaderManager),
|
||||||
|
|
||||||
|
// 用户服务
|
||||||
|
fx.Provide(service.NewUserService),
|
||||||
|
|
||||||
|
// 文本审查服务
|
||||||
|
fx.Provide(moderation.NewGiteeAIModeration),
|
||||||
|
fx.Provide(moderation.NewBaiduAIModeration),
|
||||||
|
fx.Provide(moderation.NewTencentAIModeration),
|
||||||
|
fx.Provide(moderation.NewServiceManager),
|
||||||
|
fx.Provide(admin.NewModerationHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.ModerationHandler) {
|
||||||
|
h.RegisterRoutes()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 注册路由
|
// 注册路由
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppHandler) {
|
||||||
group := s.Engine.Group("/api/app/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("list/user", h.ListByUser)
|
|
||||||
group.POST("update", h.UpdateRole)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
||||||
group := s.Engine.Group("/api/user/")
|
h.RegisterRoutes()
|
||||||
group.POST("register", h.Register)
|
|
||||||
group.POST("login", h.Login)
|
|
||||||
group.GET("logout", h.Logout)
|
|
||||||
group.GET("session", h.Session)
|
|
||||||
group.GET("profile", h.Profile)
|
|
||||||
group.POST("profile/update", h.ProfileUpdate)
|
|
||||||
group.POST("password", h.UpdatePass)
|
|
||||||
group.POST("bind/mobile", h.BindMobile)
|
|
||||||
group.POST("bind/email", h.BindEmail)
|
|
||||||
group.POST("resetPass", h.ResetPass)
|
|
||||||
group.GET("clogin", h.CLogin)
|
|
||||||
group.GET("clogin/callback", h.CLoginCallback)
|
|
||||||
group.GET("signin", h.SignIn)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
||||||
group := s.Engine.Group("/api/chat/")
|
h.RegisterRoutes()
|
||||||
group.Any("message", h.Chat)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("detail", h.Detail)
|
|
||||||
group.POST("update", h.Update)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("history", h.History)
|
|
||||||
group.GET("clear", h.Clear)
|
|
||||||
group.POST("tokens", h.Tokens)
|
|
||||||
group.GET("stop", h.StopGenerate)
|
|
||||||
group.POST("tts", h.TextToSpeech)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
|
||||||
s.Engine.POST("/api/upload", h.Upload)
|
h.RegisterRoutes()
|
||||||
s.Engine.POST("/api/upload/list", h.List)
|
|
||||||
s.Engine.GET("/api/upload/remove", h.Remove)
|
|
||||||
s.Engine.GET("/api/download", h.Download)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
||||||
group := s.Engine.Group("/api/sms/")
|
h.RegisterRoutes()
|
||||||
group.POST("code", h.SendCode)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
|
||||||
group := s.Engine.Group("/api/captcha/")
|
h.RegisterRoutes()
|
||||||
group.GET("get", h.Get)
|
|
||||||
group.POST("check", h.Check)
|
|
||||||
group.GET("slide/get", h.SlideGet)
|
|
||||||
group.POST("slide/check", h.SlideCheck)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
|
||||||
group := s.Engine.Group("/api/redeem/")
|
h.RegisterRoutes()
|
||||||
group.POST("verify", h.Verify)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||||
group := s.Engine.Group("/api/mj/")
|
h.RegisterRoutes()
|
||||||
group.POST("image", h.Image)
|
|
||||||
group.POST("upscale", h.Upscale)
|
|
||||||
group.POST("variation", h.Variation)
|
|
||||||
group.GET("jobs", h.JobList)
|
|
||||||
group.GET("imgWall", h.ImgWall)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("publish", h.Publish)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||||
group := s.Engine.Group("/api/sd")
|
h.RegisterRoutes()
|
||||||
group.POST("image", h.Image)
|
|
||||||
group.GET("jobs", h.JobList)
|
|
||||||
group.GET("imgWall", h.ImgWall)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("publish", h.Publish)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||||
group := s.Engine.Group("/api/config/")
|
h.RegisterRoutes()
|
||||||
group.GET("get", h.Get)
|
|
||||||
group.GET("license", h.License)
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台路由注册
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||||
group := s.Engine.Group("/api/admin/config")
|
h.RegisterRoutes()
|
||||||
group.POST("update", h.Update)
|
|
||||||
group.GET("get", h.Get)
|
|
||||||
group.POST("active", h.Active)
|
|
||||||
group.GET("fixData", h.FixData)
|
|
||||||
group.GET("license", h.GetLicense)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||||
group := s.Engine.Group("/api/admin/")
|
h.RegisterRoutes()
|
||||||
group.POST("login", h.Login)
|
|
||||||
group.GET("logout", h.Logout)
|
|
||||||
group.GET("session", h.Session)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.POST("enable", h.Enable)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.POST("resetPass", h.ResetPass)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
||||||
group := s.Engine.Group("/api/admin/apikey/")
|
h.RegisterRoutes()
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("set", h.Set)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
||||||
group := s.Engine.Group("/api/admin/user/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("loginLog", h.LoginLog)
|
|
||||||
group.GET("genLoginLink", h.GenLoginLink)
|
|
||||||
group.POST("resetPass", h.ResetPass)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
||||||
group := s.Engine.Group("/api/admin/role/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.POST("sort", h.Sort)
|
|
||||||
group.POST("set", h.Set)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
|
||||||
group := s.Engine.Group("/api/admin/redeem/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("create", h.Create)
|
|
||||||
group.POST("set", h.Set)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.POST("export", h.Export)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||||
group := s.Engine.Group("/api/admin/dashboard/")
|
h.RegisterRoutes()
|
||||||
group.GET("stats", h.Stats)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) {
|
||||||
group := s.Engine.Group("/api/model/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) {
|
||||||
group := s.Engine.Group("/api/admin/model/")
|
h.RegisterRoutes()
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("set", h.Set)
|
|
||||||
group.POST("sort", h.Sort)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
||||||
group := s.Engine.Group("/api/payment/")
|
h.RegisterRoutes()
|
||||||
group.POST("doPay", h.Pay)
|
h.StartSyncOrders()
|
||||||
group.GET("payWays", h.GetPayWays)
|
|
||||||
group.POST("notify/alipay", h.AlipayNotify)
|
|
||||||
group.GET("notify/geek", h.GeekPayNotify)
|
|
||||||
group.POST("notify/wechat", h.WechatPayNotify)
|
|
||||||
group.POST("notify/hupi", h.HuPiPayNotify)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||||
group := s.Engine.Group("/api/admin/product/")
|
h.RegisterRoutes()
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("enable", h.Enable)
|
|
||||||
group.POST("sort", h.Sort)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) {
|
||||||
group := s.Engine.Group("/api/admin/order/")
|
h.RegisterRoutes()
|
||||||
group.POST("list", h.List)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("clear", h.Clear)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
||||||
group := s.Engine.Group("/api/order/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("query", h.Query)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
|
||||||
group := s.Engine.Group("/api/product/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewInviteHandler),
|
fx.Provide(handler.NewInviteHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
|
||||||
group := s.Engine.Group("/api/invite/")
|
h.RegisterRoutes()
|
||||||
group.GET("code", h.Code)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("hits", h.Hits)
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(admin.NewFunctionHandler),
|
fx.Provide(admin.NewFunctionHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
||||||
group := s.Engine.Group("/api/admin/function/")
|
h.RegisterRoutes()
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.POST("set", h.Set)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("token", h.GenToken)
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(admin.NewUploadHandler),
|
fx.Provide(admin.NewUploadHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
||||||
s.Engine.POST("/api/admin/upload", h.Upload)
|
h.RegisterRoutes()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewFunctionHandler),
|
fx.Provide(handler.NewFunctionHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
||||||
group := s.Engine.Group("/api/function/")
|
h.RegisterRoutes()
|
||||||
group.POST("weibo", h.WeiBo)
|
|
||||||
group.POST("zaobao", h.ZaoBao)
|
|
||||||
group.POST("dalle3", h.Dall3)
|
|
||||||
group.POST("websearch", h.WebSearch)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
|
fx.Provide(admin.NewChatHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
||||||
group := s.Engine.Group("/api/admin/chat/")
|
h.RegisterRoutes()
|
||||||
group.POST("list", h.List)
|
|
||||||
group.POST("message", h.Messages)
|
|
||||||
group.GET("history", h.History)
|
|
||||||
group.GET("remove", h.RemoveChat)
|
|
||||||
group.GET("message/remove", h.RemoveMessage)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
|
||||||
group := s.Engine.Group("/api/powerLog/")
|
h.RegisterRoutes()
|
||||||
group.POST("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
|
||||||
group := s.Engine.Group("/api/admin/powerLog/")
|
h.RegisterRoutes()
|
||||||
group.POST("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(admin.NewMenuHandler),
|
fx.Provide(admin.NewMenuHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
|
||||||
group := s.Engine.Group("/api/admin/menu/")
|
h.RegisterRoutes()
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.POST("enable", h.Enable)
|
|
||||||
group.POST("sort", h.Sort)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewMenuHandler),
|
fx.Provide(handler.NewMenuHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
|
||||||
group := s.Engine.Group("/api/menu/")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewMarkMapHandler),
|
fx.Provide(handler.NewMarkMapHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||||
s.Engine.POST("/api/markMap/gen", h.Generate)
|
h.RegisterRoutes()
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewDallJobHandler),
|
fx.Provide(handler.NewDallJobHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||||
group := s.Engine.Group("/api/dall")
|
h.RegisterRoutes()
|
||||||
group.POST("image", h.Image)
|
|
||||||
group.GET("jobs", h.JobList)
|
|
||||||
group.GET("imgWall", h.ImgWall)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("publish", h.Publish)
|
|
||||||
group.GET("models", h.GetModels)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewSunoHandler),
|
fx.Provide(handler.NewSunoHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
||||||
group := s.Engine.Group("/api/suno")
|
h.RegisterRoutes()
|
||||||
group.POST("create", h.Create)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("publish", h.Publish)
|
|
||||||
group.POST("update", h.Update)
|
|
||||||
group.GET("detail", h.Detail)
|
|
||||||
group.GET("play", h.Play)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewVideoHandler),
|
fx.Provide(handler.NewVideoHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
||||||
group := s.Engine.Group("/api/video")
|
h.RegisterRoutes()
|
||||||
group.POST("luma/create", h.LumaCreate)
|
|
||||||
group.POST("keling/create", h.KeLingCreate)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.GET("publish", h.Publish)
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 即梦AI 路由
|
// 即梦AI 路由
|
||||||
@@ -520,30 +394,19 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Provide(admin.NewChatAppTypeHandler),
|
fx.Provide(admin.NewChatAppTypeHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||||
group := s.Engine.Group("/api/admin/app/type")
|
h.RegisterRoutes()
|
||||||
group.POST("save", h.Save)
|
|
||||||
group.GET("list", h.List)
|
|
||||||
group.GET("remove", h.Remove)
|
|
||||||
group.POST("enable", h.Enable)
|
|
||||||
group.POST("sort", h.Sort)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewChatAppTypeHandler),
|
fx.Provide(handler.NewChatAppTypeHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
||||||
group := s.Engine.Group("/api/app/type")
|
h.RegisterRoutes()
|
||||||
group.GET("list", h.List)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewTestHandler),
|
fx.Provide(handler.NewTestHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||||
group := s.Engine.Group("/api/test")
|
h.RegisterRoutes()
|
||||||
group.Any("sse", h.PostTest, h.SseTest)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewPromptHandler),
|
fx.Provide(handler.NewPromptHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||||
group := s.Engine.Group("/api/prompt")
|
h.RegisterRoutes()
|
||||||
group.POST("/lyric", h.Lyric)
|
|
||||||
group.POST("/image", h.Image)
|
|
||||||
group.POST("/video", h.Video)
|
|
||||||
group.POST("/meta", h.MetaPrompt)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -568,23 +431,15 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Provide(admin.NewImageHandler),
|
fx.Provide(admin.NewImageHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
|
||||||
group := s.Engine.Group("/api/admin/image")
|
h.RegisterRoutes()
|
||||||
group.POST("/list/mj", h.MjList)
|
|
||||||
group.POST("/list/sd", h.SdList)
|
|
||||||
group.POST("/list/dall", h.DallList)
|
|
||||||
group.GET("/remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(admin.NewMediaHandler),
|
fx.Provide(admin.NewMediaHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
|
||||||
group := s.Engine.Group("/api/admin/media")
|
h.RegisterRoutes()
|
||||||
group.POST("/suno", h.SunoList)
|
|
||||||
group.POST("/videos", h.Videos)
|
|
||||||
group.GET("/remove", h.Remove)
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewRealtimeHandler),
|
fx.Provide(handler.NewRealtimeHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
|
||||||
s.Engine.Any("/api/realtime", h.Connection)
|
h.RegisterRoutes()
|
||||||
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
// 启动应用程序
|
// 启动应用程序
|
||||||
|
|||||||
@@ -8,35 +8,38 @@ package service
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CaptchaService struct {
|
type CaptchaService struct {
|
||||||
config types.ApiConfig
|
config types.CaptchaConfig
|
||||||
client *req.Client
|
client *req.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
|
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
|
||||||
return &CaptchaService{
|
return &CaptchaService{
|
||||||
config: config,
|
config: captchaConfig,
|
||||||
client: req.C().SetTimeout(10 * time.Second),
|
client: req.C().SetTimeout(10 * time.Second),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
|
||||||
|
s.config = config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
|
||||||
|
return s.config
|
||||||
|
}
|
||||||
|
|
||||||
func (s *CaptchaService) Get() (interface{}, error) {
|
func (s *CaptchaService) Get() (interface{}, error) {
|
||||||
if s.config.Token == "" {
|
url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
|
||||||
return nil, errors.New("无效的 API Token")
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
|
|
||||||
var res types.BizVo
|
var res types.BizVo
|
||||||
r, err := s.client.R().
|
r, err := s.client.R().
|
||||||
SetHeader("AppId", s.config.AppId).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||||
@@ -49,12 +52,11 @@ func (s *CaptchaService) Get() (interface{}, error) {
|
|||||||
return res.Data, nil
|
return res.Data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CaptchaService) Check(data interface{}) bool {
|
func (s *CaptchaService) Check(data any) bool {
|
||||||
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL)
|
url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
|
||||||
var res types.BizVo
|
var res types.BizVo
|
||||||
r, err := s.client.R().
|
r, err := s.client.R().
|
||||||
SetHeader("AppId", s.config.AppId).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
|
||||||
SetBodyJsonMarshal(data).
|
SetBodyJsonMarshal(data).
|
||||||
SetSuccessResult(&res).Post(url)
|
SetSuccessResult(&res).Post(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
@@ -68,16 +70,11 @@ func (s *CaptchaService) Check(data interface{}) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CaptchaService) SlideGet() (interface{}, error) {
|
func (s *CaptchaService) SlideGet() (any, error) {
|
||||||
if s.config.Token == "" {
|
url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
|
||||||
return nil, errors.New("无效的 API Token")
|
|
||||||
}
|
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
|
|
||||||
var res types.BizVo
|
var res types.BizVo
|
||||||
r, err := s.client.R().
|
r, err := s.client.R().
|
||||||
SetHeader("AppId", s.config.AppId).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||||
@@ -90,12 +87,11 @@ func (s *CaptchaService) SlideGet() (interface{}, error) {
|
|||||||
return res.Data, nil
|
return res.Data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CaptchaService) SlideCheck(data interface{}) bool {
|
func (s *CaptchaService) SlideCheck(data any) bool {
|
||||||
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
|
url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
|
||||||
var res types.BizVo
|
var res types.BizVo
|
||||||
r, err := s.client.R().
|
r, err := s.client.R().
|
||||||
SetHeader("AppId", s.config.AppId).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
|
||||||
SetBodyJsonMarshal(data).
|
SetBodyJsonMarshal(data).
|
||||||
SetSuccessResult(&res).Post(url)
|
SetSuccessResult(&res).Post(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
|
|||||||
@@ -1,333 +0,0 @@
|
|||||||
package crawler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"geekai/logger"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-rod/rod"
|
|
||||||
"github.com/go-rod/rod/lib/launcher"
|
|
||||||
"github.com/go-rod/rod/lib/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Service 网络爬虫服务
|
|
||||||
type Service struct {
|
|
||||||
browser *rod.Browser
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewService 创建一个新的爬虫服务
|
|
||||||
func NewService() (*Service, error) {
|
|
||||||
// 启动浏览器
|
|
||||||
path, _ := launcher.LookPath()
|
|
||||||
u := launcher.New().Bin(path).
|
|
||||||
Headless(true). // 无头模式
|
|
||||||
Set("disable-web-security", ""). // 禁用网络安全限制
|
|
||||||
Set("disable-gpu", ""). // 禁用 GPU 加速
|
|
||||||
Set("no-sandbox", ""). // 禁用沙箱模式
|
|
||||||
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
|
|
||||||
MustLaunch()
|
|
||||||
|
|
||||||
browser := rod.New().ControlURL(u).MustConnect()
|
|
||||||
|
|
||||||
return &Service{
|
|
||||||
browser: browser,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchResult 搜索结果
|
|
||||||
type SearchResult struct {
|
|
||||||
Title string `json:"title"` // 标题
|
|
||||||
URL string `json:"url"` // 链接
|
|
||||||
Content string `json:"content"` // 内容摘要
|
|
||||||
}
|
|
||||||
|
|
||||||
// WebSearch 网络搜索
|
|
||||||
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
|
|
||||||
if keyword == "" {
|
|
||||||
return nil, errors.New("搜索关键词不能为空")
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxPages <= 0 {
|
|
||||||
maxPages = 1
|
|
||||||
}
|
|
||||||
if maxPages > 10 {
|
|
||||||
maxPages = 10 // 最多搜索 10 页
|
|
||||||
}
|
|
||||||
|
|
||||||
results := make([]SearchResult, 0)
|
|
||||||
|
|
||||||
// 使用百度搜索
|
|
||||||
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
|
|
||||||
|
|
||||||
// 设置页面超时
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// 创建页面
|
|
||||||
page := s.browser.MustPage()
|
|
||||||
defer page.MustClose()
|
|
||||||
|
|
||||||
// 设置视口大小
|
|
||||||
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
|
|
||||||
Width: 1280,
|
|
||||||
Height: 800,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("设置视口失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 导航到搜索页面
|
|
||||||
err = page.Context(ctx).Navigate(searchURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 等待搜索结果加载完成
|
|
||||||
err = page.WaitLoad()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分析当前页面的搜索结果
|
|
||||||
for i := 0; i < maxPages; i++ {
|
|
||||||
if i > 0 {
|
|
||||||
// 点击下一页按钮
|
|
||||||
nextPage, err := page.Element("a.n")
|
|
||||||
if err != nil || nextPage == nil {
|
|
||||||
break // 没有下一页
|
|
||||||
}
|
|
||||||
|
|
||||||
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
|
|
||||||
if err != nil {
|
|
||||||
break // 点击下一页失败
|
|
||||||
}
|
|
||||||
|
|
||||||
// 等待新页面加载
|
|
||||||
err = page.WaitLoad()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取搜索结果
|
|
||||||
resultElements, err := page.Elements(".result, .c-container")
|
|
||||||
if err != nil || resultElements == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, result := range resultElements {
|
|
||||||
// 获取标题
|
|
||||||
titleElement, err := result.Element("h3, .t")
|
|
||||||
if err != nil || titleElement == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
title, err := titleElement.Text()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取 URL
|
|
||||||
linkElement, err := titleElement.Element("a")
|
|
||||||
if err != nil || linkElement == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
href, err := linkElement.Attribute("href")
|
|
||||||
if err != nil || href == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取内容摘要 - 尝试多个可能的选择器
|
|
||||||
var contentElement *rod.Element
|
|
||||||
var content string
|
|
||||||
|
|
||||||
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
|
|
||||||
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
|
|
||||||
for _, selector := range selectors {
|
|
||||||
contentElement, err = result.Element(selector)
|
|
||||||
if err == nil && contentElement != nil {
|
|
||||||
content, _ = contentElement.Text()
|
|
||||||
if content != "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果所有选择器都失败,尝试直接从结果块中提取文本
|
|
||||||
if content == "" {
|
|
||||||
// 获取结果元素的所有文本
|
|
||||||
fullText, err := result.Text()
|
|
||||||
if err == nil && fullText != "" {
|
|
||||||
// 简单处理:从全文中移除标题,剩下的可能是摘要
|
|
||||||
fullText = strings.Replace(fullText, title, "", 1)
|
|
||||||
// 清理文本
|
|
||||||
content = strings.TrimSpace(fullText)
|
|
||||||
// 限制内容长度
|
|
||||||
if len(content) > 200 {
|
|
||||||
content = content[:200] + "..."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加到结果集
|
|
||||||
results = append(results, SearchResult{
|
|
||||||
Title: title,
|
|
||||||
URL: *href,
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
|
|
||||||
// 限制结果数量,每页最多 10 条
|
|
||||||
if len(results) >= 10*maxPages {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取真实 URL(百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL)
|
|
||||||
for i, result := range results {
|
|
||||||
realURL, err := s.getRedirectURL(result.URL)
|
|
||||||
if err == nil && realURL != "" {
|
|
||||||
results[i].URL = realURL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取真实 URL
|
|
||||||
func (s *Service) getRedirectURL(shortURL string) (string, error) {
|
|
||||||
// 创建页面
|
|
||||||
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
|
|
||||||
if err != nil {
|
|
||||||
return shortURL, err // 返回原始URL
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = page.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 导航到短链接
|
|
||||||
err = page.Navigate(shortURL)
|
|
||||||
if err != nil {
|
|
||||||
return shortURL, err // 返回原始URL
|
|
||||||
}
|
|
||||||
|
|
||||||
// 等待重定向完成
|
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
|
|
||||||
// 获取当前 URL
|
|
||||||
info, err := page.Info()
|
|
||||||
if err != nil {
|
|
||||||
return shortURL, err // 返回原始URL
|
|
||||||
}
|
|
||||||
|
|
||||||
return info.URL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close 关闭浏览器
|
|
||||||
func (s *Service) Close() error {
|
|
||||||
if s.browser != nil {
|
|
||||||
err := s.browser.Close()
|
|
||||||
s.browser = nil
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchWeb 封装的搜索方法
|
|
||||||
func SearchWeb(keyword string, maxPages int) (string, error) {
|
|
||||||
// 添加panic恢复机制
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log := logger.GetLogger()
|
|
||||||
log.Errorf("爬虫服务崩溃: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
service, err := NewService()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
|
|
||||||
}
|
|
||||||
defer service.Close()
|
|
||||||
|
|
||||||
// 设置超时上下文
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// 使用goroutine和通道来处理超时
|
|
||||||
resultChan := make(chan []SearchResult, 1)
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
results, err := service.WebSearch(keyword, maxPages)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resultChan <- results
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 等待结果或超时
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
|
|
||||||
case err := <-errChan:
|
|
||||||
return "", fmt.Errorf("搜索失败: %v", err)
|
|
||||||
case results := <-resultChan:
|
|
||||||
if len(results) == 0 {
|
|
||||||
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 格式化结果
|
|
||||||
var builder strings.Builder
|
|
||||||
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
|
|
||||||
|
|
||||||
for i, result := range results {
|
|
||||||
// // 尝试打开链接获取实际内容
|
|
||||||
// page := service.browser.MustPage()
|
|
||||||
// defer page.MustClose()
|
|
||||||
|
|
||||||
// // 设置页面超时
|
|
||||||
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
// defer pageCancel()
|
|
||||||
|
|
||||||
// // 导航到目标页面
|
|
||||||
// err := page.Context(pageCtx).Navigate(result.URL)
|
|
||||||
// if err == nil {
|
|
||||||
// // 等待页面加载
|
|
||||||
// _ = page.WaitLoad()
|
|
||||||
|
|
||||||
// // 获取页面标题
|
|
||||||
// title, err := page.Eval("() => document.title")
|
|
||||||
// if err == nil && title.Value.String() != "" {
|
|
||||||
// result.Title = title.Value.String()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // 获取页面主要内容
|
|
||||||
// if content, err := page.Element("body"); err == nil {
|
|
||||||
// if text, err := content.Text(); err == nil {
|
|
||||||
// // 清理并截取内容
|
|
||||||
// text = strings.TrimSpace(text)
|
|
||||||
// if len(text) > 200 {
|
|
||||||
// text = text[:200] + "..."
|
|
||||||
// }
|
|
||||||
// result.Prompt = text
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
|
|
||||||
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
|
|
||||||
if result.Content != "" {
|
|
||||||
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
|
|
||||||
}
|
|
||||||
builder.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
return builder.String(), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
@@ -94,12 +95,14 @@ func (s *Service) Run() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type imgReq struct {
|
type imgReq struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Image []string `json:"image,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
Prompt string `json:"prompt"`
|
||||||
Size string `json:"size,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Quality string `json:"quality,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
Style string `json:"style,omitempty"`
|
Quality string `json:"quality,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type imgRes struct {
|
type imgRes struct {
|
||||||
@@ -122,15 +125,6 @@ type ErrRes struct {
|
|||||||
|
|
||||||
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||||
logger.Debugf("绘画参数:%+v", task)
|
logger.Debugf("绘画参数:%+v", task)
|
||||||
prompt := task.Prompt
|
|
||||||
// translate prompt
|
|
||||||
if utils.HasChinese(prompt) {
|
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
|
|
||||||
if err == nil {
|
|
||||||
prompt = content
|
|
||||||
logger.Debugf("重写后提示词:%s", prompt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
if task.ModelId > 0 {
|
if task.ModelId > 0 {
|
||||||
@@ -160,12 +154,17 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||||
reqBody := imgReq{
|
reqBody := imgReq{
|
||||||
Model: chatModel.Value,
|
Model: chatModel.Value,
|
||||||
Prompt: prompt,
|
Prompt: task.Prompt,
|
||||||
N: 1,
|
N: 1,
|
||||||
Size: task.Size,
|
Size: task.Size,
|
||||||
Style: task.Style,
|
Style: task.Style,
|
||||||
Quality: task.Quality,
|
Quality: task.Quality,
|
||||||
}
|
}
|
||||||
|
// 图片编辑
|
||||||
|
if len(task.Image) > 0 {
|
||||||
|
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||||
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
||||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
@@ -188,7 +187,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
var imgURL string
|
var imgURL string
|
||||||
var data = map[string]interface{}{
|
var data = map[string]interface{}{
|
||||||
"progress": 100,
|
"progress": 100,
|
||||||
"prompt": prompt,
|
"prompt": task.Prompt,
|
||||||
}
|
}
|
||||||
// 如果返回的是base64,则需要上传到oss
|
// 如果返回的是base64,则需要上传到oss
|
||||||
if res.Data[0].B64Json != "" {
|
if res.Data[0].B64Json != "" {
|
||||||
@@ -210,11 +209,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
|
|
||||||
var content string
|
var content string
|
||||||
if sync {
|
if sync {
|
||||||
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
|
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", task.Prompt, imgURL)
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error with download image: %v", err)
|
|
||||||
}
|
|
||||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return content, nil
|
return content, nil
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package jimeng
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/volcengine/volc-sdk-golang/base"
|
"github.com/volcengine/volc-sdk-golang/base"
|
||||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||||
@@ -13,14 +15,22 @@ import (
|
|||||||
// Client 即梦API客户端
|
// Client 即梦API客户端
|
||||||
type Client struct {
|
type Client struct {
|
||||||
visual *visual.Visual
|
visual *visual.Visual
|
||||||
|
config types.JimengConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient 创建即梦API客户端
|
// NewClient 创建即梦API客户端
|
||||||
func NewClient(accessKey, secretKey string) *Client {
|
func NewClient(sysConfig *types.SystemConfig) *Client {
|
||||||
|
|
||||||
|
client := &Client{}
|
||||||
|
client.UpdateConfig(sysConfig.Jimeng)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||||
// 使用官方SDK的visual实例
|
// 使用官方SDK的visual实例
|
||||||
visualInstance := visual.NewInstance()
|
visualInstance := visual.NewInstance()
|
||||||
visualInstance.Client.SetAccessKey(accessKey)
|
visualInstance.Client.SetAccessKey(config.AccessKey)
|
||||||
visualInstance.Client.SetSecretKey(secretKey)
|
visualInstance.Client.SetSecretKey(config.SecretKey)
|
||||||
|
|
||||||
// 添加即梦AI专有的API配置
|
// 添加即梦AI专有的API配置
|
||||||
jimengApis := map[string]*base.ApiInfo{
|
jimengApis := map[string]*base.ApiInfo{
|
||||||
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
|
|||||||
visualInstance.Client.ApiInfoList[name] = info
|
visualInstance.Client.ApiInfoList[name] = info
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Client{
|
c.config = config
|
||||||
visual: visualInstance,
|
c.visual = visualInstance
|
||||||
|
|
||||||
|
return c.testConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnection 测试即梦AI连接
|
||||||
|
func (c *Client) testConnection() error {
|
||||||
|
|
||||||
|
// 使用一个简单的查询任务来测试连接
|
||||||
|
testReq := &QueryTaskRequest{
|
||||||
|
ReqKey: "test_connection",
|
||||||
|
TaskId: "test_task_id_12345",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err := c.QueryTask(testReq)
|
||||||
|
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||||
|
if err != nil {
|
||||||
|
// 检查是否是认证错误
|
||||||
|
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||||
|
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||||
|
}
|
||||||
|
// 其他错误(如任务不存在)说明连接正常
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubmitTask 提交异步任务
|
// SubmitTask 提交异步任务
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -16,8 +15,6 @@ import (
|
|||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
|
|
||||||
"geekai/core/types"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,17 +33,8 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewService 创建即梦服务
|
// NewService 创建即梦服务
|
||||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
|
||||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||||
// 从数据库加载配置
|
|
||||||
var config model.Config
|
|
||||||
db.Where("name = ?", "Jimeng").First(&config)
|
|
||||||
var jimengConfig types.JimengConfig
|
|
||||||
if config.Id > 0 {
|
|
||||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
|
||||||
}
|
|
||||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Service{
|
return &Service{
|
||||||
db: db,
|
db: db,
|
||||||
@@ -378,7 +366,7 @@ func (s *Service) pollTaskStatus() {
|
|||||||
|
|
||||||
for _, job := range jobs {
|
for _, job := range jobs {
|
||||||
// 任务超时处理
|
// 任务超时处理
|
||||||
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
|
if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
|
||||||
s.handleTaskError(job.Id, "task timeout")
|
s.handleTaskError(job.Id, "task timeout")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -391,7 +379,7 @@ func (s *Service) pollTaskStatus() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("query jimeng task status failed: %v", err)
|
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,9 +434,7 @@ func (s *Service) pollTaskStatus() {
|
|||||||
s.handleTaskError(job.Id, "task not found")
|
s.handleTaskError(job.Id, "task not found")
|
||||||
|
|
||||||
case model.JMTaskStatusExpired:
|
case model.JMTaskStatusExpired:
|
||||||
// 任务过期
|
continue
|
||||||
s.handleTaskError(job.Id, "task expired")
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||||
}
|
}
|
||||||
@@ -524,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
|||||||
}
|
}
|
||||||
return &job, nil
|
return &job, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// testConnection 测试即梦AI连接
|
|
||||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
|
||||||
testClient := NewClient(accessKey, secretKey)
|
|
||||||
|
|
||||||
// 使用一个简单的查询任务来测试连接
|
|
||||||
testReq := &QueryTaskRequest{
|
|
||||||
ReqKey: "test_connection",
|
|
||||||
TaskId: "test_task_id_12345",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := testClient.QueryTask(testReq)
|
|
||||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
|
||||||
if err != nil {
|
|
||||||
// 检查是否是认证错误
|
|
||||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
|
||||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
|
||||||
}
|
|
||||||
// 其他错误(如任务不存在)说明连接正常
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateClientConfig 更新客户端配置
|
|
||||||
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
|
||||||
// 创建新的客户端
|
|
||||||
newClient := NewClient(accessKey, secretKey)
|
|
||||||
|
|
||||||
// 测试新客户端是否可用
|
|
||||||
err := s.testConnection(accessKey, secretKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新客户端
|
|
||||||
s.client = newClient
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultPower = types.JimengPower{
|
|
||||||
TextToImage: 20,
|
|
||||||
ImageToImage: 20,
|
|
||||||
ImageEdit: 20,
|
|
||||||
ImageEffects: 20,
|
|
||||||
TextToVideo: 300,
|
|
||||||
ImageToVideo: 300,
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetConfig 获取即梦AI配置
|
|
||||||
func (s *Service) GetConfig() *types.JimengConfig {
|
|
||||||
var config model.Config
|
|
||||||
err := s.db.Where("name", "jimeng").First(&config).Error
|
|
||||||
if err != nil {
|
|
||||||
// 如果配置不存在,返回默认配置
|
|
||||||
return &types.JimengConfig{
|
|
||||||
AccessKey: "",
|
|
||||||
SecretKey: "",
|
|
||||||
Power: defaultPower,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var jimengConfig types.JimengConfig
|
|
||||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
|
||||||
if err != nil {
|
|
||||||
return &types.JimengConfig{
|
|
||||||
AccessKey: "",
|
|
||||||
SecretKey: "",
|
|
||||||
Power: defaultPower,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &jimengConfig
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,30 +8,37 @@ package service
|
|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/store"
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/shirou/gopsutil/host"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LicenseService struct {
|
type LicenseService struct {
|
||||||
config types.ApiConfig
|
|
||||||
levelDB *store.LevelDB
|
|
||||||
license *types.License
|
license *types.License
|
||||||
urlWhiteList []string
|
urlWhiteList []string
|
||||||
machineId string
|
machineId string
|
||||||
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
|
||||||
var license types.License
|
var machineId string
|
||||||
|
info, err := host.Info()
|
||||||
|
if err == nil {
|
||||||
|
machineId = info.HostID
|
||||||
|
}
|
||||||
|
logger.Infof("License: %+v", sysConfig.License)
|
||||||
return &LicenseService{
|
return &LicenseService{
|
||||||
config: server.Config.ApiConfig,
|
license: &sysConfig.License,
|
||||||
levelDB: levelDB,
|
machineId: machineId,
|
||||||
license: &license,
|
db: db,
|
||||||
machineId: "",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,15 +53,15 @@ type License struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ActiveLicense 激活 License
|
// ActiveLicense 激活 License
|
||||||
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
func (s *LicenseService) ActiveLicense(license string) error {
|
||||||
var res struct {
|
var res struct {
|
||||||
Code types.BizCode `json:"code"`
|
Code types.BizCode `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Data License `json:"data"`
|
Data License `json:"data"`
|
||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
|
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
|
||||||
response, err := req.C().R().
|
response, err := req.C().R().
|
||||||
SetBody(map[string]string{"license": license, "machine_id": machineId}).
|
SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
|
||||||
SetSuccessResult(&res).Post(apiURL)
|
SetSuccessResult(&res).Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("发送激活请求失败: %v", err)
|
return fmt.Errorf("发送激活请求失败: %v", err)
|
||||||
@@ -68,17 +75,24 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
|||||||
return fmt.Errorf("激活失败:%v", res.Message)
|
return fmt.Errorf("激活失败:%v", res.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
|
||||||
|
return fmt.Errorf("License 已过期")
|
||||||
|
}
|
||||||
|
|
||||||
s.license = &types.License{
|
s.license = &types.License{
|
||||||
Key: license,
|
Key: license,
|
||||||
MachineId: machineId,
|
MachineId: s.machineId,
|
||||||
Configs: res.Data.Configs,
|
Configs: res.Data.Configs,
|
||||||
ExpiredAt: res.Data.ExpiredAt,
|
ExpiredAt: res.Data.ExpiredAt,
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
}
|
}
|
||||||
err = s.levelDB.Put(types.LicenseKey, s.license)
|
|
||||||
|
// 保存 License 到数据库
|
||||||
|
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("保存许可证书失败:%v", err)
|
return fmt.Errorf("保存 License 到数据库失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,6 +110,11 @@ func (s *LicenseService) SyncLicense() {
|
|||||||
s.license.IsActive = false
|
s.license.IsActive = false
|
||||||
} else {
|
} else {
|
||||||
s.license = license
|
s.license = license
|
||||||
|
// 保存 License 到数据库
|
||||||
|
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("保存 License 到数据库失败: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
urls, err := s.fetchUrlWhiteList()
|
urls, err := s.fetchUrlWhiteList()
|
||||||
@@ -109,33 +128,30 @@ func (s *LicenseService) SyncLicense() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||||
//var res struct {
|
var res struct {
|
||||||
// Code types.BizCode `json:"code"`
|
Code types.BizCode `json:"code"`
|
||||||
// Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
// Data License `json:"data"`
|
Data License `json:"data"`
|
||||||
//}
|
}
|
||||||
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
|
||||||
//response, err := req.C().R().
|
response, err := req.C().R().
|
||||||
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||||
// SetSuccessResult(&res).Post(apiURL)
|
SetSuccessResult(&res).Post(apiURL)
|
||||||
//if err != nil {
|
if err != nil {
|
||||||
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
return nil, fmt.Errorf("License 同步失败: %v", err)
|
||||||
//}
|
}
|
||||||
//if response.IsErrorState() {
|
if response.IsErrorState() {
|
||||||
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
return nil, fmt.Errorf("License 同步失败:%v", response.Status)
|
||||||
//}
|
}
|
||||||
//if res.Code != types.Success {
|
if res.Code != types.Success {
|
||||||
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
return nil, fmt.Errorf("License 同步失败:%v", res.Message)
|
||||||
//}
|
}
|
||||||
|
|
||||||
return &types.License{
|
return &types.License{
|
||||||
Key: "abc",
|
Key: res.Data.License,
|
||||||
MachineId: "abc",
|
MachineId: res.Data.MachineId,
|
||||||
Configs: types.LicenseConfig{
|
Configs: res.Data.Configs,
|
||||||
UserNum: 10000,
|
ExpiredAt: res.Data.ExpiredAt,
|
||||||
DeCopy: false,
|
|
||||||
},
|
|
||||||
ExpiredAt: 0,
|
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -146,7 +162,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Data []string `json:"data"`
|
Data []string `json:"data"`
|
||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
|
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
|
||||||
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("发送请求失败: %v", err)
|
return nil, fmt.Errorf("发送请求失败: %v", err)
|
||||||
@@ -163,35 +179,46 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
|||||||
|
|
||||||
// GetLicense 获取许可信息
|
// GetLicense 获取许可信息
|
||||||
func (s *LicenseService) GetLicense() *types.License {
|
func (s *LicenseService) GetLicense() *types.License {
|
||||||
|
if s.license == nil {
|
||||||
|
var config model.Config
|
||||||
|
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
|
||||||
|
if config.Value != "" {
|
||||||
|
utils.JsonDecode(config.Value, &s.license)
|
||||||
|
}
|
||||||
|
}
|
||||||
return s.license
|
return s.license
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *LicenseService) SetLicense(licenseKey string) {
|
||||||
|
s.license.Key = licenseKey
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// IsValidApiURL 判断是否合法的中转 URL
|
// IsValidApiURL 判断是否合法的中转 URL
|
||||||
func (s *LicenseService) IsValidApiURL(uri string) error {
|
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||||
// 获得许可授权的直接放行
|
// 获得许可授权的直接放行
|
||||||
return nil
|
if s.license.IsActive {
|
||||||
//if s.license.IsActive {
|
if s.license.MachineId != s.machineId {
|
||||||
// if s.license.MachineId != s.machineId {
|
return errors.New("系统使用了盗版的许可证书")
|
||||||
// return errors.New("系统使用了盗版的许可证书")
|
}
|
||||||
// }
|
|
||||||
//
|
if time.Now().Unix() > s.license.ExpiredAt {
|
||||||
// if time.Now().Unix() > s.license.ExpiredAt {
|
return errors.New("系统许可证书已经过期")
|
||||||
// return errors.New("系统许可证书已经过期")
|
}
|
||||||
// }
|
return nil
|
||||||
// return nil
|
}
|
||||||
//}
|
|
||||||
//
|
if len(s.urlWhiteList) == 0 {
|
||||||
//if len(s.urlWhiteList) == 0 {
|
urls, err := s.fetchUrlWhiteList()
|
||||||
// urls, err := s.fetchUrlWhiteList()
|
if err == nil {
|
||||||
// if err == nil {
|
s.urlWhiteList = urls
|
||||||
// s.urlWhiteList = urls
|
}
|
||||||
// }
|
}
|
||||||
//}
|
|
||||||
//
|
for _, v := range s.urlWhiteList {
|
||||||
//for _, v := range s.urlWhiteList {
|
if strings.HasPrefix(uri, v) {
|
||||||
// if strings.HasPrefix(uri, v) {
|
return nil
|
||||||
// return nil
|
}
|
||||||
// }
|
}
|
||||||
//}
|
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||||
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,52 +1,342 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
// Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
// Use of this source code is governed by a Apache-2.0 license
|
||||||
// * that can be found in the LICENSE file.
|
// that can be found in the LICENSE file.
|
||||||
// * @Author yangjian102621@163.com
|
// @Author yangjian102621@163.com
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 迁移状态Redis key
|
||||||
|
MigrationStatusKey = "config_migration:status"
|
||||||
|
// 迁移完成标志
|
||||||
|
MigrationCompleted = "completed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MigrationService 配置迁移服务
|
||||||
type MigrationService struct {
|
type MigrationService struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
redisClient *redis.Client
|
||||||
|
appConfig *types.AppConfig
|
||||||
|
levelDB *store.LevelDB
|
||||||
|
licenseService *LicenseService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMigrationService(db *gorm.DB) *MigrationService {
|
func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
|
||||||
return &MigrationService{db: db}
|
return &MigrationService{
|
||||||
|
db: db,
|
||||||
|
redisClient: redisClient,
|
||||||
|
appConfig: appConfig,
|
||||||
|
levelDB: levelDB,
|
||||||
|
licenseService: licenseService,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MigrationService) Migrate() error {
|
func (s *MigrationService) StartMigrate() {
|
||||||
err := s.db.AutoMigrate(
|
go func() {
|
||||||
&model.AdminUser{},
|
s.MigrateConfig(s.appConfig)
|
||||||
&model.ApiKey{},
|
s.TableMigration()
|
||||||
&model.AppType{},
|
s.MigrateLicense()
|
||||||
&model.ChatItem{},
|
}()
|
||||||
&model.ChatMessage{},
|
}
|
||||||
&model.ChatModel{},
|
|
||||||
&model.ChatRole{},
|
// 迁移 License
|
||||||
&model.Config{},
|
func (s *MigrationService) MigrateLicense() {
|
||||||
&model.DallJob{},
|
key := "migrate:license"
|
||||||
&model.File{},
|
if s.redisClient.Get(context.Background(), key).Val() == "1" {
|
||||||
&model.Function{},
|
logger.Info("License 已迁移,跳过迁移")
|
||||||
&model.InviteCode{},
|
return
|
||||||
&model.InviteLog{},
|
}
|
||||||
&model.Menu{},
|
|
||||||
&model.MidJourneyJob{},
|
logger.Info("开始迁移 License...")
|
||||||
&model.Order{},
|
var license types.License
|
||||||
&model.PowerLog{},
|
err := s.levelDB.Get(types.LicenseKey, &license)
|
||||||
&model.Product{},
|
if err != nil {
|
||||||
&model.Redeem{},
|
license = types.License{
|
||||||
&model.SdJob{},
|
Key: "",
|
||||||
&model.SunoJob{},
|
MachineId: "",
|
||||||
&model.User{},
|
Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
|
||||||
&model.UserLoginLog{},
|
ExpiredAt: 0,
|
||||||
&model.VideoJob{},
|
IsActive: false,
|
||||||
)
|
}
|
||||||
return err
|
}
|
||||||
|
logger.Infof("迁移 License: %+v", license)
|
||||||
|
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
|
||||||
|
logger.Errorf("迁移 License 失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.licenseService.SetLicense(license.Key)
|
||||||
|
logger.Info("迁移 License 完成")
|
||||||
|
s.redisClient.Set(context.Background(), key, "1", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移配置内容
|
||||||
|
func (s *MigrationService) MigrateConfigContent() error {
|
||||||
|
// 用户协议
|
||||||
|
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
|
||||||
|
"content": "用户协议内容",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
}
|
||||||
|
// 隐私政策
|
||||||
|
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
|
||||||
|
"content": "隐私政策内容",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
}
|
||||||
|
// 思维导图
|
||||||
|
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
|
||||||
|
"content": `# GeekAI 演示站
|
||||||
|
|
||||||
|
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||||
|
- 基于 Websocket 实现,完美的打字机体验。
|
||||||
|
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
|
||||||
|
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||||
|
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
|
||||||
|
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||||
|
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||||
|
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 微信登录配置
|
||||||
|
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
|
||||||
|
"api_key": "",
|
||||||
|
"notify_url": "",
|
||||||
|
"enabled": "false",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证码配置
|
||||||
|
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
|
||||||
|
"api_key": "",
|
||||||
|
"type": "dot",
|
||||||
|
"enabled": "false",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 文本审核
|
||||||
|
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
|
||||||
|
"enable": "false",
|
||||||
|
"active": "gitee",
|
||||||
|
"enable_guide": "false",
|
||||||
|
"guide_prompt": "",
|
||||||
|
"gitee": map[string]string{
|
||||||
|
"api_key": "",
|
||||||
|
"model": "Security-semantic-filtering",
|
||||||
|
},
|
||||||
|
"baidu": map[string]string{
|
||||||
|
"access_key": "",
|
||||||
|
"secret_key": "",
|
||||||
|
},
|
||||||
|
"tencent": map[string]string{
|
||||||
|
"access_key": "",
|
||||||
|
"secret_key": "",
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 数据表迁移
|
||||||
|
func (s *MigrationService) TableMigration() {
|
||||||
|
// 新数据表
|
||||||
|
s.db.AutoMigrate(&model.Moderation{})
|
||||||
|
|
||||||
|
// 订单字段整理
|
||||||
|
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
|
||||||
|
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
|
||||||
|
}
|
||||||
|
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
|
||||||
|
s.db.Migrator().AddColumn(&model.Order{}, "checked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重命名 config 表字段
|
||||||
|
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||||
|
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||||
|
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||||
|
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||||
|
s.db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 手动删除字段
|
||||||
|
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||||
|
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
||||||
|
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
||||||
|
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
||||||
|
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
||||||
|
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
||||||
|
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
|
||||||
|
s.db.Migrator().DropColumn(&model.Product{}, "discount")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
|
||||||
|
s.db.Migrator().DropColumn(&model.Product{}, "days")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
|
||||||
|
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
|
||||||
|
}
|
||||||
|
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
|
||||||
|
s.db.Migrator().DropColumn(&model.Product{}, "url")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移配置数据
|
||||||
|
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
|
||||||
|
|
||||||
|
logger.Info("开始迁移配置到数据库...")
|
||||||
|
|
||||||
|
// 迁移支付配置
|
||||||
|
if err := s.migratePaymentConfig(config); err != nil {
|
||||||
|
logger.Errorf("迁移支付配置失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移存储配置
|
||||||
|
if err := s.migrateStorageConfig(config); err != nil {
|
||||||
|
logger.Errorf("迁移存储配置失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移通信配置
|
||||||
|
if err := s.migrateCommunicationConfig(config); err != nil {
|
||||||
|
logger.Errorf("迁移通信配置失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移配置内容
|
||||||
|
if err := s.MigrateConfigContent(); err != nil {
|
||||||
|
logger.Errorf("迁移配置内容失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("配置迁移完成")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移支付配置
|
||||||
|
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
|
||||||
|
|
||||||
|
paymentConfig := types.PaymentConfig{
|
||||||
|
Alipay: config.AlipayConfig,
|
||||||
|
Epay: config.GeekPayConfig,
|
||||||
|
WxPay: config.WechatPayConfig,
|
||||||
|
}
|
||||||
|
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移存储配置
|
||||||
|
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
|
||||||
|
|
||||||
|
ossConfig := types.OSSConfig{
|
||||||
|
Active: config.OSS.Active,
|
||||||
|
Local: config.OSS.Local,
|
||||||
|
Minio: config.OSS.Minio,
|
||||||
|
QiNiu: config.OSS.QiNiu,
|
||||||
|
AliYun: config.OSS.AliYun,
|
||||||
|
}
|
||||||
|
return s.saveConfig(types.ConfigKeyOss, ossConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 迁移通信配置
|
||||||
|
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
|
||||||
|
// SMTP配置
|
||||||
|
smtpConfig := map[string]any{
|
||||||
|
"use_tls": config.SmtpConfig.UseTls,
|
||||||
|
"host": config.SmtpConfig.Host,
|
||||||
|
"port": config.SmtpConfig.Port,
|
||||||
|
"app_name": config.SmtpConfig.AppName,
|
||||||
|
"from": config.SmtpConfig.From,
|
||||||
|
"password": config.SmtpConfig.Password,
|
||||||
|
}
|
||||||
|
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 短信配置
|
||||||
|
smsConfig := map[string]any{
|
||||||
|
"active": strings.ToLower(config.SMS.Active),
|
||||||
|
"aliyun": map[string]any{
|
||||||
|
"access_key": config.SMS.Ali.AccessKey,
|
||||||
|
"access_secret": config.SMS.Ali.AccessSecret,
|
||||||
|
"sign": config.SMS.Ali.Sign,
|
||||||
|
"code_temp_id": config.SMS.Ali.CodeTempId,
|
||||||
|
},
|
||||||
|
"bao": map[string]any{
|
||||||
|
"username": config.SMS.Bao.Username,
|
||||||
|
"password": config.SMS.Bao.Password,
|
||||||
|
"sign": config.SMS.Bao.Sign,
|
||||||
|
"code_template": config.SMS.Bao.CodeTemplate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return s.saveConfig(types.ConfigKeySms, smsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存配置到数据库
|
||||||
|
func (s *MigrationService) saveConfig(key string, config any) error {
|
||||||
|
// 检查是否已存在
|
||||||
|
var existingConfig model.Config
|
||||||
|
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
|
||||||
|
// 配置已存在,跳过
|
||||||
|
logger.Infof("配置 %s 已存在,跳过迁移", key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化配置
|
||||||
|
configJSON, err := json.Marshal(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存到数据库
|
||||||
|
newConfig := model.Config{
|
||||||
|
Name: key,
|
||||||
|
Value: string(configJSON),
|
||||||
|
}
|
||||||
|
if err := s.db.Create(&newConfig).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("成功迁移配置 %s", key)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,25 +67,6 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// translate prompt
|
|
||||||
if utils.HasChinese(task.Prompt) {
|
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
|
||||||
if err == nil {
|
|
||||||
task.Prompt = content
|
|
||||||
} else {
|
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// translate negative prompt
|
|
||||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
|
|
||||||
if err == nil {
|
|
||||||
task.NegPrompt = content
|
|
||||||
} else {
|
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// use fast mode as default
|
// use fast mode as default
|
||||||
if task.Mode == "" {
|
if task.Mode == "" {
|
||||||
task.Mode = "fast"
|
task.Mode = "fast"
|
||||||
|
|||||||
33
api/service/moderation/baidu_moderation.go
Normal file
33
api/service/moderation/baidu_moderation.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package moderation
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"geekai/core/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaiduAIModeration struct {
|
||||||
|
config types.ModerationBaiduConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
|
||||||
|
return &BaiduAIModeration{
|
||||||
|
config: sysConfig.Moderation.Baidu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
|
||||||
|
s.config = config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||||
|
return types.ModerationResult{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Service = (*BaiduAIModeration)(nil)
|
||||||
58
api/service/moderation/gitee_moderation.go
Normal file
58
api/service/moderation/gitee_moderation.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package moderation
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"geekai/core/types"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GiteeAIModeration struct {
|
||||||
|
config types.ModerationGiteeConfig
|
||||||
|
apiURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
|
||||||
|
return &GiteeAIModeration{
|
||||||
|
config: sysConfig.Moderation.Gitee,
|
||||||
|
apiURL: "https://ai.gitee.com/v1/moderations",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
|
||||||
|
s.config = config
|
||||||
|
}
|
||||||
|
|
||||||
|
type GiteeAIModerationResult struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Results []types.ModerationResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||||
|
|
||||||
|
body := map[string]any{
|
||||||
|
"input": text,
|
||||||
|
"model": s.config.Model,
|
||||||
|
}
|
||||||
|
var res GiteeAIModerationResult
|
||||||
|
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return types.ModerationResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return types.ModerationResult{}, errors.New(r.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Results[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Service = (*GiteeAIModeration)(nil)
|
||||||
58
api/service/moderation/moderation_manager.go
Normal file
58
api/service/moderation/moderation_manager.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package moderation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core/types"
|
||||||
|
|
||||||
|
logger2 "geekai/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
type Service interface {
|
||||||
|
Moderate(text string) (types.ModerationResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceManager struct {
|
||||||
|
gitee *GiteeAIModeration
|
||||||
|
baidu *BaiduAIModeration
|
||||||
|
tencent *TencentAIModeration
|
||||||
|
active string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
|
||||||
|
return &ServiceManager{
|
||||||
|
gitee: gitee,
|
||||||
|
baidu: baidu,
|
||||||
|
tencent: tencent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceManager) GetService() Service {
|
||||||
|
switch s.active {
|
||||||
|
case types.ModerationBaidu:
|
||||||
|
return s.baidu
|
||||||
|
case types.ModerationTencent:
|
||||||
|
return s.tencent
|
||||||
|
default:
|
||||||
|
return s.gitee
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
|
||||||
|
switch config.Active {
|
||||||
|
case types.ModerationGitee:
|
||||||
|
s.gitee.UpdateConfig(config.Gitee)
|
||||||
|
case types.ModerationBaidu:
|
||||||
|
s.baidu.UpdateConfig(config.Baidu)
|
||||||
|
case types.ModerationTencent:
|
||||||
|
s.tencent.UpdateConfig(config.Tencent)
|
||||||
|
}
|
||||||
|
s.active = config.Active
|
||||||
|
}
|
||||||
33
api/service/moderation/tencent_moderation.go
Normal file
33
api/service/moderation/tencent_moderation.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package moderation
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"geekai/core/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TencentAIModeration struct {
|
||||||
|
config types.ModerationTencentConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
|
||||||
|
return &TencentAIModeration{
|
||||||
|
config: sysConfig.Moderation.Tencent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
|
||||||
|
s.config = config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||||
|
return types.ModerationResult{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Service = (*TencentAIModeration)(nil)
|
||||||
@@ -23,35 +23,35 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AliYunOss struct {
|
type AliYunOss struct {
|
||||||
config *types.AliYunOssConfig
|
config types.AliYunOssConfig
|
||||||
bucket *oss.Bucket
|
bucket *oss.Bucket
|
||||||
proxyURL string
|
proxyURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
|
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||||
config := &appConfig.OSS.AliYun
|
s := &AliYunOss{
|
||||||
// 创建 OSS 客户端
|
proxyURL: appConfig.ProxyURL,
|
||||||
|
}
|
||||||
|
err := s.UpdateConfig(sysConfig.OSS.AliYun)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("阿里云OSS初始化失败: %v", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
|
||||||
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
|
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取存储空间
|
|
||||||
bucket, err := client.Bucket(config.Bucket)
|
bucket, err := client.Bucket(config.Bucket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
s.bucket = bucket
|
||||||
if config.SubDir == "" {
|
s.config = config
|
||||||
config.SubDir = "gpt"
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
return &AliYunOss{
|
|
||||||
config: config,
|
|
||||||
bucket: bucket,
|
|
||||||
proxyURL: appConfig.ProxyURL,
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
defer src.Close()
|
defer src.Close()
|
||||||
|
|
||||||
fileExt := filepath.Ext(file.Filename)
|
fileExt := filepath.Ext(file.Filename)
|
||||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||||
// 上传文件
|
// 上传文件
|
||||||
err = s.bucket.PutObject(objectKey, src)
|
err = s.bucket.PutObject(objectKey, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -102,7 +102,7 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
|||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = filepath.Ext(parse.Path)
|
ext = filepath.Ext(parse.Path)
|
||||||
}
|
}
|
||||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||||
// 上传文件字节数据
|
// 上传文件字节数据
|
||||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,7 +116,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
}
|
}
|
||||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||||
// 上传文件字节数据
|
// 上传文件字节数据
|
||||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -128,8 +128,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
|||||||
func (s AliYunOss) Delete(fileURL string) error {
|
func (s AliYunOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
filename := filepath.Base(fileURL)
|
objectKey = filepath.Base(fileURL)
|
||||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
|
||||||
} else {
|
} else {
|
||||||
objectKey = fileURL
|
objectKey = fileURL
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,17 +21,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type LocalStorage struct {
|
type LocalStorage struct {
|
||||||
config *types.LocalStorageConfig
|
config types.LocalStorageConfig
|
||||||
proxyURL string
|
proxyURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocalStorage(config *types.AppConfig) LocalStorage {
|
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
|
||||||
return LocalStorage{
|
return &LocalStorage{
|
||||||
config: &config.OSS.Local,
|
config: sysConfig.OSS.Local,
|
||||||
proxyURL: config.ProxyURL,
|
proxyURL: appConfig.ProxyURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
|
||||||
|
s.config = config
|
||||||
|
}
|
||||||
|
|
||||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
file, err := ctx.FormFile(name)
|
file, err := ctx.FormFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -24,24 +24,32 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MiniOss struct {
|
type MiniOss struct {
|
||||||
config *types.MiniOssConfig
|
config types.MiniOssConfig
|
||||||
client *minio.Client
|
client *minio.Client
|
||||||
proxyURL string
|
proxyURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
|
||||||
config := &appConfig.OSS.Minio
|
|
||||||
|
s := &MiniOss{proxyURL: appConfig.ProxyURL}
|
||||||
|
err := s.UpdateConfig(sysConfig.OSS.Minio)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("MinioOSS初始化失败: %v", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
|
||||||
minioClient, err := minio.New(config.Endpoint, &minio.Options{
|
minioClient, err := minio.New(config.Endpoint, &minio.Options{
|
||||||
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
|
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
|
||||||
Secure: config.UseSSL,
|
Secure: config.UseSSL,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return MiniOss{}, err
|
return err
|
||||||
}
|
}
|
||||||
if config.SubDir == "" {
|
s.config = config
|
||||||
config.SubDir = "gpt"
|
s.client = minioClient
|
||||||
}
|
return nil
|
||||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||||
@@ -62,7 +70,7 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
|
|||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = filepath.Ext(parse.Path)
|
ext = filepath.Ext(parse.Path)
|
||||||
}
|
}
|
||||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||||
info, err := s.client.PutObject(
|
info, err := s.client.PutObject(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
s.config.Bucket,
|
s.config.Bucket,
|
||||||
@@ -89,7 +97,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
defer fileReader.Close()
|
defer fileReader.Close()
|
||||||
|
|
||||||
fileExt := filepath.Ext(file.Filename)
|
fileExt := filepath.Ext(file.Filename)
|
||||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||||
ContentType: file.Header.Get("Body-Type"),
|
ContentType: file.Header.Get("Body-Type"),
|
||||||
})
|
})
|
||||||
@@ -111,7 +119,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
}
|
}
|
||||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||||
info, err := s.client.PutObject(
|
info, err := s.client.PutObject(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
s.config.Bucket,
|
s.config.Bucket,
|
||||||
@@ -128,8 +136,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
|||||||
func (s MiniOss) Delete(fileURL string) error {
|
func (s MiniOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
filename := filepath.Base(fileURL)
|
objectKey = filepath.Base(fileURL)
|
||||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
|
||||||
} else {
|
} else {
|
||||||
objectKey = fileURL
|
objectKey = fileURL
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,18 +24,24 @@ import (
|
|||||||
"github.com/qiniu/go-sdk/v7/storage"
|
"github.com/qiniu/go-sdk/v7/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
type QinNiuOss struct {
|
type QiNiuOss struct {
|
||||||
config *types.QiNiuOssConfig
|
config types.QiNiuOssConfig
|
||||||
mac *qbox.Mac
|
mac *qbox.Mac
|
||||||
putPolicy storage.PutPolicy
|
putPolicy storage.PutPolicy
|
||||||
uploader *storage.FormUploader
|
uploader *storage.FormUploader
|
||||||
manager *storage.BucketManager
|
bucket *storage.BucketManager
|
||||||
proxyURL string
|
proxyURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiNiuOss {
|
||||||
config := &appConfig.OSS.QiNiu
|
s := &QiNiuOss{
|
||||||
// build storage uploader
|
proxyURL: appConfig.ProxyURL,
|
||||||
|
}
|
||||||
|
s.UpdateConfig(sysConfig.OSS.QiNiu)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *QiNiuOss) UpdateConfig(config types.QiNiuOssConfig) {
|
||||||
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
|
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
|
||||||
if !ok {
|
if !ok {
|
||||||
zone = storage.ZoneHuanan
|
zone = storage.ZoneHuanan
|
||||||
@@ -47,20 +53,13 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
|||||||
putPolicy := storage.PutPolicy{
|
putPolicy := storage.PutPolicy{
|
||||||
Scope: config.Bucket,
|
Scope: config.Bucket,
|
||||||
}
|
}
|
||||||
if config.SubDir == "" {
|
s.config = config
|
||||||
config.SubDir = "gpt"
|
s.mac = mac
|
||||||
}
|
s.putPolicy = putPolicy
|
||||||
return QinNiuOss{
|
s.uploader = formUploader
|
||||||
config: config,
|
s.bucket = storage.NewBucketManager(mac, &storeConfig)
|
||||||
mac: mac,
|
|
||||||
putPolicy: putPolicy,
|
|
||||||
uploader: formUploader,
|
|
||||||
manager: storage.NewBucketManager(mac, &storeConfig),
|
|
||||||
proxyURL: appConfig.ProxyURL,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
func (s QiNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||||
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|
||||||
// 解析表单
|
// 解析表单
|
||||||
file, err := ctx.FormFile(name)
|
file, err := ctx.FormFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,7 +73,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
defer src.Close()
|
defer src.Close()
|
||||||
|
|
||||||
fileExt := filepath.Ext(file.Filename)
|
fileExt := filepath.Ext(file.Filename)
|
||||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||||
// 上传文件
|
// 上传文件
|
||||||
ret := storage.PutRet{}
|
ret := storage.PutRet{}
|
||||||
extra := storage.PutExtra{}
|
extra := storage.PutExtra{}
|
||||||
@@ -93,7 +92,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
func (s QiNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||||
var fileData []byte
|
var fileData []byte
|
||||||
var err error
|
var err error
|
||||||
if useProxy {
|
if useProxy {
|
||||||
@@ -111,7 +110,7 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
|||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = filepath.Ext(parse.Path)
|
ext = filepath.Ext(parse.Path)
|
||||||
}
|
}
|
||||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||||
ret := storage.PutRet{}
|
ret := storage.PutRet{}
|
||||||
extra := storage.PutExtra{}
|
extra := storage.PutExtra{}
|
||||||
// 上传文件字节数据
|
// 上传文件字节数据
|
||||||
@@ -122,12 +121,12 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
func (s QiNiuOss) PutBase64(base64Img string) (string, error) {
|
||||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
}
|
}
|
||||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||||
ret := storage.PutRet{}
|
ret := storage.PutRet{}
|
||||||
extra := storage.PutExtra{}
|
extra := storage.PutExtra{}
|
||||||
// 上传文件字节数据
|
// 上传文件字节数据
|
||||||
@@ -138,16 +137,15 @@ func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) Delete(fileURL string) error {
|
func (s QiNiuOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
filename := filepath.Base(fileURL)
|
objectKey = filepath.Base(fileURL)
|
||||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
|
||||||
} else {
|
} else {
|
||||||
objectKey = fileURL
|
objectKey = fileURL
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.manager.Delete(s.config.Bucket, objectKey)
|
return s.bucket.Delete(s.config.Bucket, objectKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Uploader = QinNiuOss{}
|
var _ Uploader = QiNiuOss{}
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ package oss
|
|||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
const Local = "LOCAL"
|
const Local = "local"
|
||||||
const Minio = "MINIO"
|
const Minio = "minio"
|
||||||
const QiNiu = "QINIU"
|
const QiNiu = "qiniu"
|
||||||
const AliYun = "ALIYUN"
|
const AliYun = "aliyun"
|
||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
@@ -9,45 +9,58 @@ package oss
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"strings"
|
|
||||||
|
logger2 "geekai/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
type UploaderManager struct {
|
type UploaderManager struct {
|
||||||
handler Uploader
|
local *LocalStorage
|
||||||
|
aliyun *AliYunOss
|
||||||
|
mini *MiniOss
|
||||||
|
qiniu *QiNiuOss
|
||||||
|
active string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
|
func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QiNiuOss) (*UploaderManager, error) {
|
||||||
active := Local
|
if sysConfig.OSS.Active == "" {
|
||||||
if config.OSS.Active != "" {
|
sysConfig.OSS.Active = Local
|
||||||
active = strings.ToUpper(config.OSS.Active)
|
|
||||||
}
|
|
||||||
var handler Uploader
|
|
||||||
switch active {
|
|
||||||
case Local:
|
|
||||||
handler = NewLocalStorage(config)
|
|
||||||
break
|
|
||||||
case Minio:
|
|
||||||
client, err := NewMiniOss(config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
handler = client
|
|
||||||
break
|
|
||||||
case QiNiu:
|
|
||||||
handler = NewQiNiuOss(config)
|
|
||||||
break
|
|
||||||
case AliYun:
|
|
||||||
client, err := NewAliYunOss(config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
handler = client
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &UploaderManager{handler: handler}, nil
|
return &UploaderManager{
|
||||||
|
active: sysConfig.OSS.Active,
|
||||||
|
local: local,
|
||||||
|
aliyun: aliyun,
|
||||||
|
mini: mini,
|
||||||
|
qiniu: qiniu,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UploaderManager) GetUploadHandler() Uploader {
|
func (m *UploaderManager) GetUploadHandler() Uploader {
|
||||||
return m.handler
|
switch m.active {
|
||||||
|
case Local:
|
||||||
|
return m.local
|
||||||
|
case AliYun:
|
||||||
|
return m.aliyun
|
||||||
|
case Minio:
|
||||||
|
return m.mini
|
||||||
|
case QiNiu:
|
||||||
|
return m.qiniu
|
||||||
|
}
|
||||||
|
return m.local
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UploaderManager) UpdateConfig(config types.OSSConfig) {
|
||||||
|
switch config.Active {
|
||||||
|
case Local:
|
||||||
|
m.local.UpdateConfig(config.Local)
|
||||||
|
case AliYun:
|
||||||
|
m.aliyun.UpdateConfig(config.AliYun)
|
||||||
|
case Minio:
|
||||||
|
m.mini.UpdateConfig(config.Minio)
|
||||||
|
case QiNiu:
|
||||||
|
m.qiniu.UpdateConfig(config.QiNiu)
|
||||||
|
}
|
||||||
|
m.active = config.Active
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,129 +12,98 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
logger2 "geekai/logger"
|
logger2 "geekai/logger"
|
||||||
"github.com/go-pay/gopay"
|
|
||||||
"github.com/go-pay/gopay/alipay"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/go-pay/gopay"
|
||||||
|
"github.com/go-pay/gopay/alipay"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AlipayService struct {
|
type AlipayService struct {
|
||||||
config *types.AlipayConfig
|
|
||||||
client *alipay.Client
|
client *alipay.Client
|
||||||
|
config *types.AlipayConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
|
||||||
config := appConfig.AlipayConfig
|
config := sysConfig.Payment.Alipay
|
||||||
if !config.Enabled {
|
if !config.Enabled {
|
||||||
logger.Info("Disabled Alipay service")
|
logger.Debug("Disabled Alipay service")
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
priKey, err := readKey(config.PrivateKey)
|
|
||||||
|
service := &AlipayService{config: &config}
|
||||||
|
if config.Enabled {
|
||||||
|
err := service.UpdateConfig(&config)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("支付宝服务初始化失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
|
||||||
|
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
return fmt.Errorf("error with initialize alipay service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
|
s.client = client
|
||||||
if err != nil {
|
s.config = config
|
||||||
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
|
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||||
|
logger.Info("Alipay Debug mode is enabled")
|
||||||
|
client.DebugSwitch = gopay.DebugOn
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
|
||||||
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
|
||||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
|
||||||
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
|
|
||||||
|
|
||||||
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
|
||||||
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AlipayService{config: &config, client: client}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AlipayParams struct {
|
func (s *AlipayService) Pay(params PayRequest) (string, error) {
|
||||||
OutTradeNo string `json:"out_trade_no"`
|
|
||||||
Subject string `json:"subject"`
|
|
||||||
TotalFee string `json:"total_fee"`
|
|
||||||
ReturnURL string `json:"return_url"`
|
|
||||||
NotifyURL string `json:"notify_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
|
|
||||||
bm := make(gopay.BodyMap)
|
|
||||||
bm.Set("subject", params.Subject)
|
|
||||||
bm.Set("out_trade_no", params.OutTradeNo)
|
|
||||||
bm.Set("quit_url", params.ReturnURL)
|
|
||||||
bm.Set("total_amount", params.TotalFee)
|
|
||||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
|
||||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
|
|
||||||
bm := make(gopay.BodyMap)
|
bm := make(gopay.BodyMap)
|
||||||
bm.Set("subject", params.Subject)
|
bm.Set("subject", params.Subject)
|
||||||
bm.Set("out_trade_no", params.OutTradeNo)
|
bm.Set("out_trade_no", params.OutTradeNo)
|
||||||
bm.Set("total_amount", params.TotalFee)
|
bm.Set("total_amount", params.TotalFee)
|
||||||
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
return s.client.TradeWapPay(context.Background(), bm)
|
||||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
|
}
|
||||||
|
|
||||||
|
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||||
|
bm := make(gopay.BodyMap)
|
||||||
|
bm.Set("out_trade_no", outTradeNo)
|
||||||
|
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
||||||
|
if err != nil {
|
||||||
|
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rsp.Response.TradeStatus {
|
||||||
|
case "TRADE_SUCCESS":
|
||||||
|
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
|
||||||
|
return OrderInfo{
|
||||||
|
OutTradeNo: rsp.Response.OutTradeNo,
|
||||||
|
TradeId: rsp.Response.TradeNo,
|
||||||
|
Amount: rsp.Response.TotalAmount,
|
||||||
|
Status: Success,
|
||||||
|
PayTime: rsp.Response.SendPayDate,
|
||||||
|
}, nil
|
||||||
|
case "TRADE_CLOSED":
|
||||||
|
return OrderInfo{Status: Closed}, nil
|
||||||
|
default:
|
||||||
|
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TradeVerify 交易验证
|
// TradeVerify 交易验证
|
||||||
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
|
func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
|
||||||
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
|
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NotifyVo{
|
return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
|
||||||
Status: Failure,
|
|
||||||
Message: "error with parse notify request: " + err.Error(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
|
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NotifyVo{
|
return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
|
||||||
Status: Failure,
|
|
||||||
Message: "error with verify sign: " + err.Error(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.TradeQuery(request.Form.Get("out_trade_no"))
|
return s.Query(request.Form.Get("out_trade_no"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
|
var _ PayService = (*AlipayService)(nil)
|
||||||
bm := make(gopay.BodyMap)
|
|
||||||
bm.Set("out_trade_no", outTradeNo)
|
|
||||||
|
|
||||||
//查询订单
|
|
||||||
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
|
||||||
if err != nil {
|
|
||||||
return NotifyVo{
|
|
||||||
Status: Failure,
|
|
||||||
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
|
|
||||||
return NotifyVo{
|
|
||||||
Status: Success,
|
|
||||||
OutTradeNo: rsp.Response.OutTradeNo,
|
|
||||||
TradeId: rsp.Response.TradeNo,
|
|
||||||
Amount: rsp.Response.TotalAmount,
|
|
||||||
Subject: rsp.Response.Subject,
|
|
||||||
Message: "OK",
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return NotifyVo{
|
|
||||||
Status: Failure,
|
|
||||||
Message: "异步查询验证订单信息发生错误" + outTradeNo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func readKey(filename string) (string, error) {
|
|
||||||
data, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return string(data), nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,41 +22,30 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GeekPayService Geek 支付服务
|
// EPayService 易支付服务
|
||||||
type GeekPayService struct {
|
type EPayService struct {
|
||||||
config *types.GeekPayConfig
|
config *types.EpayConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
|
func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
|
||||||
return &GeekPayService{
|
return &EPayService{
|
||||||
config: &appConfig.GeekPayConfig,
|
config: &sysConfig.Payment.Epay,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeekPayParams struct {
|
func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
|
||||||
Method string `json:"method"` // 接口类型
|
s.config = config
|
||||||
Device string `json:"device"` // 设备类型
|
|
||||||
Type string `json:"type"` // 支付方式
|
|
||||||
OutTradeNo string `json:"out_trade_no"` // 商户订单号
|
|
||||||
Name string `json:"name"` // 商品名称
|
|
||||||
Money string `json:"money"` // 商品金额
|
|
||||||
ClientIP string `json:"clientip"` //用户IP地址
|
|
||||||
SubOpenId string `json:"sub_openid"` // 微信用户 openid,仅小程序支付需要
|
|
||||||
SubAppId string `json:"sub_appid"` // 小程序 AppId,仅小程序支付需要
|
|
||||||
NotifyURL string `json:"notify_url"`
|
|
||||||
ReturnURL string `json:"return_url"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pay 支付订单
|
// Pay 支付订单
|
||||||
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
func (s *EPayService) Pay(params PayRequest) (string, error) {
|
||||||
p := map[string]string{
|
p := map[string]string{
|
||||||
"pid": s.config.AppId,
|
"pid": s.config.AppId,
|
||||||
//"method": params.Method,
|
|
||||||
"device": params.Device,
|
"device": params.Device,
|
||||||
"type": params.Type,
|
"type": params.PayWay,
|
||||||
"out_trade_no": params.OutTradeNo,
|
"out_trade_no": params.OutTradeNo,
|
||||||
"name": params.Name,
|
"name": params.Subject,
|
||||||
"money": params.Money,
|
"money": params.TotalFee,
|
||||||
"clientip": params.ClientIP,
|
"clientip": params.ClientIP,
|
||||||
"notify_url": params.NotifyURL,
|
"notify_url": params.NotifyURL,
|
||||||
"return_url": params.ReturnURL,
|
"return_url": params.ReturnURL,
|
||||||
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
|||||||
}
|
}
|
||||||
p["sign"] = s.Sign(p)
|
p["sign"] = s.Sign(p)
|
||||||
p["sign_type"] = "MD5"
|
p["sign_type"] = "MD5"
|
||||||
return s.sendRequest(s.config.ApiURL, p)
|
resp, err := s.sendRequest(s.config.ApiURL, p)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if resp.Code != 1 {
|
||||||
|
return "", errors.New(resp.Msg)
|
||||||
|
}
|
||||||
|
if resp.PayURL != "" {
|
||||||
|
return resp.PayURL, nil
|
||||||
|
} else {
|
||||||
|
return resp.QrCode, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeekPayService) Sign(params map[string]string) string {
|
func (s *EPayService) Sign(params map[string]string) string {
|
||||||
// 按字母顺序排序参数
|
// 按字母顺序排序参数
|
||||||
var keys []string
|
var keys []string
|
||||||
for k := range params {
|
for k := range params {
|
||||||
@@ -100,7 +100,7 @@ type GeekPayResp struct {
|
|||||||
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
|
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
for k, v := range params {
|
for k, v := range params {
|
||||||
form.Add(k, v)
|
form.Add(k, v)
|
||||||
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
|
|||||||
}
|
}
|
||||||
return &r, nil
|
return &r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||||
|
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("act", "order")
|
||||||
|
params.Set("pid", s.config.AppId)
|
||||||
|
params.Set("key", s.config.PrivateKey)
|
||||||
|
params.Set("out_trade_no", outTradeNo)
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
|
||||||
|
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
resp, err := client.Get(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return OrderInfo{}, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return OrderInfo{}, err
|
||||||
|
}
|
||||||
|
logger.Debugf(string(body))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Money string `json:"money"`
|
||||||
|
EndTime string `json:"endtime"`
|
||||||
|
TradeNo string `json:"trade_no"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return OrderInfo{}, errors.New("订单查询响应解析失败")
|
||||||
|
}
|
||||||
|
if result.Code != 1 {
|
||||||
|
return OrderInfo{}, errors.New(result.Msg)
|
||||||
|
}
|
||||||
|
logger.Debugf("订单信息:%+v", result)
|
||||||
|
orderInfo := OrderInfo{
|
||||||
|
OutTradeNo: outTradeNo,
|
||||||
|
TradeId: result.TradeNo,
|
||||||
|
Amount: result.Money,
|
||||||
|
PayTime: result.EndTime,
|
||||||
|
}
|
||||||
|
if result.Status == "1" {
|
||||||
|
orderInfo.Status = Success
|
||||||
|
} else {
|
||||||
|
orderInfo.Status = Failure
|
||||||
|
}
|
||||||
|
return orderInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ PayService = (*EPayService)(nil)
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
package payment
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/md5"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/utils"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type HuPiPayService struct {
|
|
||||||
appId string
|
|
||||||
appSecret string
|
|
||||||
apiURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
|
||||||
return &HuPiPayService{
|
|
||||||
appId: config.HuPiPayConfig.AppId,
|
|
||||||
appSecret: config.HuPiPayConfig.AppSecret,
|
|
||||||
apiURL: config.HuPiPayConfig.ApiURL,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type HuPiPayParams struct {
|
|
||||||
AppId string `json:"appid"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
TradeOrderId string `json:"trade_order_id"`
|
|
||||||
TotalFee string `json:"total_fee"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
NotifyURL string `json:"notify_url"`
|
|
||||||
ReturnURL string `json:"return_url"`
|
|
||||||
WapName string `json:"wap_name"`
|
|
||||||
CallbackURL string `json:"callback_url"`
|
|
||||||
Time string `json:"time"`
|
|
||||||
NonceStr string `json:"nonce_str"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
WapUrl string `json:"wap_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HuPiPayResp struct {
|
|
||||||
Openid interface{} `json:"openid"`
|
|
||||||
UrlQrcode string `json:"url_qrcode"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
ErrCode int `json:"errcode"`
|
|
||||||
ErrMsg string `json:"errmsg,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pay 执行支付请求操作
|
|
||||||
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
|
||||||
params.AppId = s.appId
|
|
||||||
params.Time = simple
|
|
||||||
params.NonceStr = simple
|
|
||||||
encode := utils.JsonEncode(params)
|
|
||||||
m := make(map[string]string)
|
|
||||||
_ = utils.JsonDecode(encode, &m)
|
|
||||||
for k, v := range m {
|
|
||||||
data.Add(k, fmt.Sprintf("%v", v))
|
|
||||||
}
|
|
||||||
// 生成签名
|
|
||||||
data.Add("hash", s.Sign(data))
|
|
||||||
// 发送支付请求
|
|
||||||
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
|
||||||
resp, err := http.PostForm(apiURL, data)
|
|
||||||
if err != nil {
|
|
||||||
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
all, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var res HuPiPayResp
|
|
||||||
err = utils.JsonDecode(string(all), &res)
|
|
||||||
if err != nil {
|
|
||||||
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.ErrCode != 0 {
|
|
||||||
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign 签名方法
|
|
||||||
func (s *HuPiPayService) Sign(params url.Values) string {
|
|
||||||
params.Del(`Sign`)
|
|
||||||
var keys = make([]string, 0, 0)
|
|
||||||
for key := range params {
|
|
||||||
if params.Get(key) != `` {
|
|
||||||
keys = append(keys, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
var pList = make([]string, 0, 0)
|
|
||||||
for _, key := range keys {
|
|
||||||
var value = strings.TrimSpace(params.Get(key))
|
|
||||||
if len(value) > 0 {
|
|
||||||
pList = append(pList, key+"="+value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var src = strings.Join(pList, "&")
|
|
||||||
src += s.appSecret
|
|
||||||
|
|
||||||
md5bs := md5.Sum([]byte(src))
|
|
||||||
return hex.EncodeToString(md5bs[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check 校验订单状态
|
|
||||||
func (s *HuPiPayService) Check(outTradeNo string) error {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Add("appid", s.appId)
|
|
||||||
data.Add("out_trade_order", outTradeNo)
|
|
||||||
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
|
||||||
data.Add("time", stamp)
|
|
||||||
data.Add("nonce_str", stamp)
|
|
||||||
data.Add("hash", s.Sign(data))
|
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
|
|
||||||
resp, err := http.PostForm(apiURL, data)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with http reqeust: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var r struct {
|
|
||||||
ErrCode int `json:"errcode"`
|
|
||||||
Data struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
OpenOrderId string `json:"open_order_id"`
|
|
||||||
} `json:"data,omitempty"`
|
|
||||||
ErrMsg string `json:"errmsg"`
|
|
||||||
Hash string `json:"hash"`
|
|
||||||
}
|
|
||||||
err = utils.JsonDecode(string(body), &r)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.ErrCode == 0 && r.Data.Status == "OD" {
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
logger.Debugf("%+v", r)
|
|
||||||
return errors.New("order not paid:" + r.ErrMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
54
api/service/payment/pay_service.go
Normal file
54
api/service/payment/pay_service.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package payment
|
||||||
|
|
||||||
|
// 支付渠道定义
|
||||||
|
const PayChannelAL = "alipay" // 支付宝
|
||||||
|
const PayChannelWX = "wxpay" // 微信支付
|
||||||
|
const PayChannelEpay = "epay" // 易支付
|
||||||
|
|
||||||
|
// 支付方式
|
||||||
|
const PayWayAL = "alipay"
|
||||||
|
const PayWayWX = "wxpay"
|
||||||
|
|
||||||
|
const (
|
||||||
|
Success = 0
|
||||||
|
Failure = 1
|
||||||
|
Closed = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type PayRequest struct {
|
||||||
|
OutTradeNo string // 商户订单号
|
||||||
|
Subject string // 商品名称
|
||||||
|
TotalFee string // 商品金额
|
||||||
|
ReturnURL string // 回调地址
|
||||||
|
NotifyURL string // 回调地址
|
||||||
|
|
||||||
|
// 易支付专有参数
|
||||||
|
Method string // 接口类型
|
||||||
|
Device string // 设备类型
|
||||||
|
PayWay string // 支付方式
|
||||||
|
ClientIP string //用户IP地址
|
||||||
|
OpenID string // 用户openid
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type OrderInfo struct {
|
||||||
|
Mchid string // 商户号
|
||||||
|
OutTradeNo string // 商户订单号
|
||||||
|
TradeId string // 交易号
|
||||||
|
Amount string // 金额
|
||||||
|
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
|
||||||
|
PayTime string // 完成支付时间
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OrderInfo) Closed() bool {
|
||||||
|
return o.Status == Closed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OrderInfo) Success() bool {
|
||||||
|
return o.Status == Success
|
||||||
|
}
|
||||||
|
|
||||||
|
type PayService interface {
|
||||||
|
Pay(params PayRequest) (string, error) // 生成支付链接
|
||||||
|
Query(outTradeNo string) (OrderInfo, error) // 查询订单
|
||||||
|
}
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package payment
|
|
||||||
|
|
||||||
type NotifyVo struct {
|
|
||||||
Status int
|
|
||||||
OutTradeNo string // 商户订单号
|
|
||||||
TradeId string // 交易ID
|
|
||||||
Amount string // 交易金额
|
|
||||||
Message string
|
|
||||||
Subject string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v NotifyVo) Success() bool {
|
|
||||||
return v.Status == Success
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
Success = 0
|
|
||||||
Failure = 1
|
|
||||||
)
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
package payment
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"github.com/go-pay/gopay"
|
|
||||||
"github.com/go-pay/gopay/wechat/v3"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type WechatPayService struct {
|
|
||||||
config *types.WechatPayConfig
|
|
||||||
client *wechat.ClientV3
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
|
|
||||||
config := appConfig.WechatPayConfig
|
|
||||||
if !config.Enabled {
|
|
||||||
logger.Info("Disabled WechatPay service")
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
priKey, err := readKey(config.PrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
|
|
||||||
}
|
|
||||||
err = client.AutoVerifySign()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
|
|
||||||
}
|
|
||||||
//client.DebugSwitch = gopay.DebugOn
|
|
||||||
|
|
||||||
return &WechatPayService{config: &config, client: client}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type WechatPayParams struct {
|
|
||||||
OutTradeNo string `json:"out_trade_no"`
|
|
||||||
TotalFee int `json:"total_fee"`
|
|
||||||
Subject string `json:"subject"`
|
|
||||||
ClientIP string `json:"client_ip"`
|
|
||||||
ReturnURL string `json:"return_url"`
|
|
||||||
NotifyURL string `json:"notify_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
|
||||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
|
||||||
// 初始化 BodyMap
|
|
||||||
bm := make(gopay.BodyMap)
|
|
||||||
bm.Set("appid", s.config.AppId).
|
|
||||||
Set("mchid", s.config.MchId).
|
|
||||||
Set("description", params.Subject).
|
|
||||||
Set("out_trade_no", params.OutTradeNo).
|
|
||||||
Set("time_expire", expire).
|
|
||||||
Set("notify_url", params.NotifyURL).
|
|
||||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
|
||||||
bm.Set("total", params.TotalFee).
|
|
||||||
Set("currency", "CNY")
|
|
||||||
})
|
|
||||||
|
|
||||||
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
|
||||||
}
|
|
||||||
if wxRsp.Code != wechat.Success {
|
|
||||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
|
||||||
}
|
|
||||||
return wxRsp.Response.CodeUrl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
|
||||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
|
||||||
// 初始化 BodyMap
|
|
||||||
bm := make(gopay.BodyMap)
|
|
||||||
bm.Set("appid", s.config.AppId).
|
|
||||||
Set("mchid", s.config.MchId).
|
|
||||||
Set("description", params.Subject).
|
|
||||||
Set("out_trade_no", params.OutTradeNo).
|
|
||||||
Set("time_expire", expire).
|
|
||||||
Set("notify_url", params.NotifyURL).
|
|
||||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
|
||||||
bm.Set("total", params.TotalFee).
|
|
||||||
Set("currency", "CNY")
|
|
||||||
}).
|
|
||||||
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
|
||||||
bm.Set("payer_client_ip", params.ClientIP).
|
|
||||||
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
|
||||||
bm.Set("type", "Wap")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
|
||||||
}
|
|
||||||
if wxRsp.Code != wechat.Success {
|
|
||||||
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
|
||||||
}
|
|
||||||
return wxRsp.Response.H5Url, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type NotifyResponse struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `xml:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TradeVerify 交易验证
|
|
||||||
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
|
|
||||||
notifyReq, err := wechat.V3ParseNotify(request)
|
|
||||||
if err != nil {
|
|
||||||
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签
|
|
||||||
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
|
|
||||||
//if err != nil {
|
|
||||||
// return fmt.Errorf("error with client v3 verify sign: %v", err)
|
|
||||||
//}
|
|
||||||
|
|
||||||
// 解密支付密文,验证订单信息
|
|
||||||
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
|
||||||
if err != nil {
|
|
||||||
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NotifyVo{
|
|
||||||
Status: Success,
|
|
||||||
OutTradeNo: result.OutTradeNo,
|
|
||||||
TradeId: result.TransactionId,
|
|
||||||
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
217
api/service/payment/wxpay_service.go
Normal file
217
api/service/payment/wxpay_service.go
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
package payment
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/utils"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-pay/gopay"
|
||||||
|
"github.com/go-pay/gopay/wechat/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WxPayService struct {
|
||||||
|
config *types.WxPayConfig
|
||||||
|
client *wechat.ClientV3
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
|
||||||
|
config := sysConfig.Payment.WxPay
|
||||||
|
if !config.Enabled {
|
||||||
|
logger.Debug("Disabled WechatPay service")
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &WxPayService{config: &config}
|
||||||
|
if config.Enabled {
|
||||||
|
err := service.UpdateConfig(&config)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("微信支付服务初始化失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
|
||||||
|
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error with initialize WechatPay service: %v", err)
|
||||||
|
}
|
||||||
|
err = client.AutoVerifySign()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error with autoVerifySign: %v", err)
|
||||||
|
}
|
||||||
|
s.client = client
|
||||||
|
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||||
|
logger.Info("WechatPay Debug mode is enabled")
|
||||||
|
client.DebugSwitch = gopay.DebugOn
|
||||||
|
}
|
||||||
|
s.config = config
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WxPayService) Pay(params PayRequest) (string, error) {
|
||||||
|
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||||
|
// 初始化 BodyMap
|
||||||
|
bm := make(gopay.BodyMap)
|
||||||
|
bm.Set("appid", s.config.AppId).
|
||||||
|
Set("mchid", s.config.MchId).
|
||||||
|
Set("description", params.Subject).
|
||||||
|
Set("out_trade_no", params.OutTradeNo).
|
||||||
|
Set("time_expire", expire).
|
||||||
|
Set("notify_url", params.NotifyURL).
|
||||||
|
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||||
|
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
|
||||||
|
Set("currency", "CNY")
|
||||||
|
})
|
||||||
|
logger.Debugf("wxpay params: %+v", bm)
|
||||||
|
if params.Device == "mobile" {
|
||||||
|
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||||
|
bm.Set("payer_client_ip", params.ClientIP)
|
||||||
|
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
|
||||||
|
bm.Set("openid", params.OpenID)
|
||||||
|
})
|
||||||
|
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
|
||||||
|
}
|
||||||
|
if wxRsp.Code != wechat.Success {
|
||||||
|
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||||
|
}
|
||||||
|
return wxRsp.Response.PrepayId, nil
|
||||||
|
} else if params.Device == "pc" {
|
||||||
|
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||||
|
}
|
||||||
|
if wxRsp.Code != wechat.Success {
|
||||||
|
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||||
|
}
|
||||||
|
return wxRsp.Response.CodeUrl, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||||
|
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
|
||||||
|
if err != nil {
|
||||||
|
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if wxRsp.Code != wechat.Success {
|
||||||
|
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if wxRsp.Response.TradeState == "CLOSED" {
|
||||||
|
return OrderInfo{Status: Closed}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
orderInfo := OrderInfo{
|
||||||
|
OutTradeNo: wxRsp.Response.OutTradeNo,
|
||||||
|
TradeId: wxRsp.Response.TransactionId,
|
||||||
|
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
|
||||||
|
PayTime: wxRsp.Response.SuccessTime,
|
||||||
|
}
|
||||||
|
if wxRsp.Response.TradeState == "SUCCESS" {
|
||||||
|
orderInfo.Status = Success
|
||||||
|
} else {
|
||||||
|
orderInfo.Status = Failure
|
||||||
|
}
|
||||||
|
return orderInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TradeVerify 交易验证
|
||||||
|
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
|
||||||
|
notifyReq, err := wechat.V3ParseNotify(request)
|
||||||
|
if err != nil {
|
||||||
|
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解密支付密文,验证订单信息
|
||||||
|
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
||||||
|
if err != nil {
|
||||||
|
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return OrderInfo{
|
||||||
|
Status: Success,
|
||||||
|
OutTradeNo: result.OutTradeNo,
|
||||||
|
TradeId: result.TransactionId,
|
||||||
|
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
||||||
|
PayTime: result.SuccessTime,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||||
|
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||||
|
// // 初始化 BodyMap
|
||||||
|
// bm := make(gopay.BodyMap)
|
||||||
|
// bm.Set("appid", s.config.AppId).
|
||||||
|
// Set("mchid", s.config.MchId).
|
||||||
|
// Set("description", params.Subject).
|
||||||
|
// Set("out_trade_no", params.OutTradeNo).
|
||||||
|
// Set("time_expire", expire).
|
||||||
|
// Set("notify_url", params.NotifyURL).
|
||||||
|
// SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||||
|
// bm.Set("total", params.TotalFee).
|
||||||
|
// Set("currency", "CNY")
|
||||||
|
// })
|
||||||
|
|
||||||
|
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||||
|
// if err != nil {
|
||||||
|
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||||
|
// }
|
||||||
|
// if wxRsp.Code != wechat.Success {
|
||||||
|
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||||
|
// }
|
||||||
|
// return wxRsp.Response.CodeUrl, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||||
|
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||||
|
// // 初始化 BodyMap
|
||||||
|
// bm := make(gopay.BodyMap)
|
||||||
|
// bm.Set("appid", s.config.AppId).
|
||||||
|
// Set("mchid", s.config.MchId).
|
||||||
|
// Set("description", params.Subject).
|
||||||
|
// Set("out_trade_no", params.OutTradeNo).
|
||||||
|
// Set("time_expire", expire).
|
||||||
|
// Set("notify_url", params.NotifyURL).
|
||||||
|
// SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||||
|
// bm.Set("total", params.TotalFee).
|
||||||
|
// Set("currency", "CNY")
|
||||||
|
// }).
|
||||||
|
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||||
|
// bm.Set("payer_client_ip", params.ClientIP).
|
||||||
|
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||||
|
// bm.Set("type", "Wap")
|
||||||
|
// })
|
||||||
|
// })
|
||||||
|
|
||||||
|
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
||||||
|
// if err != nil {
|
||||||
|
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
||||||
|
// }
|
||||||
|
// if wxRsp.Code != wechat.Success {
|
||||||
|
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
||||||
|
// }
|
||||||
|
// return wxRsp.Response.H5Url, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type NotifyResponse struct {
|
||||||
|
// Code string `json:"code"`
|
||||||
|
// Message string `xml:"message"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
var _ PayService = (*WxPayService)(nil)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user