merge v4.2.6

整合 v4.2.6 的后端中间件与服务层重构、前端样式体系迁移和管理端/移动端功能更新,统一清理历史冲突并完成版本升级。

Made-with: Cursor
This commit is contained in:
RockYang
2026-04-08 15:08:34 +08:00
390 changed files with 35519 additions and 25073 deletions

View File

@@ -0,0 +1,37 @@
---
name: frontend-developer
description: Use this agent when you need assistance with frontend development tasks including Vue.js components, UI implementation, styling, responsive design, state management, or frontend architecture decisions. Examples: <example>Context: User is working on a Vue.js component and needs help with implementing a responsive layout. user: 'I need to create a mobile-friendly chat interface component' assistant: 'I'll use the frontend-developer agent to help design and implement this responsive chat component' <commentary>Since this involves frontend development work with Vue.js and responsive design, use the frontend-developer agent.</commentary></example> <example>Context: User encounters styling issues with Element Plus components. user: 'The Element Plus dialog is not displaying correctly on mobile devices' assistant: 'Let me use the frontend-developer agent to troubleshoot this mobile styling issue' <commentary>This is a frontend styling problem that requires expertise in Element Plus and responsive design.</commentary></example>
color: purple
---
You are a Senior Frontend Development Engineer with deep expertise in modern web development technologies, particularly Vue.js 3, Element Plus, Vant, and responsive design patterns. You specialize in creating high-quality, maintainable frontend applications with excellent user experience.
Your core responsibilities include:
- Developing Vue.js 3 components using Composition API and best practices
- Implementing responsive designs that work seamlessly across desktop and mobile devices
- Working with Element Plus for desktop UI and Vant for mobile components
- Managing application state using Pinia store patterns
- Styling with Stylus preprocessor and Tailwind CSS utilities
- Optimizing build processes with Vite and ensuring proper code organization
- Implementing theme switching (dark/light mode) and accessibility features
- Follow decoupled development, with HTML, CSS, and JS codes placed in separate files for easier maintenance
When working on frontend tasks, you will:
1. Analyze requirements and suggest the most appropriate Vue.js patterns and component structures
2. Ensure responsive design principles are followed, considering both desktop and mobile viewports
3. Choose appropriate UI components from Element Plus (desktop) or Vant (mobile) libraries
4. Write clean, maintainable code following Vue.js 3 Composition API best practices
5. Consider performance implications and suggest optimizations when relevant
6. Ensure proper state management using Pinia when component state needs to be shared
7. Follow the project's established patterns for routing, API integration, and component organization
8. Provide specific code examples and explain the reasoning behind architectural decisions
You have deep knowledge of:
- Vue.js 3 ecosystem (Vue Router, Pinia, Composition API)
- Modern CSS techniques and preprocessors (Stylus, Tailwind)
- Component library integration (Element Plus, Vant)
- Build tools and development workflow (Vite, npm scripts)
- Cross-browser compatibility and mobile-first design principles
- Performance optimization and code splitting strategies
Always consider the user experience, code maintainability, and alignment with modern frontend development standards. When suggesting solutions, provide clear explanations and consider both immediate needs and long-term scalability.

View File

@@ -0,0 +1,6 @@
重构当前页面代码
1. 把当前页面 JS 代码全部抽离,然后是采用 Pinia 重构
2. 把当前页面 CSS 代码全部抽离,如果是 stylus 语法代码,则需要改成 SCSS 语法代码
3. 尽量做到代码的复用性,不要重复造轮子
4. 移动端的 css 和 js 分别放到对应的 mobile 目录下,不要覆盖 PC 端的代码

View File

@@ -1,5 +1,18 @@
# 更新日志 # 更新日志
## v4.2.6
- 功能重构:优化系统配置管理功能,把 OSS支付短信邮件等配置全部迁移到管理后台无需通过修改配置文档的方式修改 🎉🎉🎉
- 功能优化:重构 API 授权代码,采用中间件鉴权方式,实现更加精准的 API 鉴权 🎉🎉🎉
- 功能优化:优化 PC 端的 Suno 音乐,视频生成,以及即梦 AI 页面 UI
- 功能优化:重构登录和注册页面,兼容移动端和 PC 端,并且所有的登录组件共用了同一套组件代码,大大降低维护成本 🎉🎉🎉
- 功能优化:管理后台增加模型批量删除功能
- 功能优化:优化 Table 组件 UI并支持 dark 主题
- 功能优化:移动端对话页面支持上传文件和图片
- 功能新增:新增微信扫码登录支持
- 功能新增:新增安全监控,内容审核功能,支持敏感内容过滤拦截
- 功能新增DALL-E 绘图支持参 Google Banana 图片编辑功能
## v4.2.5 ## v4.2.5
- 功能优化:在代码右下角增加复制代码功能按钮,增加收起和展开代码功能 - 功能优化:在代码右下角增加复制代码功能按钮,增加收起和展开代码功能

View File

@@ -1,195 +0,0 @@
# 即梦 AI 配置功能说明
## 功能概述
即梦 AI 配置功能允许管理员通过 Web 界面配置即梦 AI 的 API 密钥和算力消耗设置,支持动态配置更新,无需重启服务。
## 功能特性
### 1. 秘钥配置
- AccessKey 和 SecretKey 配置
- 支持密码显示/隐藏
- 连接测试功能
### 2. 算力配置
- 文生图算力消耗
- 图生图算力消耗
- 图片编辑算力消耗
- 图片特效算力消耗
- 文生视频算力消耗
- 图生视频算力消耗
### 3. 动态配置
- 配置实时生效
- 无需重启服务
- 支持配置验证
## API 接口
### 获取配置
```
GET /api/admin/jimeng/config
```
### 更新配置
```
POST /api/admin/jimeng/config
Content-Type: application/json
{
"config": {
"access_key": "your_access_key",
"secret_key": "your_secret_key",
"power": {
"text_to_image": 10,
"image_to_image": 15,
"image_edit": 20,
"image_effects": 25,
"text_to_video": 30,
"image_to_video": 35
}
}
}
```
### 测试连接
```
POST /api/admin/jimeng/config/test
Content-Type: application/json
{
"config": {
"access_key": "your_access_key",
"secret_key": "your_secret_key"
}
}
```
## 前端页面
### 访问路径
管理后台 -> 即梦 AI -> 配置设置
### 页面功能
1. **秘钥配置标签页**
- AccessKey 输入框(密码模式)
- SecretKey 输入框(密码模式)
- 测试连接按钮
2. **算力配置标签页**
- 各种任务类型的算力消耗配置
- 数字输入框,支持 1-100 范围
- 提示信息说明
3. **操作按钮**
- 保存配置
- 重置配置
## 配置存储
配置存储在数据库的`config`表中:
- 配置键:`jimeng`
- 配置值JSON 格式的即梦 AI 配置
## 默认配置
如果配置不存在,系统会使用以下默认值:
```json
{
"access_key": "",
"secret_key": "",
"power": {
"text_to_image": 10,
"image_to_image": 15,
"image_edit": 20,
"image_effects": 25,
"text_to_video": 30,
"image_to_video": 35
}
}
```
## 使用流程
1. **初始配置**
- 访问管理后台即梦 AI 配置页面
- 填写 AccessKey 和 SecretKey
- 点击"测试连接"验证配置
- 调整各功能算力消耗
- 保存配置
2. **配置更新**
- 修改需要更新的配置项
- 保存配置
- 配置立即生效
3. **故障排查**
- 使用"测试连接"功能验证 API 密钥
- 检查配置是否正确保存
- 查看服务日志
## 注意事项
1. **权限要求**
- 只有管理员可以访问配置页面
- 需要有效的管理员登录会话
2. **配置验证**
- AccessKey 和 SecretKey 不能为空
- 算力消耗必须大于 0
- 建议先测试连接再保存配置
3. **服务影响**
- 配置更新不会影响正在进行的任务
- 新任务会使用更新后的配置
- 客户端配置会在下次请求时更新
## 错误处理
1. **配置加载失败**
- 使用默认配置
- 记录错误日志
2. **连接测试失败**
- 显示具体错误信息
- 建议检查 API 密钥
3. **配置保存失败**
- 显示错误信息
- 保留原有配置
## 开发说明
### 后端文件
- `api/handler/admin/jimeng_handler.go` - 配置管理 API
- `api/service/jimeng/service.go` - 配置服务逻辑
- `api/core/types/jimeng.go` - 配置类型定义
### 前端文件
- `web/src/views/admin/jimeng/JimengSetting.vue` - 配置页面
### 数据库
- `config`表存储配置信息
- 配置键:`jimeng`
- 配置值JSON 格式

145
README.md
View File

@@ -1,90 +1,77 @@
# GeekAI # 🚀 GeekAI-PLUS一站式 AI 创意生产力平台
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 **重新定义 AI 创作体验,让每个人都能成为内容创作大师**
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Claude, 通义千问KimiDeepSeekGitee AI 等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI 绘画功能 基于 GeekAI 项目开发的高级版增加了很多高级功能比如思维导图Dalle 绘画等。**高级版源码不会一次性开放,只提供镜像给大家免费使用**,源码会逐步逐步按照版同步迁移到[社区版GeekAI](https://github.com/yangjian102621/geekai)。所以如果大家想要二次开发,请移步去社区版
主要特性: ## ✨ 核心特色
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。 ### 🎨 **全能 AI 创作矩阵**
- 基于 Websocket 实现,完美的打字机体验。
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
- 支持 OpenAI, Claude, 通义千问KimiDeepSeek 等多个大语言模型,**支持 Gitee AI Serverless 大模型 API**。
- 支持 Suno 文生音乐
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
绘画函数插件。
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus) - **智能对话**:集成 ChatGPT、Claude 等多款顶级 AI 模型,支持角色扮演和专业对话
- **图像生成**:整合 MidJourney、DALL-E、Stable Diffusion 三大主流 AI 绘画引擎
- **音频创作**Suno AI 音乐生成,从旋律到歌词一键创作专属音乐
- **视频制作**Luma 和 KeLing即梦Veo3 视频 AI文本到视频创意无限
- **思维导图**AI 辅助思维整理,复杂想法可视化呈现
- [x] 更友好的 UI 界面 ### 🏗️ **企业级技术架构**
- [x] 支持 Dall-E 文生图功能
- [x] 支持文生思维导图 - **高性能后端**Go + Gin + MySQL + Redis支持高并发访问
- [x] 支持为模型绑定指定的 API KEY支持为角色绑定指定的模型等功能 - **现代化前端**Vue3 + Element Plus + Vant桌面移动双端适配
- [x] 支持网站 Logo 版权等信息的修改 - **智能缓存**:多层缓存策略,响应速度提升 80%
- **弹性部署**Docker 容器化部署,一键启动,轻松扩展
- **私有化部署**:支持私有化部署,私有化部署不支持升级,需要手动升级
- **文档支持**:丰富且详细的部署和 API 开发文档支持,二次开发轻松上手
### 💼 **商业化就绪**
- **完整用户系统**:注册登录、权限管理、积分充值
- **灵活计费模式**:支持按次付费、包月订阅等多种商业模式
- **数据统计分析**:用户行为、消费记录、系统性能全方位监控
- **管理后台**:功能完备的管理员界面,运营数据一目了然
### 🎯 **用户体验优势**
- **响应式设计**:完美适配桌面、平板、手机等全终端设备
- **暗黑模式**:支持明暗主题切换,护眼舒适
- **实时交互**WebSocket 实时通信,创作过程流畅无卡顿
- **文件管理**:支持多种云存储,作品安全可靠
## 🎪 **应用场景**
- **内容创作者**:博客写作、社交媒体素材、短视频制作
- **企业营销**:品牌宣传材料、产品介绍、创意广告
- **教育培训**:课件制作、知识图谱、互动内容
- **个人娱乐**AI 聊天、创意绘画、音乐创作
## 🔥 **为什么选择 GeekAI-PLUS**
1. **技术领先**:集成当前最先进的 AI 技术,始终保持创新前沿
2. **开箱即用**:完整的商业化解决方案,无需从零开发
3. **高度定制**:模块化架构设计,支持个性化功能扩展
4. **稳定可靠**:经过大量用户验证,性能稳定,安全可信
5. **持续更新**:紧跟 AI 技术发展,功能持续迭代升级
## 演示站点
[Geek-AI 创作系统](https://www.geekai.me)
## 文档地址
[Geek-AI 文档](https://www.geekai.me/docs/)
## 部署
1. 安装 docker 和 docker-compose 程序,这个自行解决。
2. 直接在项目根目录运行启动命令:
```shell
docker-compose up -d
```
## 功能截图 ## 功能截图
请参考 [GeekAI 项目介绍](https://docs.geekai.me/plus/info/)。 请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
### 体验地址 ---
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/> > **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!** _让 AI 成为你最强大的创作伙伴开启无限创意可能_
## 快速部署体验
您可以通过 EazyDevelop 平台体验-键私有化部署 **GeekAI 创作助手**,只需一分钟即可部署成功。
部署模板地址: [https://eazydevelop.eazytec-cloud.com/templates/dev-template-5e4dc4-1764053014?q=bB3R_1VnJq9_3Zs9uX](https://eazydevelop.eazytec-cloud.com/templates/dev-template-5e4dc4-1764053014?q=bB3R_1VnJq9_3Zs9uX)
详细部署教程请参考 [EazyDevelop 部署 GeekAI](https://docs.geekai.me/plus/install/quick-start.html#eazydevelop-一键部署)。
## 使用须知
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
## 项目地址
- Github 地址https://github.com/yangjian102621/geekai
- 码云地址https://gitee.com/blackfox/geekai
## 客户端下载
目前已经支持 Win/Linux/Mac/Android 客户端下载地址为https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
## 项目文档
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自 Github!!!)。**
![微信名片](https://docs.geekai.me/images/wx_card.png)
## 参与贡献
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合
### Commit 类型
- feat: 新特性或功能
- fix: 缺陷修复
- docs: 文档更新
- style: 代码风格或者组件样式更新
- refactor: 代码重构,不引入新功能和缺陷修复
- opt: 性能优化
- chore: 一些不涉及到功能变动的小提交,比如修改文字表述,修改注释等
## 打赏
如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~
![打赏](https://blog.img.r9it.com/image-f02ca9eccbe93c7b1193c2623e7336ea.png)
![Star History Chart](https://api.star-history.com/svg?repos=yangjian102621/geekai&type=Date)

View File

@@ -8,118 +8,50 @@ package core
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"bytes"
"context"
"fmt" "fmt"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp"
"image"
"image/jpeg"
"io" "io"
"net/http" "net/http"
"os"
"runtime/debug" "runtime/debug"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"github.com/nfnt/resize"
"github.com/shirou/gopsutil/host" "github.com/shirou/gopsutil/host"
"golang.org/x/image/webp"
"gorm.io/gorm" "gorm.io/gorm"
) )
// AuthConfig 定义授权配置
type AuthConfig struct {
ExactPaths map[string]bool // 精确匹配的路径
PrefixPaths map[string]bool // 前缀匹配的路径
}
var authConfig = &AuthConfig{
ExactPaths: map[string]bool{
"/api/user/login": false,
"/api/user/logout": false,
"/api/user/resetPass": false,
"/api/user/register": false,
"/api/admin/login": false,
"/api/admin/logout": false,
"/api/admin/login/captcha": false,
"/api/app/list": false,
"/api/app/type/list": false,
"/api/app/list/user": false,
"/api/model/list": false,
"/api/mj/imgWall": false,
"/api/mj/notify": false,
"/api/invite/hits": false,
"/api/sd/imgWall": false,
"/api/dall/imgWall": false,
"/api/product/list": false,
"/api/menu/list": false,
"/api/markMap/client": false,
"/api/payment/doPay": false,
"/api/payment/payWays": false,
"/api/download": false,
"/api/dall/models": false,
},
PrefixPaths: map[string]bool{
"/api/test/": false,
"/api/payment/notify/": false,
"/api/user/clogin": false,
"/api/config/": false,
"/api/function/": false,
"/api/sms/": false,
"/api/captcha/": false,
"/static/": false,
},
}
type AppServer struct { type AppServer struct {
Config *types.AppConfig Config *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
SysConfig *types.SystemConfig // system config cache SysConfig *types.SystemConfig // system config cache
Redis *redis.Client
} }
func NewServer(appConfig *types.AppConfig) *AppServer { func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard gin.DefaultWriter = io.Discard
return &AppServer{ return &AppServer{
Config: appConfig, Config: appConfig,
Engine: gin.Default(), Redis: redis,
Engine: gin.Default(),
SysConfig: sysConfig,
} }
} }
func (s *AppServer) Init(debug bool, client *redis.Client) { func (s *AppServer) Init(client *redis.Client) {
// 允许跨域请求 API s.Engine.Use(middleware.ParameterHandlerMiddleware())
s.Engine.Use(corsMiddleware())
s.Engine.Use(staticResourceMiddleware())
s.Engine.Use(authorizeMiddleware(s, client))
s.Engine.Use(parameterHandlerMiddleware())
s.Engine.Use(errorHandler) s.Engine.Use(errorHandler)
// 添加静态资源访问 // 添加静态资源访问
s.Engine.Static("/static", s.Config.StaticDir) s.Engine.Static("/static", s.Config.StaticDir)
s.Engine.Use(middleware.StaticMiddleware())
} }
func (s *AppServer) Run(db *gorm.DB) error { func (s *AppServer) Run(db *gorm.DB) error {
// 重命名 config 表字段
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if db.Migrator().HasColumn(&model.Config{}, "marker") {
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if db.Migrator().HasIndex(&model.Config{}, "marker") {
db.Migrator().DropIndex(&model.Config{}, "marker")
}
// load system configs // load system configs
var sysConfig model.Config var sysConfig model.Config
err := db.Where("name", "system").First(&sysConfig).Error err := db.Where("name", "system").First(&sysConfig).Error
@@ -131,94 +63,22 @@ func (s *AppServer) Run(db *gorm.DB) error {
return fmt.Errorf("failed to decode system config: %v", err) return fmt.Errorf("failed to decode system config: %v", err)
} }
// 迁移数据表
logger.Info("Migrating database tables...")
db.AutoMigrate(
&model.ChatItem{},
&model.ChatMessage{},
&model.ChatRole{},
&model.ChatModel{},
&model.InviteCode{},
&model.InviteLog{},
&model.Menu{},
&model.Order{},
&model.Product{},
&model.User{},
&model.Function{},
&model.File{},
&model.Redeem{},
&model.Config{},
&model.ApiKey{},
&model.AdminUser{},
&model.AppType{},
&model.SdJob{},
&model.SunoJob{},
&model.PowerLog{},
&model.VideoJob{},
&model.MidJourneyJob{},
&model.UserLoginLog{},
&model.DallJob{},
&model.JimengJob{},
)
// 手动删除字段
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.User{}, "chat_config") {
db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if db.Migrator().HasColumn(&model.ChatModel{}, "category") {
db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if db.Migrator().HasColumn(&model.ChatModel{}, "description") {
db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
logger.Info("Database tables migrated successfully")
// 统计安装信息 // 统计安装信息
go func() { go func() {
info, err := host.Info() info, err := host.Info()
if err == nil { if err == nil {
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push") apiURL := fmt.Sprintf("%s/api/installs/push", types.GeekAPIURL)
timestamp := time.Now().Unix() timestamp := time.Now().Unix()
product := "geekai-plus" product := "geekai-plus"
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp) signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
sign := utils.Sha256(signStr) sign := utils.Sha256(signStr)
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL) resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
if err != nil { if err == nil {
logger.Errorf("register install info failed: %v", err)
} else {
logger.Debugf("register install info success: %v", resp.String()) logger.Debugf("register install info success: %v", resp.String())
} }
} }
}() }()
logger.Infof("http://%s", s.Config.Listen) logger.Infof("http://%s", s.Config.Listen)
// 统计安装信息
go func() {
info, err := host.Info()
if err == nil {
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
timestamp := time.Now().Unix()
product := "geekai-plus"
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
sign := utils.Sha256(signStr)
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
if err != nil {
logger.Errorf("register install info failed: %v", err)
} else {
logger.Debugf("register install info success: %v", resp.String())
}
}
}()
return s.Engine.Run(s.Config.Listen) return s.Engine.Run(s.Config.Listen)
} }
@@ -235,283 +95,3 @@ func errorHandler(c *gin.Context) {
//加载完 defer recover继续后续接口调用 //加载完 defer recover继续后续接口调用
c.Next() c.Next()
} }
// 跨域中间件设置
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
origin := c.Request.Header.Get("Origin")
// 设置允许的请求源
if origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
} else {
c.Header("Access-Control-Allow-Origin", "*")
}
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
// 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间
c.Header("Access-Control-Max-Age", "172800")
//允许客户端传递校验信息比如 cookie (重要)
c.Header("Access-Control-Allow-Credentials", "true")
if method == http.MethodOptions {
c.JSON(http.StatusOK, "ok!")
}
defer func() {
if err := recover(); err != nil {
logger.Info("Panic info is: %v", err)
}
}()
c.Next()
}
}
// 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
if !needLogin(c) {
c.Next()
return
}
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
var tokenString string
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if clientProtocols != "" { // Websocket 连接
// 解析子协议内容
protocols := strings.Split(clientProtocols, ",")
if protocols[0] == "realtime" {
tokenString = strings.TrimSpace(protocols[1][25:])
} else if protocols[0] == "token" {
tokenString = strings.TrimSpace(protocols[1])
}
} else {
tokenString = c.GetHeader(types.UserAuthHeader)
}
if tokenString == "" {
resp.NotAuth(c, "You should put Authorization in request headers")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
if isAdminApi {
return []byte(s.Config.AdminSession.SecretKey), nil
} else {
return []byte(s.Config.Session.SecretKey), nil
}
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "Token is invalid")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "Token is expired")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if isAdminApi {
key = fmt.Sprintf("admin/%v", claims["user_id"])
}
if _, err := client.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "Token is not found in redis")
c.Abort()
return
}
c.Set(types.LoginUserID, claims["user_id"])
c.Next()
}
}
func needLogin(c *gin.Context) bool {
path := c.Request.URL.Path
// 如果不是 API 路径,不需要登录
if !strings.HasPrefix(path, "/api") {
return false
}
// 检查精确匹配的路径
if skip, exists := authConfig.ExactPaths[path]; exists {
return skip
}
// 检查前缀匹配的路径
for prefix, skip := range authConfig.PrefixPaths {
if strings.HasPrefix(path, prefix) {
return skip
}
}
return true
}
// 跳过授权
func (s *AppServer) SkipAuth(url string, prefix bool) {
if prefix {
authConfig.PrefixPaths[url] = false
} else {
authConfig.ExactPaths[url] = false
}
}
// 统一参数处理
func parameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// GET 参数处理
params := c.Request.URL.Query()
for key, values := range params {
for i, value := range values {
params[key][i] = strings.TrimSpace(value)
}
}
// update get parameters
c.Request.URL.RawQuery = params.Encode()
// skip file upload requests
contentType := c.Request.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
c.Next()
return
}
if strings.Contains(contentType, "application/json") {
// process POST JSON request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}
// 还原请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 将请求体解析为 JSON
var jsonData map[string]interface{}
if err := c.ShouldBindJSON(&jsonData); err != nil {
c.Next()
return
}
// 对 JSON 数据中的字符串值去除两端空格
trimJSONStrings(jsonData)
// 更新请求体
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
}
c.Next()
}
}
// 递归对 JSON 数据中的字符串值去除两端空格
func trimJSONStrings(data interface{}) {
switch v := data.(type) {
case map[string]interface{}:
for key, value := range v {
switch valueType := value.(type) {
case string:
v[key] = strings.TrimSpace(valueType)
case map[string]interface{}, []interface{}:
trimJSONStrings(value)
}
}
case []interface{}:
for i, value := range v {
switch valueType := value.(type) {
case string:
v[i] = strings.TrimSpace(valueType)
case map[string]interface{}, []interface{}:
trimJSONStrings(value)
}
}
}
}
// 静态资源中间件
func staticResourceMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
url := c.Request.URL.String()
// 拦截生成缩略图请求
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
r := strings.SplitAfter(url, "imageView2")
size := strings.Split(r[1], "/")
if len(size) != 8 {
c.String(http.StatusNotFound, "invalid thumb args")
return
}
with := utils.IntValue(size[3], 0)
height := utils.IntValue(size[5], 0)
quality := utils.IntValue(size[7], 75)
// 打开图片文件
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
file, err := os.Open(filePath)
if err != nil {
c.String(http.StatusNotFound, "Image not found")
return
}
defer file.Close()
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
}
var newImg image.Image
if height == 0 || with == 0 {
// 固定宽度,高度自适应
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
} else {
// 生成缩略图
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
}
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
logger.Error(err)
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)
c.Header("Cache-Control", "max-age=31536000, public")
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}
}

View File

@@ -11,10 +11,12 @@ import (
"bytes" "bytes"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils" "geekai/utils"
"os" "os"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"gorm.io/gorm"
) )
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
@@ -30,7 +32,6 @@ func NewDefaultConfig() *types.AppConfig {
SecretKey: utils.RandString(64), SecretKey: utils.RandString(64),
MaxAge: 86400, MaxAge: 86400,
}, },
ApiConfig: types.ApiConfig{},
OSS: types.OSSConfig{ OSS: types.OSSConfig{
Active: "local", Active: "local",
Local: types.LocalStorageConfig{ Local: types.LocalStorageConfig{
@@ -38,7 +39,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload", BasePath: "./static/upload",
}, },
}, },
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
} }
} }
@@ -74,3 +74,108 @@ func SaveConfig(config *types.AppConfig) error {
return os.WriteFile(config.Path, buf.Bytes(), 0644) return os.WriteFile(config.Path, buf.Bytes(), 0644)
} }
func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
// 加载系统配置
var sysConfig model.Config
var baseConfig types.BaseConfig
db.Where("name", "system").First(&sysConfig)
err := utils.JsonDecode(sysConfig.Value, &baseConfig)
if err != nil {
logger.Error("load system config error: ", err)
}
// 加载许可证配置
var license types.License
sysConfig.Id = 0
db.Where("name", types.ConfigKeyLicense).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &license)
if err != nil {
logger.Error("load license config error: ", err)
}
// 加载验证码配置
var captchaConfig types.CaptchaConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyCaptcha).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &captchaConfig)
if err != nil {
logger.Error("load geek service config error: ", err)
}
// 加载微信登录配置
var wxLoginConfig types.WxLoginConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyWxLogin).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &wxLoginConfig)
if err != nil {
logger.Error("load wx login config error: ", err)
}
// 加载短信配置
var smsConfig types.SMSConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeySms).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &smsConfig)
if err != nil {
logger.Error("load sms config error: ", err)
}
// 加载 OSS 配置
var ossConfig types.OSSConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyOss).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &ossConfig)
if err != nil {
logger.Error("load oss config error: ", err)
}
// 加载 SMTP 配置
var smtpConfig types.SmtpConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeySmtp).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &smtpConfig)
if err != nil {
logger.Error("load smtp config error: ", err)
}
// 加载支付配置
var paymentConfig types.PaymentConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyPayment).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &paymentConfig)
if err != nil {
logger.Error("load payment config error: ", err)
}
// 加载文本审查配置
var moderationConfig types.ModerationConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyModeration).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &moderationConfig)
if err != nil {
logger.Error("load moderation config error: ", err)
}
// 加载即梦AI配置
var jimengConfig types.JimengConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyJimeng).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &jimengConfig)
if err != nil {
logger.Error("load jimeng config error: ", err)
}
return &types.SystemConfig{
Base: baseConfig,
License: license,
SMS: smsConfig,
OSS: ossConfig,
SMTP: smtpConfig,
Payment: paymentConfig,
Captcha: captchaConfig,
WxLogin: wxLoginConfig,
Moderation: moderationConfig,
Jimeng: jimengConfig,
}
}

109
api/core/middleware/auth.go Normal file
View File

@@ -0,0 +1,109 @@
package middleware
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt"
)
// 前端用户授权验证
func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader(types.UserAuthHeader)
if tokenString == "" {
resp.NotAuth(c, "无效的授权令牌")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "令牌无效")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "令牌过期")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "当前用户已退出登录")
c.Abort()
return
}
c.Set(types.LoginUserID, claims["user_id"])
}
}
// 管理后台用户授权验证
func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader(types.AdminAuthHeader)
if tokenString == "" {
resp.NotAuth(c, "无效的授权令牌")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "令牌无效")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "令牌过期")
c.Abort()
return
}
key := fmt.Sprintf("admin/%v", claims["user_id"])
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "当前用户已退出登录")
c.Abort()
return
}
c.Set(types.AdminUserID, claims["user_id"])
}
}

View File

@@ -0,0 +1,80 @@
package middleware
import (
"bytes"
"geekai/utils"
"io"
"strings"
"github.com/gin-gonic/gin"
)
// 统一参数处理
func ParameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// GET 参数处理
params := c.Request.URL.Query()
for key, values := range params {
for i, value := range values {
params[key][i] = strings.TrimSpace(value)
}
}
// update get parameters
c.Request.URL.RawQuery = params.Encode()
// skip file upload requests
contentType := c.Request.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
c.Next()
return
}
if strings.Contains(contentType, "application/json") {
// process POST JSON request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}
// 还原请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 将请求体解析为 JSON
var jsonData map[string]any
if err := c.ShouldBindJSON(&jsonData); err != nil {
c.Next()
return
}
// 对 JSON 数据中的字符串值去除两端空格
trimJSONStrings(jsonData)
// 更新请求体
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
}
c.Next()
}
}
// 递归对 JSON 数据中的字符串值去除两端空格
func trimJSONStrings(data any) {
switch v := data.(type) {
case map[string]any:
for key, value := range v {
switch valueType := value.(type) {
case string:
v[key] = strings.TrimSpace(valueType)
case map[string]any, []any:
trimJSONStrings(value)
}
}
case []any:
for i, value := range v {
switch valueType := value.(type) {
case string:
v[i] = strings.TrimSpace(valueType)
case map[string]any, []any:
trimJSONStrings(value)
}
}
}
}

View File

@@ -0,0 +1,43 @@
package middleware
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求
// Key 优先使用登录用户ID若没有则退化为 route + IP
func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
keyID := ""
if userID, ok := c.Get(types.LoginUserID); ok {
keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID))
} else {
keyID = fmt.Sprintf("ip:%s", c.ClientIP())
}
fullPath := c.FullPath()
if fullPath == "" {
fullPath = c.Request.URL.Path
}
key := fmt.Sprintf("rl:%s:%s", fullPath, keyID)
okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result()
if err != nil {
// Redis 异常时放行,避免误伤可用性
return
}
if !okSet {
c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"})
c.Abort()
return
}
}
}

View File

@@ -0,0 +1,78 @@
package middleware
import (
"bytes"
"geekai/utils"
"image"
"image/jpeg"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/nfnt/resize"
"golang.org/x/image/webp"
)
// 静态资源中间件
func StaticMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
url := c.Request.URL.String()
// 拦截生成缩略图请求
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
r := strings.SplitAfter(url, "imageView2")
size := strings.Split(r[1], "/")
if len(size) != 8 {
c.String(http.StatusNotFound, "invalid thumb args")
return
}
with := utils.IntValue(size[3], 0)
height := utils.IntValue(size[5], 0)
quality := utils.IntValue(size[7], 75)
// 打开图片文件
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
file, err := os.Open(filePath)
if err != nil {
c.String(http.StatusNotFound, "Image not found")
return
}
defer file.Close()
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
}
var newImg image.Image
if height == 0 || with == 0 {
// 固定宽度,高度自适应
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
} else {
// 生成缩略图
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
}
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)
c.Header("Cache-Control", "max-age=31536000, public")
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}
}

View File

@@ -17,88 +17,17 @@ type AppConfig struct {
Session Session Session Session
AdminSession Session AdminSession Session
ProxyURL string ProxyURL string
MysqlDns string // mysql 连接地址 MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录 StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息 Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs SMS SMSConfig // send mobile message config
SMS SMSConfig // send mobile message config OSS OSSConfig // OSS config
OSS OSSConfig // OSS config SmtpConfig SmtpConfig // 邮件发送配置
SmtpConfig SmtpConfig // 邮件发送配置 AlipayConfig AlipayConfig // 支付宝支付渠道配置
XXLConfig XXLConfig GeekPayConfig EpayConfig // GEEK 支付配置
AlipayConfig AlipayConfig // 支付宝支付渠道配置 WechatPayConfig WxPayConfig // 微信支付渠道配置
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置 TikaHost string // TiKa 服务器地址
GeekPayConfig GeekPayConfig // GEEK 支付配置
WechatPayConfig WechatPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址
}
type SmtpConfig struct {
UseTls bool // 是否使用 TLS 发送
Host string
Port int
AppName string // 应用名称
From string // 发件人邮箱地址
Password string // 发件人邮箱密码
}
type ApiConfig struct {
ApiURL string
AppId string
Token string
JimengConfig JimengConfig // 即梦AI配置
}
type AlipayConfig struct {
Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境
AppId string // 应用 ID
UserId string // 支付宝用户 ID
PrivateKey string // 用户私钥文件路径
PublicKey string // 用户公钥文件路径
AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
}
type WechatPayConfig struct {
Enabled bool // 是否启用该支付通道
AppId string // 公众号的APPID,如wxd678efh567hg6787
MchId string // 直连商户的商户号,由微信支付生成并下发
SerialNo string // 商户证书的证书序列号
PrivateKey string // 用户私钥文件路径
ApiV3Key string // API V3 秘钥
NotifyURL string // 异步通知地址
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道
AppId string // App ID
AppSecret string // app 密钥
ApiURL string // 支付网关
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
}
// GeekPayConfig GEEK支付配置
type GeekPayConfig struct {
Enabled bool
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
Methods []string // 支付方式
}
type XXLConfig struct { // XXL 任务调度配置
Enabled bool
ServerAddr string
ExecutorIp string
ExecutorPort string
AccessToken string
RegistryKey string
} }
type RedisConfig struct { type RedisConfig struct {
@@ -128,32 +57,28 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port) return fmt.Sprintf("%s:%d", c.Host, c.Port)
} }
type SystemConfig struct { type BaseConfig struct {
Title string `json:"title,omitempty"` // 网站标题 Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题 AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` // 圆形 Logo Logo string `json:"logo,omitempty"` // 圆形 Logo
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册 RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册 EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间 OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间,单位:分钟
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力 MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力 SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力 SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力 LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力 KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力 AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
@@ -163,15 +88,44 @@ type SystemConfig struct {
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词 SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单 IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息 Copyright string `json:"copyright"` // 版权信息
DefaultNickname string `json:"default_nickname"` // 默认昵称 ICP string `json:"icp"` // ICP 备案号
ICP string `json:"icp"` // ICP 备案号 GaBeian string `json:"ga_beian"` // 公安备案号
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表 EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位MB MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位MB
} }
type SystemConfig struct {
Base BaseConfig
Payment PaymentConfig
OSS OSSConfig
SMS SMSConfig
SMTP SmtpConfig
Captcha CaptchaConfig
WxLogin WxLoginConfig
Jimeng JimengConfig
License License
Moderation ModerationConfig
}
// 配置键名常量
const (
ConfigKeySystem = "system"
ConfigKeyNotice = "notice"
ConfigKeyAgreement = "agreement"
ConfigKeyPrivacy = "privacy"
ConfigKeyMarkMap = "mark_map"
ConfigKeyCaptcha = "captcha"
ConfigKeyWxLogin = "wx_login"
ConfigKeyLicense = "license"
ConfigKeySms = "sms"
ConfigKeySmtp = "smtp"
ConfigKeyOss = "oss"
ConfigKeyPayment = "payment"
ConfigKeyModeration = "moderation"
ConfigKeyAI3D = "ai3d"
ConfigKeyJimeng = "jimeng"
)

33
api/core/types/geekai.go Normal file
View File

@@ -0,0 +1,33 @@
package types
import "os"
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// GeekAI 增值服务
var GeekAPIURL = "https://sapi.geekai.me"
func init() {
if os.Getenv("GEEK_API_URL") != "" {
GeekAPIURL = os.Getenv("GEEK_API_URL")
}
}
// CaptchaConfig 行为验证码配置
type CaptchaConfig struct {
ApiKey string `json:"api_key"`
Type string `json:"type"` // 验证码类型, 可选值: "dot" 或 "slide"
Enabled bool `json:"enabled"`
}
// WxLoginConfig 微信登录配置
type WxLoginConfig struct {
ApiKey string `json:"api_key"`
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
Enabled bool `json:"enabled"` // 是否启用微信登录
}

View File

@@ -0,0 +1,73 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// 文本审查
type ModerationConfig struct {
Enable bool `json:"enable"` // 是否启用文本审查
Active string `json:"active"`
EnableGuide bool `json:"enable_guide"` // 是否启用模型引导提示词
GuidePrompt string `json:"guide_prompt"` // 模型引导提示词
Gitee ModerationGiteeConfig `json:"gitee"`
Baidu ModerationBaiduConfig `json:"baidu"`
Tencent ModerationTencentConfig `json:"tencent"`
}
const (
ModerationGitee = "gitee"
ModerationBaidu = "baidu"
ModerationTencent = "tencent"
)
// GiteeAI 文本审查配置
type ModerationGiteeConfig struct {
ApiKey string `json:"api_key"`
Model string `json:"model"` // 文本审核模型
}
// 百度文本审查配置
type ModerationBaiduConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
}
// 腾讯云文本审查配置
type ModerationTencentConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
}
type ModerationResult struct {
Flagged bool `json:"flagged"`
Categories map[string]bool `json:"categories"`
CategoryScores map[string]float64 `json:"category_scores"`
}
var ModerationCategories = map[string]string{
"politic": "内容涉及人物、事件或敏感的政治观点",
"porn": "明确的色情内容",
"insult": "具有侮辱、攻击性语言、人身攻击或冒犯性表达",
"violence": "包含暴力、血腥、攻击行为或煽动暴力的言论",
"illegal": "涉及违法活动的内容,如诈骗、赌博等",
"terror": "宣扬恐怖主义、极端暴力或煽动恐怖行为的内容",
"ad": "垃圾广告或未经许可的推广内容",
"spam": "无意义重复内容或诱导性信息",
"abuse": "人身攻击、恶意辱骂或侮辱性言论",
"polity": "涉及国家政治、领导人或政策的违规讨论内容",
}
// 敏感词来源
const (
ModerationSourceChat = "chat"
ModerationSourceMJ = "mj"
ModerationSourceDalle = "dalle"
ModerationSourceSD = "sd"
ModerationSourceSuno = "suno"
ModerationSourceVideo = "video"
ModerationSourceJiMeng = "jimeng"
)

View File

@@ -11,29 +11,25 @@ type OrderStatus int
const ( const (
OrderNotPaid = OrderStatus(0) OrderNotPaid = OrderStatus(0)
OrderScanned = OrderStatus(1) // 已扫码 OrderPaidSuccess = OrderStatus(2) // 已支付
OrderPaidSuccess = OrderStatus(2) OrderPaidFailed = OrderStatus(3) // 已关闭
) )
type OrderRemark struct { type OrderRemark struct {
Days int `json:"days"` // 有效期 Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点数 Power int `json:"power"` // 增加算力点数
Name string `json:"name"` // 产品名称 Name string `json:"name"` // 产品名称
Price float64 `json:"price"` Price float64 `json:"price"`
Discount float64 `json:"discount"`
} }
var PayMethods = map[string]string{ // PayChannel 支付渠道
var PayChannel = map[string]string{
"alipay": "支付宝商号", "alipay": "支付宝商号",
"wechat": "微信商号", "wxpay": "微信商号",
"hupi": "虎皮椒", "epay": "易支付",
"geek": "易支付",
} }
var PayNames = map[string]string{
var PayWays = map[string]string{
"alipay": "支付宝", "alipay": "支付宝",
"wxpay": "微信支付", "wxpay": "微信支付",
"qqpay": "QQ钱包",
"jdpay": "京东支付",
"douyin": "抖音支付",
"paypal": "PayPal支付",
} }

View File

@@ -8,41 +8,39 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OSSConfig struct { type OSSConfig struct {
Active string Active string `json:"active"`
Local LocalStorageConfig Local LocalStorageConfig `json:"local"`
Minio MiniOssConfig Minio MiniOssConfig `json:"minio"`
QiNiu QiNiuOssConfig QiNiu QiNiuOssConfig `json:"qiniu"`
AliYun AliYunOssConfig AliYun AliYunOssConfig `json:"aliyun"`
} }
type MiniOssConfig struct { type MiniOssConfig struct {
Endpoint string Endpoint string `json:"endpoint"`
AccessKey string AccessKey string `json:"access_key"`
AccessSecret string AccessSecret string `json:"access_secret"`
Bucket string Bucket string `json:"bucket"`
SubDir string UseSSL bool `json:"use_ssl"`
UseSSL bool Domain string `json:"domain"`
Domain string
} }
type QiNiuOssConfig struct { type QiNiuOssConfig struct {
Zone string Zone string `json:"zone"`
AccessKey string AccessKey string `json:"access_key"`
AccessSecret string AccessSecret string `json:"access_secret"`
Bucket string Bucket string `json:"bucket"`
SubDir string Domain string `json:"domain"`
Domain string
} }
type AliYunOssConfig struct { type AliYunOssConfig struct {
Endpoint string Endpoint string `json:"endpoint"`
AccessKey string AccessKey string `json:"access_key"`
AccessSecret string AccessSecret string `json:"access_secret"`
Bucket string Bucket string `json:"bucket"`
SubDir string Domain string `json:"domain"`
Domain string
} }
type LocalStorageConfig struct { type LocalStorageConfig struct {
BasePath string BasePath string `json:"base_path"`
BaseURL string BaseURL string `json:"base_url"`
} }

60
api/core/types/payment.go Normal file
View File

@@ -0,0 +1,60 @@
package types
type PaymentConfig struct {
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
Epay EpayConfig `json:"epay"` // 易支付配置
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
}
// AlipayConfig 支付宝支付配置
type AlipayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
SandBox bool `json:"sandbox"` // 是否沙盒环境
AppId string `json:"app_id"` // 应用 ID
PrivateKey string `json:"private_key"` // 应用私钥
AlipayPublicKey string `json:"alipay_public_key"` // 支付宝公钥
Domain string `json:"domain"` // 支付回调域名
}
func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
return c.AppId == other.AppId &&
c.PrivateKey == other.PrivateKey &&
c.AlipayPublicKey == other.AlipayPublicKey &&
c.Domain == other.Domain
}
// WxPayConfig 微信支付配置
type WxPayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 公众号的APPID,如wxd678efh567hg6787
MchId string `json:"mch_id"` // 直连商户的商户号,由微信支付生成并下发
SerialNo string `json:"serial_no"` // 商户证书的证书序列号
PrivateKey string `json:"private_key"` // 商户证书私钥
ApiV3Key string `json:"api_v3_key"` // API V3 秘钥
Domain string `json:"domain"` // 支付回调域名
}
func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
return c.AppId == other.AppId &&
c.MchId == other.MchId &&
c.SerialNo == other.SerialNo &&
c.PrivateKey == other.PrivateKey &&
c.ApiV3Key == other.ApiV3Key &&
c.Domain == other.Domain
}
// EpayConfig 易支付配置
type EpayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 商户 ID
PrivateKey string `json:"private_key"` // 私钥
ApiURL string `json:"api_url"` // z支付 API 网关
Domain string `json:"domain"` // 支付回调域名
}
func (c *EpayConfig) Equal(other *EpayConfig) bool {
return c.AppId == other.AppId &&
c.PrivateKey == other.PrivateKey &&
c.ApiURL == other.ApiURL &&
c.Domain == other.Domain
}

View File

@@ -8,6 +8,7 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const LoginUserID = "LOGIN_USER_ID" const LoginUserID = "LOGIN_USER_ID"
const AdminUserID = "ADMIN_USER_ID"
const LoginUserCache = "LOGIN_USER_CACHE" const LoginUserCache = "LOGIN_USER_CACHE"
const UserAuthHeader = "Authorization" const UserAuthHeader = "Authorization"

View File

@@ -8,26 +8,23 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SMSConfig struct { type SMSConfig struct {
Active string Active string `json:"active"`
Ali SmsConfigAli Ali SmsConfigAli `json:"aliyun"`
Bao SmsConfigBao Bao SmsConfigBao `json:"bao"`
} }
// SmsConfigAli 阿里云短信平台配置 // SmsConfigAli 阿里云短信平台配置
type SmsConfigAli struct { type SmsConfigAli struct {
AccessKey string AccessKey string `json:"access_key"`
AccessSecret string AccessSecret string `json:"access_secret"`
Product string Sign string `json:"sign"` // 短信签名
Domain string CodeTempId string `json:"code_temp_id"` // 验证码短信模板 ID
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
} }
// SmsConfigBao 短信宝平台配置 // SmsConfigBao 短信宝平台配置
type SmsConfigBao struct { type SmsConfigBao struct {
Username string //短信宝平台注册的用户名 Username string `json:"username"` //短信宝平台注册的用户名
Password string //短信宝平台注册的密码 Password string `json:"password"` //短信宝平台注册的密码
Domain string //域 Sign string `json:"sign"` // 短信签
Sign string // 短信签名 CodeTemplate string `json:"code_template"` // 验证码短信模板 匹配
CodeTemplate string // 验证码短信模板 匹配
} }

26
api/core/types/smtp.go Normal file
View File

@@ -0,0 +1,26 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SmtpConfig struct {
UseTls bool `json:"use_tls"` // 是否使用 TLS 发送
Host string `json:"host"` // 邮件服务器地址
Port int `json:"port"` // 邮件服务器端口
AppName string `json:"app_name"` // 应用名称
From string `json:"from"` // 发件人邮箱地址
Password string `json:"password"` // 发件人邮箱密码
}
func (s *SmtpConfig) Equal(other *SmtpConfig) bool {
return s.UseTls == other.UseTls &&
s.Host == other.Host &&
s.Port == other.Port &&
s.AppName == other.AppName &&
s.From == other.From &&
s.Password == other.Password
}

View File

@@ -70,17 +70,18 @@ type SdTaskParams struct {
// DallTask DALL-E task // DallTask DALL-E task
type DallTask struct { type DallTask struct {
ModelId uint `json:"model_id"` ModelId uint `json:"model_id"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
Id uint `json:"id"` Image []string `json:"image,omitempty"`
UserId uint `json:"user_id"` Id uint `json:"id"`
Prompt string `json:"prompt"` UserId uint `json:"user_id"`
N int `json:"n"` Prompt string `json:"prompt"`
Quality string `json:"quality"` N int `json:"n"`
Size string `json:"size"` Quality string `json:"quality"`
Style string `json:"style"` Size string `json:"size"`
Power int `json:"power"` Style string `json:"style"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID Power int `json:"power"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
} }
type SunoTask struct { type SunoTask struct {

View File

@@ -4,7 +4,7 @@ build_name: runner-build
build_log: runner-build-errors.log build_log: runner-build-errors.log
valid_ext: .go, .tpl, .tmpl, .html valid_ext: .go, .tpl, .tmpl, .html
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
ignored: assets, tmp, web, .git, .idea, test, data ignored: assets, tmp, web, .git, .idea, test, data, static
build_delay: 600 build_delay: 600
colors: 1 colors: 1
log_color_main: cyan log_color_main: cyan

View File

@@ -24,11 +24,9 @@ require (
gorm.io/driver/mysql v1.4.7 gorm.io/driver/mysql v1.4.7
) )
require github.com/xxl-job/xxl-job-executor-go v1.2.0
require ( require (
github.com/go-pay/gopay v1.5.101 github.com/go-pay/gopay v1.5.101
github.com/go-rod/rod v0.116.2 github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/go-tika v0.3.1 github.com/google/go-tika v0.3.1
github.com/microcosm-cc/bluemonday v1.0.26 github.com/microcosm-cc/bluemonday v1.0.26
github.com/sashabaranov/go-openai v1.38.1 github.com/sashabaranov/go-openai v1.38.1
@@ -50,11 +48,6 @@ require (
github.com/gorilla/css v1.0.0 // indirect github.com/gorilla/css v1.0.0 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect github.com/tklauser/numcpus v0.7.0 // indirect
github.com/ysmood/fetchup v0.3.0 // indirect
github.com/ysmood/goob v0.4.0 // indirect
github.com/ysmood/got v0.40.0 // indirect
github.com/ysmood/gson v0.7.3 // indirect
github.com/ysmood/leakless v0.9.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.4.0 // indirect
) )
@@ -69,7 +62,6 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gaukas/godicttls v0.0.3 // indirect github.com/gaukas/godicttls v0.0.3 // indirect
github.com/go-basic/ipv4 v1.0.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect

View File

@@ -46,8 +46,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
@@ -80,8 +78,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
@@ -89,6 +85,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@@ -261,22 +259,6 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8= github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU= github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
github.com/ysmood/fetchup v0.3.0 h1:UhYz9xnLEVn2ukSuK3KCgcznWpHMdrmbsPpllcylyu8=
github.com/ysmood/fetchup v0.3.0/go.mod h1:hbysoq65PXL0NQeNzUczNYIKpwpkwFL4LXMDEvIQq9A=
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=

View File

@@ -11,6 +11,7 @@ import (
"context" "context"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
logger2 "geekai/logger" logger2 "geekai/logger"
@@ -19,9 +20,10 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -45,6 +47,26 @@ func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, cap
} }
} }
// RegisterRoutes 注册路由
func (h *ManagerHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/")
// 公开接口,不需要授权
group.POST("login", h.Login)
group.GET("logout", h.Logout)
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}
}
// Login 登录 // Login 登录
func (h *ManagerHandler) Login(c *gin.Context) { func (h *ManagerHandler) Login(c *gin.Context) {
var data struct { var data struct {
@@ -59,19 +81,6 @@ func (h *ManagerHandler) Login(c *gin.Context) {
return return
} }
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var manager model.AdminUser var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager) res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil { if res.Error != nil {
@@ -135,16 +144,15 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
// Session 会话检测 // Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) { func (h *ManagerHandler) Session(c *gin.Context) {
id := h.GetLoginUserId(c) id := h.GetAdminId(c)
key := fmt.Sprintf("admin/%d", id) if id == 0 {
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil { resp.NotAuth(c, "当前用户已退出登录")
resp.NotAuth(c)
return return
} }
var manager model.AdminUser var manager model.AdminUser
res := h.DB.Where("id", id).First(&manager) err := h.DB.Where("id", id).First(&manager).Error
if res.Error != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c, "当前用户已退出登录")
return return
} }

View File

@@ -10,6 +10,7 @@ package admin
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
@@ -30,6 +31,20 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}} return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
} }
// RegisterRoutes 注册路由
func (h *ApiKeyHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/apikey/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}
}
func (h *ApiKeyHandler) Save(c *gin.Context) { func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`

View File

@@ -10,6 +10,7 @@ package admin
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
@@ -30,14 +31,29 @@ func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ChatAppHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/role/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}
}
// Save 创建或者更新某个角色 // Save 创建或者更新某个角色
func (h *ChatAppHandler) Save(c *gin.Context) { func (h *ChatAppHandler) Save(c *gin.Context) {
var data vo.ChatRole var data vo.ChatApp
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
var role model.ChatRole var role model.ChatApp
err := utils.CopyObject(data, &role) err := utils.CopyObject(data, &role)
if err != nil { if err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -65,8 +81,8 @@ func (h *ChatAppHandler) Save(c *gin.Context) {
} }
func (h *ChatAppHandler) List(c *gin.Context) { func (h *ChatAppHandler) List(c *gin.Context) {
var items []model.ChatRole var items []model.ChatApp
var roles = make([]vo.ChatRole, 0) var roles = make([]vo.ChatApp, 0)
res := h.DB.Order("sort_num ASC").Find(&items) res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "No data found") resp.ERROR(c, "No data found")
@@ -107,7 +123,7 @@ func (h *ChatAppHandler) List(c *gin.Context) {
} }
for _, v := range items { for _, v := range items {
var role vo.ChatRole var role vo.ChatApp
err := utils.CopyObject(v, &role) err := utils.CopyObject(v, &role)
if err == nil { if err == nil {
role.Id = v.Id role.Id = v.Id
@@ -135,7 +151,7 @@ func (h *ChatAppHandler) Sort(c *gin.Context) {
} }
for index, id := range data.Ids { for index, id := range data.Ids {
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error err := h.DB.Model(&model.ChatApp{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
@@ -157,7 +173,7 @@ func (h *ChatAppHandler) Set(c *gin.Context) {
return return
} }
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error err := h.DB.Model(&model.ChatApp{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
@@ -172,9 +188,8 @@ func (h *ChatAppHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
res := h.DB.Where("id", id).Delete(&model.ChatRole{}) res := h.DB.Where("id", id).Delete(&model.ChatApp{})
if res.Error != nil { if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "删除失败!") resp.ERROR(c, "删除失败!")
return return
} }

View File

@@ -2,12 +2,14 @@ package admin
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -20,6 +22,21 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ChatAppTypeHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/app/type/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}
}
// Save 创建或更新App类型 // Save 创建或更新App类型
func (h *ChatAppTypeHandler) Save(c *gin.Context) { func (h *ChatAppTypeHandler) Save(c *gin.Context) {
var data struct { var data struct {

View File

@@ -9,6 +9,7 @@ package admin
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
@@ -28,16 +29,31 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ChatHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/chat/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
}
}
type chatItemVo struct { type chatItemVo struct {
Username string `json:"username"` Username string `json:"username"`
UserId uint `json:"user_id"` UserId uint `json:"user_id"`
ChatId string `json:"chat_id"` ChatId string `json:"chat_id"`
Title string `json:"title"` Title string `json:"title"`
Role vo.ChatRole `json:"role"` Role vo.ChatApp `json:"role"`
Model string `json:"model"` Model string `json:"model"`
Token int `json:"token"` Token int `json:"token"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量 MsgNum int `json:"msg_num"` // 消息数量
} }
func (h *ChatHandler) List(c *gin.Context) { func (h *ChatHandler) List(c *gin.Context) {
@@ -87,7 +103,7 @@ func (h *ChatHandler) List(c *gin.Context) {
} }
var messages []model.ChatMessage var messages []model.ChatMessage
var users []model.User var users []model.User
var roles []model.ChatRole var roles []model.ChatApp
h.DB.Where("chat_id IN ?", chatIds).Find(&messages) h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users) h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles) h.DB.Where("id IN ?", roleIds).Find(&roles)
@@ -95,7 +111,7 @@ func (h *ChatHandler) List(c *gin.Context) {
tokenMap := make(map[string]int) tokenMap := make(map[string]int)
userMap := make(map[uint]string) userMap := make(map[uint]string)
msgMap := make(map[string]int) msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole) roleMap := make(map[uint]vo.ChatApp)
for _, msg := range messages { for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1 msgMap[msg.ChatId] += 1
@@ -104,7 +120,7 @@ func (h *ChatHandler) List(c *gin.Context) {
userMap[user.Id] = user.Username userMap[user.Id] = user.Username
} }
for _, r := range roles { for _, r := range roles {
var roleVo vo.ChatRole var roleVo vo.ChatApp
err := utils.CopyObject(r, &roleVo) err := utils.CopyObject(r, &roleVo)
if err != nil { if err != nil {
continue continue

View File

@@ -8,7 +8,9 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
@@ -28,6 +30,22 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ChatModelHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/model/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("set", h.Set)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
group.POST("batch-remove", h.BatchRemove)
}
}
func (h *ChatModelHandler) Save(c *gin.Context) { func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
@@ -201,3 +219,33 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// BatchRemove 批量删除模型
func (h *ChatModelHandler) BatchRemove(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if len(data.Ids) == 0 {
resp.ERROR(c, "请选择要删除的模型")
return
}
// 执行批量删除
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.ChatModel{}).Error
if err != nil {
logger.Error("批量删除模型失败:", err)
resp.ERROR(c, "批量删除失败:"+err.Error())
return
}
resp.SUCCESS(c, gin.H{
"message": fmt.Sprintf("成功删除 %d 个模型", len(data.Ids)),
"deleted_count": len(data.Ids),
})
}

View File

@@ -9,106 +9,399 @@ package admin
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
"geekai/store" "geekai/service/oss"
"geekai/service/payment"
"geekai/service/sms"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ConfigHandler struct { type ConfigHandler struct {
handler.BaseHandler handler.BaseHandler
levelDB *store.LevelDB licenseService *service.LicenseService
licenseService *service.LicenseService sysConfig *types.SystemConfig
alipayService *payment.AlipayService
wxpayService *payment.WxPayService
epayService *payment.EPayService
smsManager *sms.SmsManager
uploaderManager *oss.UploaderManager
smtpService *service.SmtpService
captchaService *service.CaptchaService
wxLoginService *service.WxLoginService
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler { func NewConfigHandler(
app *core.AppServer,
db *gorm.DB,
licenseService *service.LicenseService,
sysConfig *types.SystemConfig,
alipayService *payment.AlipayService,
wxpayService *payment.WxPayService,
epayService *payment.EPayService,
smsManager *sms.SmsManager,
uploaderManager *oss.UploaderManager,
smtpService *service.SmtpService,
captchaService *service.CaptchaService,
wxLoginService *service.WxLoginService,
) *ConfigHandler {
return &ConfigHandler{ return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db}, BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB, licenseService: licenseService,
licenseService: licenseService, sysConfig: sysConfig,
alipayService: alipayService,
wxpayService: wxpayService,
epayService: epayService,
smsManager: smsManager,
uploaderManager: uploaderManager,
smtpService: smtpService,
captchaService: captchaService,
wxLoginService: wxLoginService,
} }
} }
func (h *ConfigHandler) Update(c *gin.Context) { // RegisterRoutes 注册路由
var data struct { func (h *ConfigHandler) RegisterRoutes() {
Key string `json:"key"` rg := h.App.Engine.Group("/api/admin/config")
Config struct {
types.SystemConfig // 需要管理员登录的接口
Content string `json:"content,omitempty"` rg.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
Updated bool `json:"updated,omitempty"` {
} `json:"config"` rg.POST("update/base", h.UpdateBase)
ConfigBak types.SystemConfig `json:"config_bak,omitempty"` rg.POST("update/power", h.UpdatePower)
rg.POST("update/notice", h.UpdateNotice)
rg.POST("update/agreement", h.UpdateAgreement)
rg.POST("update/privacy", h.UpdatePrivacy)
rg.POST("update/mark_map", h.UpdateMarkMap)
rg.POST("update/captcha", h.UpdateCaptcha)
rg.POST("update/wx_login", h.UpdateWxLogin)
rg.POST("update/payment", h.UpdatePayment)
rg.POST("update/sms", h.UpdateSms)
rg.POST("update/oss", h.UpdateOss)
rg.POST("update/smtp", h.UpdateStmp)
rg.GET("get", h.Get)
rg.POST("license/active", h.Active)
rg.GET("license/get", h.GetLicense)
} }
}
// UpdateBase 更新基础配置
func (h *ConfigHandler) UpdateBase(c *gin.Context) {
var data types.BaseConfig
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
logger.Errorf("Update config failed: %v", err)
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
// ONLY authorized user can change the copyright // 未授权的话不允许修改版权
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy { license := h.licenseService.GetLicense()
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权") if !license.IsActive && data.Copyright != h.sysConfig.Base.Copyright {
resp.ERROR(c, "未授权系统不允许修改版权信息")
return return
} }
// 如果要启用图形验证码功能,则检查是否配置了 API 服务 // 未授权的话不允许修改 Logo
if data.Config.EnabledVerify && h.App.Config.ApiConfig.AppId == "" { if !license.IsActive && data.Logo != h.sysConfig.Base.Logo {
resp.ERROR(c, "启用验证码服务需要先配置 GeekAI 官方 API 服务 AppId 和 Token") resp.ERROR(c, "未授权系统不允许修改 Logo")
return return
} }
value := utils.JsonEncode(&data.Config) err := h.Update(types.ConfigKeySystem, data)
config := model.Config{Name: data.Key, Value: value} if err != nil {
res := h.DB.FirstOrCreate(&config, model.Config{Name: data.Key}) resp.ERROR(c, err.Error())
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return return
} }
if config.Id > 0 { h.sysConfig.Base = data
config.Value = value
res := h.DB.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// update config cache for AppServer resp.SUCCESS(c, data)
var cfg model.Config
h.DB.Where("name", data.Key).First(&cfg)
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Value, &h.App.SysConfig)
}
if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error())
return
}
logger.Infof("Update AppServer's config successfully: %v", config.Value)
}
resp.SUCCESS(c, config)
} }
// Get 获取指定的系统配置 // UpdatePower 更新系统配置
func (h *ConfigHandler) Get(c *gin.Context) { func (h *ConfigHandler) UpdatePower(c *gin.Context) {
key := c.Query("key") var data struct {
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
h.sysConfig.Base.InitPower = data.InitPower
h.sysConfig.Base.DailyPower = data.DailyPower
h.sysConfig.Base.InvitePower = data.InvitePower
h.sysConfig.Base.MjPower = data.MjPower
h.sysConfig.Base.MjActionPower = data.MjActionPower
h.sysConfig.Base.SdPower = data.SdPower
h.sysConfig.Base.SunoPower = data.SunoPower
h.sysConfig.Base.LumaPower = data.LumaPower
h.sysConfig.Base.KeLingPowers = data.KeLingPowers
err := h.Update(types.ConfigKeySystem, h.sysConfig.Base)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, h.sysConfig.Base)
}
// UpdateNotice 更新公告配置
func (h *ConfigHandler) UpdateNotice(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyNotice, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdateAgreement 更新用户协议配置
func (h *ConfigHandler) UpdateAgreement(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyAgreement, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdatePrivacy 更新隐私政策配置
func (h *ConfigHandler) UpdatePrivacy(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyPrivacy, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdateMarkMap 更新思维导图配置
func (h *ConfigHandler) UpdateMarkMap(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyMarkMap, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdateCaptcha 更新行为验证码配置
func (h *ConfigHandler) UpdateCaptcha(c *gin.Context) {
var data types.CaptchaConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyCaptcha, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
h.captchaService.UpdateConfig(data)
resp.SUCCESS(c, data)
}
// UpdatePayment 更新支付配置
func (h *ConfigHandler) UpdatePayment(c *gin.Context) {
var data types.PaymentConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyPayment, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 如果启用状态发生改变,则需要更新支付服务配置
if data.WxPay.Enabled {
err = h.wxpayService.UpdateConfig(&data.WxPay)
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
if data.Epay.Enabled {
h.epayService.UpdateConfig(&data.Epay)
}
if data.Alipay.Enabled {
err = h.alipayService.UpdateConfig(&data.Alipay)
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
h.sysConfig.Payment = data
resp.SUCCESS(c, data)
}
// UpdateSms 更新短信配置
func (h *ConfigHandler) UpdateSms(c *gin.Context) {
var data types.SMSConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeySms, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新服务配置
h.smsManager.UpdateConfig(data)
resp.SUCCESS(c, data)
}
// UpdateOss 更新 Oss 配置
func (h *ConfigHandler) UpdateOss(c *gin.Context) {
var data types.OSSConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyOss, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新服务配置
h.uploaderManager.UpdateConfig(data)
h.sysConfig.OSS = data
resp.SUCCESS(c, data)
}
// UpdateStmp 更新 Stmp 配置
func (h *ConfigHandler) UpdateStmp(c *gin.Context) {
var data types.SmtpConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeySmtp, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新服务配置
h.smtpService.UpdateConfig(&data)
h.sysConfig.SMTP = data
resp.SUCCESS(c, data)
}
// UpdateWxLogin 更新微信登录配置
func (h *ConfigHandler) UpdateWxLogin(c *gin.Context) {
var data types.WxLoginConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyWxLogin, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if data.Enabled {
h.wxLoginService.UpdateConfig(data)
}
h.sysConfig.WxLogin = data
resp.SUCCESS(c, data)
}
// Update 更新系统配置
func (h *ConfigHandler) Update(name string, value any) error {
var config model.Config var config model.Config
res := h.DB.Where("name", key).First(&config) err := h.DB.Where("name", name).First(&config).Error
if err != nil { // 不存在则创建
config.Name = name
config.Value = utils.JsonEncode(value)
return h.DB.Create(&config).Error
} else { // 存在则更新
config.Value = utils.JsonEncode(value)
return h.DB.Updates(&config).Error
}
}
// Get 获取指定名称的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
name := c.Query("key")
var config model.Config
res := h.DB.Where("name", name).First(&config)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, res.Error.Error()) resp.ERROR(c, res.Error.Error())
return return
} }
var value map[string]interface{} var value map[string]any
err := utils.JsonDecode(config.Value, &value) err := utils.JsonDecode(config.Value, &value)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
@@ -127,19 +420,21 @@ func (h *ConfigHandler) Active(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
info, err := host.Info()
err := h.licenseService.ActiveLicense(data.License)
license := h.licenseService.GetLicense()
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
if err := h.Update(types.ConfigKeyLicense, license); err != nil {
err = h.licenseService.ActiveLicense(data.License, info.HostID)
if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 更新系统配置
h.sysConfig.License = *license
resp.SUCCESS(c) resp.SUCCESS(c, license.MachineId)
} }
@@ -148,69 +443,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense() license := h.licenseService.GetLicense()
resp.SUCCESS(c, license) resp.SUCCESS(c, license)
} }
// FixData 修复数据
func (h *ConfigHandler) FixData(c *gin.Context) {
resp.ERROR(c, "当前升级版本没有数据需要修正!")
//var fixed bool
//version := "data_fix_4.1.4"
//err := h.levelDB.Get(version, &fixed)
//if err == nil || fixed {
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
// return
//}
//tx := h.DB.Begin()
//var users []model.User
//err = tx.Find(&users).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, user := range users {
// if user.Email != "" || user.Mobile != "" {
// continue
// }
// if utils.IsValidEmail(user.Username) {
// user.Email = user.Username
// } else if utils.IsValidMobile(user.Username) {
// user.Mobile = user.Username
// }
// err = tx.Save(&user).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//
//var orders []model.Order
//err = h.DB.Find(&orders).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, order := range orders {
// if order.PayWay == "支付宝" {
// order.PayWay = "alipay"
// order.PayType = "alipay"
// } else if order.PayWay == "微信支付" {
// order.PayWay = "wechat"
// order.PayType = "wxpay"
// } else if order.PayWay == "hupi" {
// order.PayType = "wxpay"
// }
// err = tx.Save(&order).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//tx.Commit()
//err = h.levelDB.Put(version, true)
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//resp.SUCCESS(c)
}

View File

@@ -13,10 +13,11 @@ import (
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"gorm.io/gorm" "gorm.io/gorm"
"time"
) )
type DashboardHandler struct { type DashboardHandler struct {
@@ -27,46 +28,161 @@ func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *DashboardHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/dashboard/")
group.GET("stats", h.Stats)
}
// statsVo 增加 recentOrders、recentUsers 字段
// 最近订单
type OrderBrief struct {
OrderNo string `json:"order_no"`
Amount float64 `json:"amount"`
CreatedAt time.Time `json:"created_at"`
}
// 最近用户
type UserBrief struct {
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
LastActive time.Time `json:"last_active"`
}
type statsVo struct { type statsVo struct {
Users int64 `json:"users"` Users int64 `json:"users"`
Chats int64 `json:"chats"` Chats int64 `json:"chats"`
Tokens int `json:"tokens"` Tokens int `json:"tokens"`
Income float64 `json:"income"` Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"` Chart map[string]map[string]float64 `json:"chart"`
TodayUsers int64 `json:"todayUsers"`
TodayChats int64 `json:"todayChats"`
TodayTokens int `json:"todayTokens"`
TodayIncome float64 `json:"todayIncome"`
TodayOrders int64 `json:"todayOrders"`
TodayImageJobs int64 `json:"todayImageJobs"`
TodayVideoJobs int64 `json:"todayVideoJobs"`
TodayMusicJobs int64 `json:"todayMusicJobs"`
Orders int64 `json:"orders"`
ImageJobs int64 `json:"imageJobs"`
VideoJobs int64 `json:"videoJobs"`
MusicJobs int64 `json:"musicJobs"`
RecentOrders []OrderBrief `json:"recentOrders"`
RecentUsers []UserBrief `json:"recentUsers"`
} }
func (h *DashboardHandler) Stats(c *gin.Context) { func (h *DashboardHandler) Stats(c *gin.Context) {
stats := statsVo{} stats := statsVo{}
// new users statistic
var userCount int64
now := time.Now() now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil { // 总用户数
stats.Users = userCount h.DB.Model(&model.User{}).Count(&stats.Users)
// 今日新增用户
h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&stats.TodayUsers)
// 总对话数
h.DB.Model(&model.ChatItem{}).Count(&stats.Chats)
// 今日新增对话
h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&stats.TodayChats)
// 总算力消耗
var powerLogs []model.PowerLog
h.DB.Where("mark = ?", types.PowerSub).Find(&powerLogs)
for _, item := range powerLogs {
stats.Tokens += item.Amount
} }
// new chats statistic // 今日算力消耗
var chatCount int64 var todayPowerLogs []model.PowerLog
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount) h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", zeroTime).Find(&todayPowerLogs)
if res.Error == nil { for _, item := range todayPowerLogs {
stats.Chats = chatCount stats.TodayTokens += item.Amount
} }
// tokens took stats // 总收入
var historyMessages []model.ChatMessage var allOrders []model.Order
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages) h.DB.Where("status = ?", types.OrderPaidSuccess).Find(&allOrders)
for _, item := range historyMessages { for _, item := range allOrders {
stats.Tokens += item.Tokens
}
// 订单收入
var orders []model.Order
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders {
stats.Income += item.Amount stats.Income += item.Amount
} }
// 今日收入
var todayOrders []model.Order
h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&todayOrders)
for _, item := range todayOrders {
stats.TodayIncome += item.Amount
}
// 订单总数
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Count(&stats.Orders)
// 今日订单数
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Count(&stats.TodayOrders)
// 图片生成任务统计
var mjJobs, sdJobs, dallJobs, jimengImageJobs int64
h.DB.Model(&model.MidJourneyJob{}).Count(&mjJobs)
h.DB.Model(&model.SdJob{}).Count(&sdJobs)
h.DB.Model(&model.DallJob{}).Count(&dallJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Count(&jimengImageJobs)
stats.ImageJobs = mjJobs + sdJobs + dallJobs + jimengImageJobs
logger.Info("stats.ImageJobs", stats.ImageJobs)
// 今日图片生成任务统计
var todayMjJobs, todaySdJobs, todayDallJobs, todayJimengImageJobs int64
h.DB.Model(&model.MidJourneyJob{}).Where("created_at > ?", zeroTime).Count(&todayMjJobs)
h.DB.Model(&model.SdJob{}).Where("created_at > ?", zeroTime).Count(&todaySdJobs)
h.DB.Model(&model.DallJob{}).Where("created_at > ?", zeroTime).Count(&todayDallJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Where("created_at > ?", zeroTime).Count(&todayJimengImageJobs)
stats.TodayImageJobs = todayMjJobs + todaySdJobs + todayDallJobs + todayJimengImageJobs
// 视频生成任务统计
var videoJobs, jimengVideoJobs int64
h.DB.Model(&model.VideoJob{}).Count(&videoJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Count(&jimengVideoJobs)
stats.VideoJobs = videoJobs + jimengVideoJobs
// 今日视频生成任务统计
var todayVideoJobs, todayJimengVideoJobs int64
h.DB.Model(&model.VideoJob{}).Where("created_at > ?", zeroTime).Count(&todayVideoJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Where("created_at > ?", zeroTime).Count(&todayJimengVideoJobs)
stats.TodayVideoJobs = todayVideoJobs + todayJimengVideoJobs
// 音乐生成任务统计
h.DB.Model(&model.SunoJob{}).Count(&stats.MusicJobs)
// 今日音乐生成任务统计
h.DB.Model(&model.SunoJob{}).Where("created_at > ?", zeroTime).Count(&stats.TodayMusicJobs)
// recentOrders: 最近10条已支付订单
var orderList []model.Order
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Order("created_at desc").Limit(10).Find(&orderList)
for _, o := range orderList {
stats.RecentOrders = append(stats.RecentOrders, OrderBrief{
OrderNo: o.OrderNo,
Amount: o.Amount,
CreatedAt: o.CreatedAt,
})
}
// recentUsers: 最近10个注册用户
var userList []model.User
h.DB.Model(&model.User{}).Order("created_at desc").Limit(10).Find(&userList)
for _, u := range userList {
lastActive := u.UpdatedAt
if lastActive.IsZero() {
lastActive = u.CreatedAt
}
stats.RecentUsers = append(stats.RecentUsers, UserBrief{
Nickname: u.Nickname,
Avatar: u.Avatar,
LastActive: lastActive,
})
}
// 统计7天的订单的图表 // 统计7天的订单的图表
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02") startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
var statsChart = make(map[string]map[string]float64) var statsChart = make(map[string]map[string]float64)
@@ -81,23 +197,29 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
// 统计用户7天增加的曲线 // 统计用户7天增加的曲线
var users []model.User var users []model.User
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users) err := h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users).Error
if res.Error == nil { if err == nil {
for _, item := range users { for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1 userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
} }
} }
// 统计7天Token 消耗 // 统计7天算力消耗
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages) var chartPowerLogs []model.PowerLog
for _, item := range historyMessages { err = h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", startDate).Find(&chartPowerLogs).Error
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens) if err == nil {
for _, item := range chartPowerLogs {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Amount)
}
} }
// 统计最近7天的订单 // 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders) var orders []model.Order
for _, item := range orders { err = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders).Error
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64() if err == nil {
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
} }
statsChart["users"] = userStatistic statsChart["users"] = userStatistic

View File

@@ -9,6 +9,7 @@ package admin
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/store/model" "geekai/store/model"
@@ -30,6 +31,21 @@ func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *FunctionHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/function/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.GET("token", h.GenToken)
}
}
func (h *FunctionHandler) Save(c *gin.Context) { func (h *FunctionHandler) Save(c *gin.Context) {
var data vo.Function var data vo.Function
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@@ -119,7 +135,6 @@ func (h *FunctionHandler) GenToken(c *gin.Context) {
}) })
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil { if err != nil {
logger.Error("error with generate token", err)
resp.ERROR(c) resp.ERROR(c)
return return
} }

View File

@@ -10,6 +10,7 @@ package admin
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
@@ -33,6 +34,20 @@ func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.User
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager} return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
} }
// RegisterRoutes 注册路由
func (h *ImageHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/image/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("list/mj", h.MjList)
group.POST("list/sd", h.SdList)
group.POST("list/dall", h.DallList)
group.GET("remove", h.Remove)
}
}
type imageQuery struct { type imageQuery struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Username string `json:"username"` Username string `json:"username"`

View File

@@ -21,18 +21,18 @@ import (
// AdminJimengHandler 管理后台即梦AI处理器 // AdminJimengHandler 管理后台即梦AI处理器
type AdminJimengHandler struct { type AdminJimengHandler struct {
handler.BaseHandler handler.BaseHandler
jimengService *jimeng.Service jimengClient *jimeng.Client
userService *service.UserService userService *service.UserService
uploader *oss.UploaderManager uploader *oss.UploaderManager
} }
// NewAdminJimengHandler 创建管理后台即梦AI处理器 // NewAdminJimengHandler 创建管理后台即梦AI处理器
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler { func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengClient *jimeng.Client, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
return &AdminJimengHandler{ return &AdminJimengHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db}, BaseHandler: handler.BaseHandler{App: app, DB: db},
jimengService: jimengService, jimengClient: jimengClient,
userService: userService, userService: userService,
uploader: uploader, uploader: uploader,
} }
} }
@@ -43,7 +43,6 @@ func (h *AdminJimengHandler) RegisterRoutes() {
rg.GET("/jobs/:id", h.JobDetail) rg.GET("/jobs/:id", h.JobDetail)
rg.POST("/jobs/remove", h.BatchRemove) rg.POST("/jobs/remove", h.BatchRemove)
rg.GET("/stats", h.Stats) rg.GET("/stats", h.Stats)
rg.GET("/config", h.GetConfig)
rg.POST("/config/update", h.UpdateConfig) rg.POST("/config/update", h.UpdateConfig)
} }
@@ -213,12 +212,6 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
resp.SUCCESS(c, result) resp.SUCCESS(c, result)
} }
// GetConfig 获取即梦AI配置
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
jimengConfig := h.jimengService.GetConfig()
resp.SUCCESS(c, jimengConfig)
}
// UpdateConfig 更新即梦AI配置 // UpdateConfig 更新即梦AI配置
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) { func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
var req types.JimengConfig var req types.JimengConfig
@@ -266,31 +259,35 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
// 保存配置 // 保存配置
tx := h.DB.Begin() tx := h.DB.Begin()
value := utils.JsonEncode(&req) value := utils.JsonEncode(&req)
config := model.Config{Name: "jimeng", Value: value} var exist model.Config
tx.Where("name", types.ConfigKeyJimeng).First(&exist)
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error if exist.Id > 0 {
if err != nil { exist.Value = value
resp.ERROR(c, "保存配置失败: "+err.Error()) err := tx.Updates(&exist).Error
return
}
if config.Id > 0 {
config.Value = value
err = tx.Updates(&config).Error
if err != nil { if err != nil {
resp.ERROR(c, "更新配置失败: "+err.Error()) resp.ERROR(c, "更新配置失败: "+err.Error())
return return
} }
} else {
exist.Name = types.ConfigKeyJimeng
exist.Value = value
err := tx.Create(&exist).Error
if err != nil {
resp.ERROR(c, "创建配置失败: "+err.Error())
return
}
} }
// 更新服务中的客户端配置 // 更新服务中的客户端配置
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey) err := h.jimengClient.UpdateConfig(req)
if updateErr != nil { if err != nil {
resp.ERROR(c, updateErr.Error()) resp.ERROR(c, err.Error())
tx.Rollback() tx.Rollback()
return return
} }
tx.Commit() tx.Commit()
h.App.SysConfig.Jimeng = req
resp.SUCCESS(c, gin.H{"message": "配置更新成功"}) resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
} }

View File

@@ -10,6 +10,7 @@ package admin
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
@@ -33,6 +34,19 @@ func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.User
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager} return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
} }
// RegisterRoutes 注册路由
func (h *MediaHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/media/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("suno", h.SunoList)
group.POST("videos", h.Videos)
group.GET("remove", h.Remove)
}
}
type mediaQuery struct { type mediaQuery struct {
Type string `json:"type"` // 任务类型 luma, keling Type string `json:"type"` // 任务类型 luma, keling
Prompt string `json:"prompt"` Prompt string `json:"prompt"`

View File

@@ -27,6 +27,16 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *MenuHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/menu/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}
func (h *MenuHandler) Save(c *gin.Context) { func (h *MenuHandler) Save(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`

View File

@@ -0,0 +1,333 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service/moderation"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ModerationHandler struct {
handler.BaseHandler
sysConfig *types.SystemConfig
moderationManager *moderation.ServiceManager
}
func NewModerationHandler(app *core.AppServer, db *gorm.DB, sysConfig *types.SystemConfig, moderationManager *moderation.ServiceManager) *ModerationHandler {
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, sysConfig: sysConfig, moderationManager: moderationManager}
}
// RegisterRoutes 注册路由
func (h *ModerationHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/moderation/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.POST("batch-remove", h.BatchRemove)
group.GET("source-list", h.GetSourceList)
group.POST("config", h.UpdateModeration)
group.POST("test", h.TestModeration)
}
}
// List 获取文本审核记录列表
func (h *ModerationHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`
Source string `json:"source"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
// 构建查询条件
if data.Username != "" {
// 通过用户名查找用户ID
var user model.User
if err := h.DB.Where("username LIKE ?", "%"+data.Username+"%").First(&user).Error; err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Source != "" {
session = session.Where("source", data.Source)
}
if data.StartDate != "" && data.EndDate != "" {
startTime := data.StartDate + " 00:00:00"
endTime := data.EndDate + " 23:59:59"
session = session.Where("created_at >= ? AND created_at <= ?", startTime, endTime)
}
// 统计总数
var total int64
session.Model(&model.Moderation{}).Count(&total)
// 分页
page := data.Page
pageSize := data.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
// 查询数据
var items []model.Moderation
err := session.Order("id DESC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 获取用户信息
userIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
}
var users []model.User
if len(userIds) > 0 {
h.DB.Where("id IN ?", userIds).Find(&users)
}
userMap := make(map[uint]string)
for _, user := range users {
userMap[user.Id] = user.Username
}
// 转换为响应数据
list := make([]map[string]any, 0)
for _, item := range items {
var moderation types.ModerationResult
err := utils.JsonDecode(item.Result, &moderation)
if err != nil {
continue
}
var result []string
for value, label := range types.ModerationCategories {
if moderation.Categories[value] {
result = append(result, label)
}
}
list = append(list, map[string]any{
"id": item.Id,
"user_id": item.UserId,
"username": userMap[item.UserId],
"source": item.Source,
"input": item.Input,
"output": item.Output,
"result": result,
"created_at": item.CreatedAt.Unix(),
})
}
resp.SUCCESS(c, map[string]any{
"items": list,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *ModerationHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.Moderation{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// BatchRemove 批量删除文本审核记录
func (h *ModerationHandler) BatchRemove(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if len(data.Ids) == 0 {
resp.ERROR(c, "请选择要删除的记录")
return
}
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.Moderation{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// 获取 source 列表
func (h *ModerationHandler) GetSourceList(c *gin.Context) {
sources := []gin.H{
{
"id": types.ModerationSourceChat,
"name": "AI对话",
},
{
"id": types.ModerationSourceMJ,
"name": "Midjourney 绘图",
},
{
"id": types.ModerationSourceDalle,
"name": "Dalle 绘图",
},
{
"id": types.ModerationSourceSD,
"name": "StableDiffusion 绘图",
},
{
"id": types.ModerationSourceSuno,
"name": "Suno 音乐",
},
{
"id": types.ModerationSourceVideo,
"name": "视频生成",
},
{
"id": types.ModerationSourceJiMeng,
"name": "即梦AI",
},
}
resp.SUCCESS(c, sources)
}
// UpdateModeration 更新文本审查配置
func (h *ModerationHandler) UpdateModeration(c *gin.Context) {
var data types.ModerationConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var config model.Config
err := h.DB.Where("name", types.ConfigKeyModeration).First(&config).Error
if err != nil {
config.Name = types.ConfigKeyModeration
config.Value = utils.JsonEncode(data)
err = h.DB.Create(&config).Error
} else {
config.Value = utils.JsonEncode(data)
err = h.DB.Updates(&config).Error
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
h.moderationManager.UpdateConfig(data)
h.sysConfig.Moderation = data
resp.SUCCESS(c, data)
}
// 测试结果类型,用于前端显示
type ModerationTestResult struct {
IsAbnormal bool `json:"isAbnormal"`
Details []ModerationTestDetail `json:"details"`
}
type ModerationTestDetail struct {
Category string `json:"category"`
Description string `json:"description"`
Confidence string `json:"confidence"`
IsCategory bool `json:"isCategory"`
}
// TestModeration 测试文本审查服务
func (h *ModerationHandler) TestModeration(c *gin.Context) {
var data struct {
Text string `json:"text"`
Service string `json:"service"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Text == "" {
resp.ERROR(c, "测试文本不能为空")
return
}
// 检查是否启用了文本审查
if !h.sysConfig.Moderation.Enable {
resp.ERROR(c, "文本审查服务未启用")
return
}
// 获取当前激活的审核服务
service := h.moderationManager.GetService()
// 执行文本审核
result, err := service.Moderate(data.Text)
if err != nil {
resp.ERROR(c, "审核服务调用失败: "+err.Error())
return
}
// 转换为前端需要的格式
testResult := ModerationTestResult{
IsAbnormal: result.Flagged,
Details: make([]ModerationTestDetail, 0),
}
// 构建详细信息
for category, description := range types.ModerationCategories {
score := result.CategoryScores[category]
isCategory := result.Categories[category]
testResult.Details = append(testResult.Details, ModerationTestDetail{
Category: category,
Description: description,
Confidence: fmt.Sprintf("%.2f", score),
IsCategory: isCategory,
})
}
resp.SUCCESS(c, testResult)
}

View File

@@ -29,6 +29,14 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *OrderHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/order/")
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
}
func (h *OrderHandler) List(c *gin.Context) { func (h *OrderHandler) List(c *gin.Context) {
var data struct { var data struct {
OrderNo string `json:"order_no"` OrderNo string `json:"order_no"`
@@ -68,16 +76,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix() order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay] payChannel, ok := types.PayChannel[item.Channel]
if !ok { if !ok {
payMethod = item.PayWay payChannel = item.Channel
} }
payName, ok := types.PayNames[item.PayType] payWays, ok := types.PayWays[item.PayWay]
if !ok { if !ok {
payName = item.PayWay payWays = item.PayWay
} }
order.PayMethod = payMethod order.ChannelName = payChannel
order.PayName = payName order.PayName = payWays
list = append(list, order) list = append(list, order)
} else { } else {
logger.Error(err) logger.Error(err)
@@ -121,8 +129,8 @@ func (h *OrderHandler) Clear(c *gin.Context) {
} }
deleteIds := make([]uint, 0) deleteIds := make([]uint, 0)
for _, order := range orders { for _, order := range orders {
// 只删除 15 分钟内的未支付订单 // 只删除超时的未支付订单
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) { if time.Now().After(order.CreatedAt.Add(time.Minute * time.Duration(h.App.SysConfig.Base.OrderPayTimeout))) {
deleteIds = append(deleteIds, order.Id) deleteIds = append(deleteIds, order.Id)
} }
} }

View File

@@ -28,6 +28,12 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *PowerLogHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}
func (h *PowerLogHandler) List(c *gin.Context) { func (h *PowerLogHandler) List(c *gin.Context) {
var data struct { var data struct {
Username string `json:"username"` Username string `json:"username"`

View File

@@ -15,9 +15,10 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"time"
) )
type ProductHandler struct { type ProductHandler struct {
@@ -28,14 +29,22 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ProductHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/product/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}
func (h *ProductHandler) Save(c *gin.Context) { func (h *ProductHandler) Save(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Price float64 `json:"price"` Price float64 `json:"price"`
Discount float64 `json:"discount"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Days int `json:"days"`
Power int `json:"power"` Power int `json:"power"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
} }
@@ -45,12 +54,10 @@ func (h *ProductHandler) Save(c *gin.Context) {
} }
item := model.Product{ item := model.Product{
Name: data.Name, Name: data.Name,
Price: data.Price, Price: data.Price,
Discount: data.Discount, Power: data.Power,
Days: data.Days, Enabled: data.Enabled}
Power: data.Power,
Enabled: data.Enabled}
item.Id = data.Id item.Id = data.Id
if item.Id > 0 { if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0) item.CreatedAt = time.Unix(data.CreatedAt, 0)

View File

@@ -29,6 +29,16 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *RedeemHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/redeem/")
group.GET("list", h.List)
group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("export", h.Export)
}
func (h *RedeemHandler) List(c *gin.Context) { func (h *RedeemHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20) pageSize := h.GetInt(c, "page_size", 20)

View File

@@ -9,6 +9,7 @@ package admin
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/handler" "geekai/handler"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
@@ -28,6 +29,17 @@ func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderMan
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager} return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
} }
// RegisterRoutes 注册路由
func (h *UploadHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/upload")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("", h.Upload)
}
}
func (h *UploadHandler) Upload(c *gin.Context) { func (h *UploadHandler) Upload(c *gin.Context) {
// 判断文件大小 // 判断文件大小
f, err := c.FormFile("file") f, err := c.FormFile("file")
@@ -36,7 +48,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
return return
} }
if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 { if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 {
resp.ERROR(c, "文件大小超过限制") resp.ERROR(c, "文件大小超过限制")
return return
} }

View File

@@ -10,6 +10,7 @@ package admin
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
@@ -19,10 +20,9 @@ import (
"geekai/utils/resp" "geekai/utils/resp"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -36,6 +36,22 @@ func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.Li
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli} return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
} }
// RegisterRoutes 注册路由
func (h *UserHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/user/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.GET("remove", h.Remove)
group.GET("loginLog", h.LoginLog)
group.GET("genLoginLink", h.GenLoginLink)
group.POST("resetPass", h.ResetPass)
}
}
// List 用户列表 // List 用户列表
func (h *UserHandler) List(c *gin.Context) { func (h *UserHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)

View File

@@ -15,9 +15,10 @@ import (
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"gorm.io/gorm"
"strings" "strings"
"gorm.io/gorm"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -69,6 +70,14 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
return uint(utils.IntValue(utils.InterfaceToString(userId), 0)) return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
} }
func (h *BaseHandler) GetAdminId(c *gin.Context) uint {
userId, ok := c.Get(types.AdminUserID)
if !ok {
return 0
}
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) IsLogin(c *gin.Context) bool { func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0 return h.GetLoginUserId(c) > 0
} }

View File

@@ -8,23 +8,45 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// 今日头条函数实现
type CaptchaHandler struct { type CaptchaHandler struct {
App *core.AppServer
service *service.CaptchaService service *service.CaptchaService
} }
func NewCaptchaHandler(s *service.CaptchaService) *CaptchaHandler { func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler {
return &CaptchaHandler{service: s} return &CaptchaHandler{App: app, service: s}
}
// RegisterRoutes 注册路由
func (h *CaptchaHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/captcha/")
// 无需授权的接口
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
group.GET("config", h.GetConfig)
}
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
resp.SUCCESS(c, gin.H{"enabled": h.service.GetConfig().Enabled, "type": h.service.GetConfig().Type})
} }
func (h *CaptchaHandler) Get(c *gin.Context) { func (h *CaptchaHandler) Get(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
data, err := h.service.Get() data, err := h.service.Get()
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
@@ -36,6 +58,11 @@ func (h *CaptchaHandler) Get(c *gin.Context) {
// Check verify the captcha data // Check verify the captcha data
func (h *CaptchaHandler) Check(c *gin.Context) { func (h *CaptchaHandler) Check(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
var data struct { var data struct {
Key string `json:"key"` Key string `json:"key"`
Dots string `json:"dots"` Dots string `json:"dots"`
@@ -55,6 +82,11 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
// SlideGet 获取滑动验证图片 // SlideGet 获取滑动验证图片
func (h *CaptchaHandler) SlideGet(c *gin.Context) { func (h *CaptchaHandler) SlideGet(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
data, err := h.service.SlideGet() data, err := h.service.SlideGet()
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
@@ -66,6 +98,11 @@ func (h *CaptchaHandler) SlideGet(c *gin.Context) {
// SlideCheck 滑动验证结果校验 // SlideCheck 滑动验证结果校验
func (h *CaptchaHandler) SlideCheck(c *gin.Context) { func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
var data struct { var data struct {
Key string `json:"key"` Key string `json:"key"`
X int `json:"x"` X int `json:"x"`

View File

@@ -9,6 +9,7 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
@@ -19,18 +20,31 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
type ChatRoleHandler struct { type ChatAppHandler struct {
BaseHandler BaseHandler
} }
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &ChatAppHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatAppHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/app/")
group.GET("list", h.List)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateApp)
}
} }
// List 获取用户聊天应用列表 // List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) { func (h *ChatAppHandler) List(c *gin.Context) {
tid := h.GetInt(c, "tid", 0) tid := h.GetInt(c, "tid", 0)
var roles []model.ChatRole var roles []model.ChatApp
session := h.DB.Where("enable", true) session := h.DB.Where("enable", true)
if tid > 0 { if tid > 0 {
session = session.Where("tid", tid) session = session.Where("tid", tid)
@@ -41,9 +55,9 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return return
} }
var roleVos = make([]vo.ChatRole, 0) var roleVos = make([]vo.ChatApp, 0)
for _, r := range roles { for _, r := range roles {
var v vo.ChatRole var v vo.ChatApp
err := utils.CopyObject(r, &v) err := utils.CopyObject(r, &v)
if err == nil { if err == nil {
v.Id = r.Id v.Id = r.Id
@@ -54,10 +68,10 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
} }
// ListByUser 获取用户添加的角色列表 // ListByUser 获取用户添加的角色列表
func (h *ChatRoleHandler) ListByUser(c *gin.Context) { func (h *ChatAppHandler) ListByUser(c *gin.Context) {
id := h.GetInt(c, "id", 0) id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
var roles []model.ChatRole var roles []model.ChatApp
session := h.DB.Where("enable", true) session := h.DB.Where("enable", true)
// 如果用户没登录,则获取所有角色 // 如果用户没登录,则获取所有角色
if userId > 0 { if userId > 0 {
@@ -86,9 +100,9 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
return return
} }
var roleVos = make([]vo.ChatRole, 0) var roleVos = make([]vo.ChatApp, 0)
for _, r := range roles { for _, r := range roles {
var v vo.ChatRole var v vo.ChatApp
err := utils.CopyObject(r, &v) err := utils.CopyObject(r, &v)
if err == nil { if err == nil {
v.Id = r.Id v.Id = r.Id
@@ -98,8 +112,8 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
resp.SUCCESS(c, roleVos) resp.SUCCESS(c, roleVos)
} }
// UpdateRole 更新用户聊天角色 // UpdateApp 更新用户聊天应用
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) { func (h *ChatAppHandler) UpdateApp(c *gin.Context) {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)

View File

@@ -19,6 +19,12 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ChatAppTypeHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/app/type/")
group.GET("list", h.List)
}
// List 获取App类型列表 // List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) { func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType var items []model.AppType

View File

@@ -14,8 +14,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
@@ -39,6 +41,7 @@ import (
const ( const (
ChatEventStart = "start" ChatEventStart = "start"
ChatEventEnd = "end" ChatEventEnd = "end"
ChatEventComplete = "complete"
ChatEventError = "error" ChatEventError = "error"
ChatEventMessageDelta = "message_delta" ChatEventMessageDelta = "message_delta"
ChatEventTitle = "title" ChatEventTitle = "title"
@@ -54,44 +57,69 @@ type ChatInput struct {
Stream bool `json:"stream"` Stream bool `json:"stream"`
Files []vo.File `json:"files"` Files []vo.File `json:"files"`
ChatModel model.ChatModel `json:"chat_model,omitempty"` ChatModel model.ChatModel `json:"chat_model,omitempty"`
ChatRole model.ChatRole `json:"chat_role,omitempty"` ChatRole model.ChatApp `json:"chat_role,omitempty"`
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID用于重新生成答案的时候过滤上下文 LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID用于重新生成答案的时候过滤上下文
} }
type ChatHandler struct { type ChatHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
licenseService *service.LicenseService licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
return &ChatHandler{ return &ChatHandler{
BaseHandler: BaseHandler{App: app, DB: db}, BaseHandler: BaseHandler{App: app, DB: db},
redis: redis, redis: redis,
uploadManager: manager, uploadManager: manager,
licenseService: licenseService, licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
userService: userService, userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由
func (h *ChatHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/chat/")
// 聊天接口不需要授权已在authConfig中配置
group.Any("message", h.Chat)
// 其他接口需要用户授权
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.GET("detail", h.Detail)
group.POST("update", h.Update)
group.GET("remove", h.Remove)
group.GET("history", h.History)
group.GET("clear", h.Clear)
group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
group.POST("tts", h.TextToSpeech)
} }
} }
// Chat 处理聊天请求 // Chat 处理聊天请求
func (h *ChatHandler) Chat(c *gin.Context) { func (h *ChatHandler) Chat(c *gin.Context) {
var input ChatInput
if err := c.ShouldBindJSON(&input); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 设置SSE响应头 // 设置SSE响应头
c.Header("Prompt-Type", "text/event-stream") c.Header("Prompt-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no") c.Header("X-Accel-Buffering", "no")
var input ChatInput
if err := c.ShouldBindJSON(&input); err != nil {
pushMessage(c, ChatEventError, types.InvalidArgs)
c.Abort()
return
}
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
@@ -113,7 +141,7 @@ func (h *ChatHandler) Chat(c *gin.Context) {
} }
// 验证聊天角色 // 验证聊天角色
var chatRole model.ChatRole var chatRole model.ChatApp
err := h.DB.First(&chatRole, input.RoleId).Error err := h.DB.First(&chatRole, input.RoleId).Error
if err != nil || !chatRole.Enable { if err != nil || !chatRole.Enable {
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!") pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
@@ -166,7 +194,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
} }
if userVo.Power < input.ChatModel.Power { if userVo.Power < input.ChatModel.Power {
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d[立即购买](/member)。", userVo.Power, input.ChatModel.Power) return fmt.Errorf("您的算力不足,请购买算力。")
} }
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() { if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
@@ -229,17 +257,24 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
// 加载聊天上下文 // 加载聊天上下文
chatCtx := make([]any, 0) chatCtx := make([]any, 0)
messages := make([]any, 0) messages := make([]any, 0)
if h.App.SysConfig.EnableContext { if h.App.SysConfig.Base.EnableContext {
_ = utils.JsonDecode(input.ChatRole.Context, &messages) _ = utils.JsonDecode(input.ChatRole.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 { if h.App.SysConfig.Base.ContextDeep > 0 {
var historyMessages []model.ChatMessage var historyMessages []model.ChatMessage
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId) dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
if input.LastMsgId > 0 { // 重新生成逻辑 if input.LastMsgId > 0 { // 重新生成逻辑
var lastMessage model.ChatMessage
err = dbSession.Where("id <= ?", input.LastMsgId).Where("type", types.PromptMsg).First(&lastMessage).Error
if err != nil {
input.LastMsgId = 0
} else {
input.LastMsgId = lastMessage.Id
}
dbSession = dbSession.Where("id < ?", input.LastMsgId) dbSession = dbSession.Where("id < ?", input.LastMsgId)
// 删除对应的聊天记录 // 删除对应的聊天记录
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{}) h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
} }
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
if err == nil { if err == nil {
for i := len(historyMessages) - 1; i >= 0; i-- { for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i] msg := historyMessages[i]
@@ -267,7 +302,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
} }
// 上下文的深度超出了模型的最大上下文深度 // 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep { if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep {
break break
} }
@@ -277,6 +312,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
} }
reqMgs := make([]any, 0) reqMgs := make([]any, 0)
// 添加引导提示词,防止模型生成违规内容
if h.App.SysConfig.Moderation.EnableGuide {
reqMgs = append(reqMgs, map[string]any{
"role": "system",
"content": h.App.SysConfig.Moderation.GuidePrompt,
})
}
for i := len(chatCtx) - 1; i >= 0; i-- { for i := len(chatCtx) - 1; i >= 0; i-- {
reqMgs = append(reqMgs, chatCtx[i]) reqMgs = append(reqMgs, chatCtx[i])
} }
@@ -295,16 +338,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}, },
}) })
} else { } else {
// 如果不是逆向模型,则提取文件内容 // 处理文件,提取文件内容
modelValue := input.ChatModel.Value content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) { if err != nil {
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost) logger.Error("error with read file: ", err)
if err != nil { continue
logger.Error("error with read file: ", err) } else {
continue fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
} else { logger.Debugf("fileContents: %s", fileContents)
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
} }
} }
} }
@@ -320,16 +361,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
} }
if len(imgList) > 0 { if len(imgList) > 0 {
imgList = append(imgList, map[string]interface{}{ imgList = append(imgList, map[string]any{
"type": "text", "type": "text",
"text": input.Prompt, "text": input.Prompt,
}) })
req.Messages = append(reqMgs, map[string]interface{}{ req.Messages = append(reqMgs, map[string]any{
"role": "user", "role": "user",
"content": imgList, "content": imgList,
}) })
} else { } else {
req.Messages = append(reqMgs, map[string]interface{}{ req.Messages = append(reqMgs, map[string]any{
"role": "user", "role": "user",
"content": finalPrompt, "content": finalPrompt,
}) })
@@ -445,7 +486,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) { func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly // if the chat model bind a KEY, use it directly
if input.ChatModel.KeyId > 0 { if input.ChatModel.KeyId > 0 {
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey) h.DB.Where("id", input.ChatModel.KeyId).Where("enabled", true).Find(apiKey)
} else { // use the last unused key } else { // use the last unused key
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey) h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
} }
@@ -516,6 +557,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens
} }
func (h *ChatHandler) saveChatHistory( func (h *ChatHandler) saveChatHistory(
c *gin.Context,
req types.ApiRequest, req types.ApiRequest,
usage Usage, usage Usage,
message types.Message, message types.Message,
@@ -524,6 +566,34 @@ func (h *ChatHandler) saveChatHistory(
promptCreatedAt time.Time, promptCreatedAt time.Time,
replyCreatedAt time.Time) { replyCreatedAt time.Time) {
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(usage.Content)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
logger.Debugf("moderationResult: %+v", moderationResult)
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: userVo.Id,
Source: types.ModerationSourceChat,
Input: usage.Prompt,
Output: usage.Content,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
pushMessage(c, ChatEventError, "很抱歉内容触发敏感词预警AI 无法回答!!!")
// 更新用户算力
if input.ChatModel.Power > 0 {
h.subUserPower(userVo, input, 0, 0)
}
return
}
}
// 追加聊天记录 // 追加聊天记录
// for prompt // for prompt
var promptTokens, replyTokens, totalTokens int var promptTokens, replyTokens, totalTokens int
@@ -586,6 +656,22 @@ func (h *ChatHandler) saveChatHistory(
logger.Error("failed to save reply history message: ", err) logger.Error("failed to save reply history message: ", err)
} }
// 发送完整聊天记录给前端
var messageVo vo.ChatMessage
err = utils.CopyObject(historyReplyMsg, &messageVo)
if err == nil {
// 解析内容
var content vo.MsgContent
err = utils.JsonDecode(historyReplyMsg.Content, &content)
if err != nil {
content.Text = historyReplyMsg.Content
}
messageVo.Content = content
messageVo.CreatedAt = historyReplyMsg.CreatedAt.Unix()
messageVo.UpdatedAt = historyReplyMsg.UpdatedAt.Unix()
pushMessage(c, ChatEventComplete, messageVo)
}
// 更新用户算力 // 更新用户算力
if input.ChatModel.Power > 0 { if input.ChatModel.Power > 0 {
h.subUserPower(userVo, input, promptTokens, replyTokens) h.subUserPower(userVo, input, promptTokens, replyTokens)
@@ -710,221 +796,3 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
logger.Error("写入音频数据到响应失败:", err) logger.Error("写入音频数据到响应失败:", err)
} }
} }
// // OPenAI 消息发送实现
// func (h *ChatHandler) sendOpenAiMessage(
// req types.ApiRequest,
// userVo vo.User,
// ctx context.Context,
// session *types.ChatSession,
// role model.ChatRole,
// prompt string,
// c *gin.Context) error {
// promptCreatedAt := time.Now() // 记录提问时间
// start := time.Now()
// var apiKey = model.ApiKey{}
// response, err := h.doRequest(ctx, req, session, &apiKey)
// logger.Info("HTTP请求完成耗时", time.Since(start))
// if err != nil {
// if strings.Contains(err.Error(), "context canceled") {
// return fmt.Errorf("用户取消了请求:%s", prompt)
// } else if strings.Contains(err.Error(), "no available key") {
// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
// }
// return err
// } else {
// defer response.Body.Close()
// }
// if response.StatusCode != 200 {
// body, _ := io.ReadAll(response.Body)
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
// }
// contentType := response.Header.Get("Prompt-Type")
// if strings.Contains(contentType, "text/event-stream") {
// replyCreatedAt := time.Now() // 记录回复时间
// // 循环读取 Chunk 消息
// var message = types.Message{Role: "assistant"}
// var contents = make([]string, 0)
// var function model.Function
// var toolCall = false
// var arguments = make([]string, 0)
// var reasoning = false
// pushMessage(c, ChatEventStart, "开始响应")
// scanner := bufio.NewScanner(response.Body)
// for scanner.Scan() {
// line := scanner.Text()
// if !strings.Contains(line, "data:") || len(line) < 30 {
// continue
// }
// var responseBody = types.ApiResponse{}
// err = json.Unmarshal([]byte(line[6:]), &responseBody)
// if err != nil { // 数据解析出错
// return errors.New(line)
// }
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
// continue
// }
// if responseBody.Choices[0].Delta.Prompt == nil &&
// responseBody.Choices[0].Delta.ToolCalls == nil &&
// responseBody.Choices[0].Delta.ReasoningContent == "" {
// continue
// }
// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
// pushMessage(c, ChatEventError, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
// break
// }
// var tool types.ToolCall
// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
// tool = responseBody.Choices[0].Delta.ToolCalls[0]
// if toolCall && tool.Function.Name == "" {
// arguments = append(arguments, tool.Function.Arguments)
// continue
// }
// }
// // 兼容 Function Call
// fun := responseBody.Choices[0].Delta.FunctionCall
// if fun.Name != "" {
// tool = *new(types.ToolCall)
// tool.Function.Name = fun.Name
// } else if toolCall {
// arguments = append(arguments, fun.Arguments)
// continue
// }
// if !utils.IsEmptyValue(tool) {
// res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
// if res.Error == nil {
// toolCall = true
// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": callMsg,
// })
// contents = append(contents, callMsg)
// }
// continue
// }
// if responseBody.Choices[0].FinishReason == "tool_calls" ||
// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
// break
// }
// // output stopped
// if responseBody.Choices[0].FinishReason != "" {
// break // 输出完成或者输出中断了
// } else { // 正常输出结果
// // 兼容思考过程
// if responseBody.Choices[0].Delta.ReasoningContent != "" {
// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
// if !reasoning {
// reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
// reasoning = true
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": reasoningContent,
// })
// contents = append(contents, reasoningContent)
// } else if responseBody.Choices[0].Delta.Prompt != "" {
// finalContent := responseBody.Choices[0].Delta.Prompt
// if reasoning {
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
// reasoning = false
// }
// contents = append(contents, utils.InterfaceToString(finalContent))
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": finalContent,
// })
// }
// }
// } // end for
// if err := scanner.Err(); err != nil {
// if strings.Contains(err.Error(), "context canceled") {
// logger.Info("用户取消了请求:", prompt)
// } else {
// logger.Error("信息读取出错:", err)
// }
// }
// if toolCall { // 调用函数完成任务
// params := make(map[string]any)
// _ = utils.JsonDecode(strings.Join(arguments, ""), &params)
// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
// params["user_id"] = userVo.Id
// var apiRes types.BizVo
// r, err := req2.C().R().SetHeader("Body-Type", "application/json").
// SetHeader("Authorization", function.Token).
// SetBody(params).Post(function.Action)
// errMsg := ""
// if err != nil {
// errMsg = err.Error()
// } else {
// all, _ := io.ReadAll(r.Body)
// err = json.Unmarshal(all, &apiRes)
// if err != nil {
// errMsg = err.Error()
// } else if apiRes.Code != types.Success {
// errMsg = apiRes.Message
// }
// }
// if errMsg != "" {
// errMsg = "调用函数工具出错:" + errMsg
// contents = append(contents, errMsg)
// } else {
// errMsg = utils.InterfaceToString(apiRes.Data)
// contents = append(contents, errMsg)
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": errMsg,
// })
// }
// // 消息发送成功
// if len(contents) > 0 {
// usage := Usage{
// Prompt: prompt,
// Prompt: strings.Join(contents, ""),
// PromptTokens: 0,
// CompletionTokens: 0,
// TotalTokens: 0,
// }
// message.Prompt = usage.Prompt
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
// }
// } else {
// var respVo OpenAIResVo
// body, err := io.ReadAll(response.Body)
// if err != nil {
// return fmt.Errorf("读取响应失败:%v", body)
// }
// err = json.Unmarshal(body, &respVo)
// if err != nil {
// return fmt.Errorf("解析响应失败:%v", body)
// }
// content := respVo.Choices[0].Message.Prompt
// if strings.HasPrefix(req.Model, "o1-") {
// content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": content,
// })
// respVo.Usage.Prompt = prompt
// respVo.Usage.Prompt = content
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
// }
// return nil
// }

View File

@@ -42,9 +42,9 @@ func (h *ChatHandler) List(c *gin.Context) {
modelValues = append(modelValues, chat.Model) modelValues = append(modelValues, chat.Model)
} }
var roles []model.ChatRole var roles []model.ChatApp
var models []model.ChatModel var models []model.ChatModel
roleMap := make(map[uint]model.ChatRole) roleMap := make(map[uint]model.ChatApp)
modelMap := make(map[string]model.ChatModel) modelMap := make(map[string]model.ChatModel)
h.DB.Where("id IN ?", roleIds).Find(&roles) h.DB.Where("id IN ?", roleIds).Find(&roles)
h.DB.Where("value IN ?", modelValues).Find(&models) h.DB.Where("value IN ?", modelValues).Find(&models)
@@ -205,7 +205,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
} }
// 填充角色名称 // 填充角色名称
var role model.ChatRole var role model.ChatApp
res = h.DB.Where("id", chatItem.RoleId).First(&role) res = h.DB.Where("id", chatItem.RoleId).First(&role)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "Role not found") resp.ERROR(c, "Role not found")

View File

@@ -26,6 +26,12 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ChatModelHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/model/")
group.GET("list", h.List)
}
// List 模型列表 // List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) { func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel var items []model.ChatModel

View File

@@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
TotalTokens: 0, TotalTokens: 0,
} }
message.Content = usage.Content message.Content = usage.Content
h.saveChatHistory(req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt) h.saveChatHistory(c, req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
} }
} else { // 非流式输出 } else { // 非流式输出
var respVo OpenAIResVo var respVo OpenAIResVo
@@ -242,7 +242,7 @@ func (h *ChatHandler) sendOpenAiMessage(
pushMessage(c, "text", content) pushMessage(c, "text", content)
respVo.Usage.Prompt = input.Prompt respVo.Usage.Prompt = input.Prompt
respVo.Usage.Content = content respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now()) h.saveChatHistory(c, req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
} }
return nil return nil

View File

@@ -27,6 +27,15 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService} return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
} }
// RegisterRoutes 注册路由
func (h *ConfigHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/config/")
// 无需授权的接口
group.GET("get", h.Get)
group.GET("license", h.License)
}
// Get 获取指定的系统配置 // Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) { func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key") key := c.Query("key")

View File

@@ -10,9 +10,11 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
@@ -25,16 +27,18 @@ import (
type DallJobHandler struct { type DallJobHandler struct {
BaseHandler BaseHandler
dallService *dalle.Service dallService *dalle.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler { func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *DallJobHandler {
return &DallJobHandler{ return &DallJobHandler{
dallService: service, dallService: service,
uploader: manager, uploader: manager,
userService: userService, userService: userService,
moderationManager: moderationManager,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -42,6 +46,24 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
} }
} }
// RegisterRoutes 注册路由
func (h *DallJobHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/dall/")
// 公开接口,不需要授权
group.GET("imgWall", h.ImgWall)
group.GET("models", h.GetModels)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) { func (h *DallJobHandler) Image(c *gin.Context) {
var data types.DallTask var data types.DallTask
@@ -50,6 +72,29 @@ func (h *DallJobHandler) Image(c *gin.Context) {
return return
} }
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceDalle,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,提示词未通过文本审核,请重新输入!")
return
}
}
var chatModel model.ChatModel var chatModel model.ChatModel
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil { if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
resp.ERROR(c, "模型不存在") resp.ERROR(c, "模型不存在")
@@ -73,11 +118,12 @@ func (h *DallJobHandler) Image(c *gin.Context) {
UserId: uint(userId), UserId: uint(userId),
ModelId: chatModel.Id, ModelId: chatModel.Id,
ModelName: chatModel.Value, ModelName: chatModel.Value,
Image: data.Image,
Prompt: data.Prompt, Prompt: data.Prompt,
Quality: data.Quality, Quality: data.Quality,
Size: data.Size, Size: data.Size,
Style: data.Style, Style: data.Style,
TranslateModelId: h.App.SysConfig.AssistantModelId, TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
Power: chatModel.Power, Power: chatModel.Power,
} }
job := model.DallJob{ job := model.DallJob{

View File

@@ -13,7 +13,6 @@ import (
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/crawler"
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
@@ -31,7 +30,6 @@ import (
type FunctionHandler struct { type FunctionHandler struct {
BaseHandler BaseHandler
config types.ApiConfig
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
dallService *dalle.Service dallService *dalle.Service
userService *service.UserService userService *service.UserService
@@ -49,13 +47,23 @@ func NewFunctionHandler(
App: server, App: server,
DB: db, DB: db,
}, },
config: config.ApiConfig,
uploadManager: manager, uploadManager: manager,
dallService: dallService, dallService: dallService,
userService: userService, userService: userService,
} }
} }
// RegisterRoutes 注册路由
func (h *FunctionHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/function/")
group.GET("list", h.List)
// 需要用户授权的接口
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
}
type resVo struct { type resVo struct {
Code types.BizCode `json:"code"` Code types.BizCode `json:"code"`
Message string `json:"message"` Message string `json:"message"`
@@ -107,16 +115,10 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
return return
} }
if h.config.Token == "" { url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL)
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
var res resVo var res resVo
r, err := req.C().R(). r, err := req.C().R().
SetHeader("AppId", h.config.AppId). SetHeader("Authorization", "Bearer geekai-plus").
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil { if err != nil {
resp.ERROR(c, fmt.Sprintf("%v", err)) resp.ERROR(c, fmt.Sprintf("%v", err))
@@ -146,16 +148,10 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
return return
} }
if h.config.Token == "" { url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL)
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
var res resVo var res resVo
r, err := req.C().R(). r, err := req.C().R().
SetHeader("AppId", h.config.AppId). SetHeader("Authorization", "Bearer geekai-plus").
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil { if err != nil {
resp.ERROR(c, fmt.Sprintf("%v", err)) resp.ERROR(c, fmt.Sprintf("%v", err))
@@ -193,16 +189,23 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return return
} }
var chatModel model.ChatModel
res := h.DB.Where("type = ?", "img").Where("enabled", true).First(&chatModel)
if res.Error != nil {
resp.ERROR(c, "没有找到可用的AI绘图模型")
return
}
logger.Debugf("绘画参数:%+v", params) logger.Debugf("绘画参数:%+v", params)
var user model.User var user model.User
res := h.DB.Where("id = ?", params["user_id"]).First(&user) res = h.DB.Where("id = ?", params["user_id"]).First(&user)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "当前用户不存在!") resp.ERROR(c, "当前用户不存在!")
return return
} }
if user.Power < h.App.SysConfig.DallPower { if user.Power < chatModel.Power {
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足") resp.ERROR(c, "创建绘图任务失败,算力不足")
return return
} }
@@ -211,24 +214,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
task := types.DallTask{ task := types.DallTask{
UserId: user.Id, UserId: user.Id,
Prompt: prompt, Prompt: prompt,
ModelId: 0, ModelId: chatModel.Id,
ModelName: "dall-e-3", ModelName: chatModel.Value,
TranslateModelId: h.App.SysConfig.AssistantModelId, TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
N: 1, N: 1,
Quality: "standard", Quality: "standard",
Size: "1024x1024", Size: "1024x1024",
Style: "vivid", Style: "vivid",
Power: h.App.SysConfig.DallPower, Power: chatModel.Power,
} }
job := model.DallJob{ job := model.DallJob{
UserId: user.Id, UserId: user.Id,
Prompt: prompt, Prompt: prompt,
Power: h.App.SysConfig.DallPower, Power: chatModel.Power,
TaskInfo: utils.JsonEncode(task), TaskInfo: utils.JsonEncode(task),
} }
err := h.DB.Create(&job).Error err := h.DB.Create(&job).Error
if err != nil { if err != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error()) resp.ERROR(c, "创建绘图任务失败:"+err.Error())
return return
} }
@@ -253,76 +256,6 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
resp.SUCCESS(c, content) resp.SUCCESS(c, content)
} }
// 实现一个联网搜索的函数工具,采用爬虫实现
func (h *FunctionHandler) WebSearch(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
var params map[string]interface{}
if err := c.ShouldBindJSON(&params); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 从参数中获取搜索关键词
keyword, ok := params["keyword"].(string)
if !ok || keyword == "" {
resp.ERROR(c, "搜索关键词不能为空")
return
}
// 从参数中获取最大页数默认为1页
maxPages := 1
if pages, ok := params["max_pages"].(float64); ok {
maxPages = int(pages)
}
// 获取用户ID
userID, ok := params["user_id"].(float64)
if !ok {
resp.ERROR(c, "用户ID不能为空")
return
}
// 查询用户信息
var user model.User
res := h.DB.Where("id = ?", int(userID)).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户不存在")
return
}
// 检查用户算力是否足够
searchPower := 1 // 每次搜索消耗1点算力
if user.Power < searchPower {
resp.ERROR(c, "算力不足,无法执行网络搜索")
return
}
// 执行网络搜索
searchResults, err := crawler.SearchWeb(keyword, maxPages)
if err != nil {
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
return
}
// 扣减用户算力
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
Type: types.PowerConsume,
Model: "web_search",
Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)),
})
if err != nil {
resp.ERROR(c, "扣减算力失败:"+err.Error())
return
}
// 返回搜索结果
resp.SUCCESS(c, searchResults)
}
// List 获取所有的工具函数列表 // List 获取所有的工具函数列表
func (h *FunctionHandler) List(c *gin.Context) { func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function var items []model.Function

View File

@@ -8,14 +8,18 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"strings"
) )
// InviteHandler 用户邀请 // InviteHandler 用户邀请
@@ -27,6 +31,23 @@ func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *InviteHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/invite/")
// 公开接口,不需要授权
group.GET("hits", h.Hits)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("code", h.Code)
group.GET("list", h.List)
group.GET("stats", h.Stats)
group.GET("rules", h.Rules)
}
}
// Code 获取当前用户邀请码 // Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) { func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c) userId := h.GetLoginUserId(c)
@@ -65,21 +86,34 @@ func (h *InviteHandler) List(c *gin.Context) {
var total int64 var total int64
session.Model(&model.InviteLog{}).Count(&total) session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog var items []model.InviteLog
var list = make([]vo.InviteLog, 0)
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items) err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items).Error
if res.Error == nil { if err != nil {
for _, item := range items { resp.ERROR(c, err.Error())
var v vo.InviteLog return
err := utils.CopyObject(item, &v) }
if err == nil {
v.Id = item.Id userIds := make([]uint, 0)
v.CreatedAt = item.CreatedAt.Unix() for _, item := range items {
list = append(list, v) userIds = append(userIds, item.UserId)
} else { }
logger.Error(err) userMap := make(map[uint]model.User)
} var users []model.User
h.DB.Model(&model.User{}).Where("id IN (?)", userIds).Find(&users)
for _, user := range users {
userMap[user.Id] = user
}
var list = make([]vo.InviteLog, 0)
for _, item := range items {
var v vo.InviteLog
err := utils.CopyObject(item, &v)
if err != nil {
continue
} }
v.CreatedAt = item.CreatedAt.Unix()
v.Avatar = userMap[item.UserId].Avatar
list = append(list, v)
} }
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list)) resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
} }
@@ -90,3 +124,89 @@ func (h *InviteHandler) Hits(c *gin.Context) {
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1)) h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// Stats 获取邀请统计
func (h *InviteHandler) Stats(c *gin.Context) {
userId := h.GetLoginUserId(c)
// 获取邀请码
var inviteCode model.InviteCode
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "邀请码不存在")
return
}
// 统计累计邀请数
var totalInvite int64
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ?", userId).Count(&totalInvite)
// 统计今日邀请数
today := time.Now().Format("2006-01-02")
var todayInvite int64
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ? AND DATE(created_at) = ?", userId, today).Count(&todayInvite)
// 获取系统配置中的邀请奖励
var config model.Config
var invitePower int = 200 // 默认值
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
var configMap map[string]any
if utils.JsonDecode(config.Value, &configMap) == nil {
if power, ok := configMap["invite_power"].(float64); ok {
invitePower = int(power)
}
}
}
// 计算获得奖励总数
rewardTotal := int(totalInvite) * invitePower
// 构建邀请链接
inviteLink := fmt.Sprintf("%s/register?invite=%s", h.App.Config.StaticUrl, inviteCode.Code)
stats := vo.InviteStats{
InviteCount: int(totalInvite),
RewardTotal: rewardTotal,
TodayInvite: int(todayInvite),
InviteCode: inviteCode.Code,
InviteLink: inviteLink,
}
resp.SUCCESS(c, stats)
}
// Rules 获取奖励规则
func (h *InviteHandler) Rules(c *gin.Context) {
// 获取系统配置中的邀请奖励
var config model.Config
var invitePower int = 200 // 默认值
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
var configMap map[string]interface{}
if utils.JsonDecode(config.Value, &configMap) == nil {
if power, ok := configMap["invite_power"].(float64); ok {
invitePower = int(power)
}
}
}
rules := []vo.RewardRule{
{
Id: 1,
Title: "好友注册",
Desc: "好友通过邀请链接成功注册",
Icon: "icon-user-fill",
Color: "#1989fa",
Reward: invitePower,
},
{
Id: 2,
Title: "好友首次充值",
Desc: "好友首次充值任意金额",
Icon: "icon-money",
Color: "#07c160",
Reward: invitePower * 2, // 假设首次充值奖励是注册奖励的2倍
},
}
resp.SUCCESS(c, rules)
}

View File

@@ -2,11 +2,12 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/jimeng" "geekai/service/jimeng"
"geekai/service/moderation"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
@@ -19,27 +20,34 @@ import (
// JimengHandler 即梦AI处理器 // JimengHandler 即梦AI处理器
type JimengHandler struct { type JimengHandler struct {
BaseHandler BaseHandler
jimengService *jimeng.Service jimengService *jimeng.Service
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
// NewJimengHandler 创建即梦AI处理器 // NewJimengHandler 创建即梦AI处理器
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler { func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService, moderationManager *moderation.ServiceManager) *JimengHandler {
return &JimengHandler{ return &JimengHandler{
BaseHandler: BaseHandler{App: app, DB: db}, BaseHandler: BaseHandler{App: app, DB: db},
jimengService: jimengService, jimengService: jimengService,
userService: userService, userService: userService,
moderationManager: moderationManager,
} }
} }
// RegisterRoutes 注册路由,新增统一任务接口 // RegisterRoutes 注册路由,新增统一任务接口
func (h *JimengHandler) RegisterRoutes() { func (h *JimengHandler) RegisterRoutes() {
rg := h.App.Engine.Group("/api/jimeng") group := h.App.Engine.Group("/api/jimeng/")
rg.POST("task", h.CreateTask) // 只保留统一任务接口
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口 // 需要用户授权的接口
rg.POST("jobs", h.Jobs) group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
rg.GET("remove", h.Remove) {
rg.GET("retry", h.Retry) group.POST("task", h.CreateTask)
group.GET("power-config", h.GetPowerConfig)
group.POST("jobs", h.Jobs)
group.GET("remove", h.Remove)
group.GET("retry", h.Retry)
}
} }
// JimengTaskRequest 统一任务请求结构体 // JimengTaskRequest 统一任务请求结构体
@@ -70,6 +78,31 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceJiMeng,
Input: req.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
// 新增:除图像特效外,其他任务类型必须有提示词 // 新增:除图像特效外,其他任务类型必须有提示词
if req.TaskType != "image_effects" && req.Prompt == "" { if req.TaskType != "image_effects" && req.Prompt == "" {
resp.ERROR(c, "提示词不能为空") resp.ERROR(c, "提示词不能为空")
@@ -153,12 +186,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
"seed": req.Seed, "seed": req.Seed,
"scale": req.Scale, "scale": req.Scale,
} }
if len(req.ImageUrls) > 0 { params["image_urls"] = []string{req.ImageInput}
params["image_urls"] = req.ImageUrls
}
if len(req.BinaryDataBase64) > 0 {
params["binary_data_base64"] = req.BinaryDataBase64
}
case "image_effects": case "image_effects":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects) powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
taskType = model.JMTaskTypeImageEffects taskType = model.JMTaskTypeImageEffects
@@ -181,9 +209,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
taskType = model.JMTaskTypeTextToVideo taskType = model.JMTaskTypeTextToVideo
reqKey = jimeng.ReqKeyTextToVideo reqKey = jimeng.ReqKeyTextToVideo
modelName = "即梦文生视频" modelName = "即梦文生视频"
if req.Seed == 0 {
req.Seed = -1
}
if req.AspectRatio == "" { if req.AspectRatio == "" {
req.AspectRatio = jimeng.AspectRatio16_9 req.AspectRatio = jimeng.AspectRatio16_9
} }
@@ -196,9 +221,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
taskType = model.JMTaskTypeImageToVideo taskType = model.JMTaskTypeImageToVideo
reqKey = jimeng.ReqKeyImageToVideo reqKey = jimeng.ReqKeyImageToVideo
modelName = "即梦图生视频" modelName = "即梦图生视频"
if req.Seed == 0 {
req.Seed = -1
}
params = map[string]any{ params = map[string]any{
"seed": req.Seed, "seed": req.Seed,
"aspect_ratio": req.AspectRatio, "aspect_ratio": req.AspectRatio,
@@ -333,8 +355,10 @@ func (h *JimengHandler) Remove(c *gin.Context) {
resp.ERROR(c, "无权限操作") resp.ERROR(c, "无权限操作")
return return
} }
if job.Status != model.JMTaskStatusFailed {
resp.ERROR(c, "只有失败的任务能删除") // 正在运行中的任务能删除
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
return return
} }
@@ -345,17 +369,20 @@ func (h *JimengHandler) Remove(c *gin.Context) {
return return
} }
// 退回算力 // 失败任务删除后退回算力
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{ if job.Status != model.JMTaskStatusFailed {
Type: types.PowerRefund, err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
Model: "jimeng", Type: types.PowerRefund,
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power), Model: "jimeng",
}) Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
if err != nil { })
resp.ERROR(c, "退回算力失败") if err != nil {
tx.Rollback() resp.ERROR(c, "退回算力失败")
return tx.Rollback()
return
}
} }
tx.Commit() tx.Commit()
resp.SUCCESS(c, gin.H{}) resp.SUCCESS(c, gin.H{})
@@ -408,7 +435,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
// getPowerFromConfig 从配置中获取指定类型的算力消耗 // getPowerFromConfig 从配置中获取指定类型的算力消耗
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int { func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
config := h.jimengService.GetConfig() config := h.App.SysConfig.Jimeng
switch taskType { switch taskType {
case model.JMTaskTypeTextToImage: case model.JMTaskTypeTextToImage:
@@ -430,7 +457,7 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
// GetPowerConfig 获取即梦各任务类型算力消耗配置 // GetPowerConfig 获取即梦各任务类型算力消耗配置
func (h *JimengHandler) GetPowerConfig(c *gin.Context) { func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
config := h.jimengService.GetConfig() config := h.App.SysConfig.Jimeng
resp.SUCCESS(c, gin.H{ resp.SUCCESS(c, gin.H{
"text_to_image": config.Power.TextToImage, "text_to_image": config.Power.TextToImage,
"image_to_image": config.Power.ImageToImage, "image_to_image": config.Power.ImageToImage,

View File

@@ -10,6 +10,7 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/store/model" "geekai/store/model"
@@ -35,6 +36,17 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
} }
} }
// RegisterRoutes 注册路由
func (h *MarkMapHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/markMap/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("gen", h.Generate)
}
}
// Generate 生成思维导图 // Generate 生成思维导图
func (h *MarkMapHandler) Generate(c *gin.Context) { func (h *MarkMapHandler) Generate(c *gin.Context) {
var data struct { var data struct {

View File

@@ -13,6 +13,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -25,6 +26,12 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *MenuHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/menu/")
group.GET("list", h.List)
}
// List 数据列表 // List 数据列表
func (h *MenuHandler) List(c *gin.Context) { func (h *MenuHandler) List(c *gin.Context) {
index := h.GetBool(c, "index") index := h.GetBool(c, "index")
@@ -33,7 +40,7 @@ func (h *MenuHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
session = session.Where("enabled", true) session = session.Where("enabled", true)
if index { if index {
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs) session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs)
} }
res := session.Order("sort_num ASC").Find(&items) res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil { if res.Error == nil {

View File

@@ -10,9 +10,11 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/mj" "geekai/service/mj"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
@@ -27,18 +29,20 @@ import (
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
mjService *mj.Service mjService *mj.Service
snowflake *service.Snowflake snowflake *service.Snowflake
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler { func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *MidJourneyHandler {
return &MidJourneyHandler{ return &MidJourneyHandler{
snowflake: snowflake, snowflake: snowflake,
mjService: service, mjService: service,
uploader: manager, uploader: manager,
userService: userService, userService: userService,
moderationManager: moderationManager,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -46,6 +50,25 @@ func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.S
} }
} }
// RegisterRoutes 注册路由
func (h *MidJourneyHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/mj/")
// 公开接口,不需要授权
group.GET("imgWall", h.ImgWall)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("image", h.Image)
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
@@ -53,7 +76,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if user.Power < h.App.SysConfig.MjPower { if user.Power < h.App.SysConfig.Base.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false return false
} }
@@ -90,6 +113,29 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceMJ,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
var params = "" var params = ""
if data.Rate != "" && !strings.Contains(params, "--ar") { if data.Rate != "" && !strings.Contains(params, "--ar") {
params += " --ar " + data.Rate params += " --ar " + data.Rate
@@ -159,8 +205,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Params: params, Params: params,
UserId: userId, UserId: userId,
ImgArr: data.ImgArr, ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode, Mode: h.App.SysConfig.Base.MjMode,
TranslateModelId: h.App.SysConfig.AssistantModelId, TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
} }
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: data.TaskType, Type: data.TaskType,
@@ -169,7 +215,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskInfo: utils.JsonEncode(task), TaskInfo: utils.JsonEncode(task),
Progress: 0, Progress: 0,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params), Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower, Power: h.App.SysConfig.Base.MjPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
opt := "绘图" opt := "绘图"
@@ -232,7 +278,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode, Mode: h.App.SysConfig.Base.MjMode,
} }
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: types.TaskUpscale.String(), Type: types.TaskUpscale.String(),
@@ -240,7 +286,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
TaskId: taskId, TaskId: taskId,
TaskInfo: utils.JsonEncode(task), TaskInfo: utils.JsonEncode(task),
Progress: 0, Progress: 0,
Power: h.App.SysConfig.MjActionPower, Power: h.App.SysConfig.Base.MjActionPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
@@ -287,7 +333,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
ChannelId: data.ChannelId, ChannelId: data.ChannelId,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode, Mode: h.App.SysConfig.Base.MjMode,
} }
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: types.TaskVariation.String(), Type: types.TaskVariation.String(),
@@ -296,7 +342,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId, TaskId: taskId,
TaskInfo: utils.JsonEncode(task), TaskInfo: utils.JsonEncode(task),
Progress: 0, Progress: 0,
Power: h.App.SysConfig.MjActionPower, Power: h.App.SysConfig.Base.MjActionPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {

View File

@@ -9,6 +9,7 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service/oss" "geekai/service/oss"
"geekai/store/model" "geekai/store/model"
@@ -32,6 +33,22 @@ func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManage
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager} return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
} }
// RegisterRoutes 注册路由
func (h *NetHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/upload")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("", h.Upload)
group.POST("list", h.List)
group.GET("remove", h.Remove)
}
// 公开接口,不需要授权
h.App.Engine.GET("/api/download", h.Download)
}
func (h *NetHandler) Upload(c *gin.Context) { func (h *NetHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil { if err != nil {

View File

@@ -9,12 +9,12 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -28,6 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *OrderHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/order/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.GET("query", h.Query)
}
}
// List 订单列表 // List 订单列表
func (h *OrderHandler) List(c *gin.Context) { func (h *OrderHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
@@ -48,20 +60,21 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix() order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay] payChannel, ok := types.PayChannel[item.Channel]
if !ok { if !ok {
payMethod = item.PayWay payChannel = item.PayWay
} }
payName, ok := types.PayNames[item.PayType] payWays, ok := types.PayWays[item.PayWay]
if !ok { if !ok {
payName = item.PayWay payWays = item.PayWay
} }
order.PayMethod = payMethod order.ChannelName = payChannel
order.PayName = payName order.PayName = payWays
list = append(list, order) list = append(list, order)
} else { } else {
logger.Error(err) logger.Error(err)
} }
} }
} }
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list)) resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
@@ -82,17 +95,8 @@ func (h *OrderHandler) Query(c *gin.Context) {
return return
} }
counter := 0 var item model.Order
for { h.DB.Where("order_no = ?", orderNo).First(&item)
time.Sleep(time.Second)
var item model.Order
h.DB.Where("order_no = ?", orderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
}
counter++
}
resp.SUCCESS(c, gin.H{"status": order.Status}) resp.SUCCESS(c, gin.H{"status": order.Status})
} }

View File

@@ -11,6 +11,7 @@ import (
"embed" "embed"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/payment" "geekai/service/payment"
@@ -33,52 +34,148 @@ type PayWay struct {
// PaymentHandler 支付服务回调 handler // PaymentHandler 支付服务回调 handler
type PaymentHandler struct { type PaymentHandler struct {
BaseHandler BaseHandler
alipayService *payment.AlipayService alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService epayService *payment.EPayService
geekPayService *payment.GeekPayService wxpayService *payment.WxPayService
wechatPayService *payment.WechatPayService snowflake *service.Snowflake
snowflake *service.Snowflake userService *service.UserService
userService *service.UserService fs embed.FS
fs embed.FS lock sync.Mutex
lock sync.Mutex config *types.PaymentConfig
signKey string // 用来签名的随机秘钥
} }
func NewPaymentHandler( func NewPaymentHandler(
server *core.AppServer, server *core.AppServer,
alipayService *payment.AlipayService, alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService, geekPayService *payment.EPayService,
geekPayService *payment.GeekPayService, wxpayService *payment.WxPayService,
wechatPayService *payment.WechatPayService,
db *gorm.DB, db *gorm.DB,
userService *service.UserService, userService *service.UserService,
snowflake *service.Snowflake, snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler { fs embed.FS,
sysConfig *types.SystemConfig) *PaymentHandler {
return &PaymentHandler{ return &PaymentHandler{
alipayService: alipayService, alipayService: alipayService,
huPiPayService: huPiPayService, epayService: geekPayService,
geekPayService: geekPayService, wxpayService: wxpayService,
wechatPayService: wechatPayService, snowflake: snowflake,
snowflake: snowflake, userService: userService,
userService: userService, fs: fs,
fs: fs, lock: sync.Mutex{},
lock: sync.Mutex{},
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: server, App: server,
DB: db, DB: db,
}, },
signKey: utils.RandString(32), config: &sysConfig.Payment,
} }
} }
func (h *PaymentHandler) Pay(c *gin.Context) { // RegisterRoutes 注册路由
func (h *PaymentHandler) RegisterRoutes() {
rg := h.App.Engine.Group("/api/payment/")
// 支付回调接口(公开)
rg.POST("notify/alipay", h.AlipayNotify)
rg.GET("notify/epay", h.EPayNotify)
rg.POST("notify/wxpay", h.WxpayNotify)
// 需要用户登录的接口
rg.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
rg.POST("create", h.CreateOrder)
}
}
func (h *PaymentHandler) StartSyncOrders() {
go func() {
for {
err := h.SyncOrders()
if err != nil {
logger.Error(err)
}
time.Sleep(time.Second * 5)
}
}()
}
// SyncOrders 同步订单状态
func (h *PaymentHandler) SyncOrders() error {
defer func() {
if err := recover(); err != nil {
logger.Errorf("同步订单状态发生异常: %v", err)
}
}()
var orders []model.Order
err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error
if err != nil {
return err
}
for _, order := range orders {
time.Sleep(time.Second * 1)
//超时15分钟的订单直接标记为已关闭
if time.Now().After(order.CreatedAt.Add(time.Minute * 5)) {
h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true)
logger.Errorf("订单超时:%v", order)
continue
}
// 查询订单状态
var res payment.OrderInfo
switch order.Channel {
case payment.PayChannelEpay:
res, err = h.epayService.Query(order.OrderNo)
if err != nil {
logger.Errorf("error with query order info: %v", err)
continue
}
// 微信支付
case payment.PayChannelWX:
res, err = h.wxpayService.Query(order.OrderNo)
logger.Debugf("微信支付订单状态:%+v", res)
if err != nil {
logger.Errorf("error with query order info: %v", err)
continue
}
case payment.PayChannelAL:
res, err = h.alipayService.Query(order.OrderNo)
if err != nil {
logger.Errorf("error with query order info: %v", err)
continue
}
}
// 订单已关闭
if res.Closed() {
h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{
"checked": true,
"status": types.OrderPaidFailed,
})
logger.Errorf("订单已关闭:%v", order)
continue
}
// 订单未支付,不处理,继续轮询
if !res.Success() {
continue
}
// 订单支付成功
err = h.paySuccess(res)
if err != nil {
logger.Errorf("error with deal order: %v", err)
continue
}
}
return nil
}
func (h *PaymentHandler) CreateOrder(c *gin.Context) {
var data struct { var data struct {
PayWay string `json:"pay_way"` PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信
PayType string `json:"pay_type"` Pid int `json:"pid,omitempty"`
ProductId int `json:"product_id"` Device string `json:"device,omitempty"`
UserId int `json:"user_id"` Domain string `json:"domain,omitempty"` // 支付回调域名
Device string `json:"device"` Channel string `json:"channel,omitempty"`
Host string `json:"host"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@@ -86,7 +183,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
} }
var product model.Product var product model.Product
err := h.DB.Where("id", data.ProductId).First(&product).Error err := h.DB.Where("id", data.Pid).First(&product).Error
if err != nil { if err != nil {
resp.ERROR(c, "Product not found") resp.ERROR(c, "Product not found")
return return
@@ -97,136 +194,118 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
resp.ERROR(c, "error with generate trade no: "+err.Error()) resp.ERROR(c, "error with generate trade no: "+err.Error())
return return
} }
userId := h.GetLoginUserId(c)
var user model.User var user model.User
err = h.DB.Where("id", data.UserId).First(&user).Error err = h.DB.Where("id", userId).First(&user).Error
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
amount := product.Discount amount := product.Price
var payURL, returnURL, notifyURL string var payURL, notifyURL string
switch data.PayWay { switch data.PayWay {
case "alipay": case "wxpay":
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付 logger.Debugf("微信支付,%+v", data)
notifyURL = h.App.Config.AlipayConfig.NotifyURL data.Channel = payment.PayChannelWX
} else { // 优先使用微信官方支付
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host) if h.config.WxPay.Enabled {
} data.Channel = "wxpay"
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付 if h.config.WxPay.Domain != "" {
returnURL = h.App.Config.AlipayConfig.ReturnURL data.Domain = h.config.WxPay.Domain
} else { }
returnURL = fmt.Sprintf("%s/payReturn", data.Host) notifyURL = fmt.Sprintf("%s/api/payment/notify/wxpay", data.Domain)
} payURL, err = h.wxpayService.Pay(payment.PayRequest{
money := fmt.Sprintf("%.2f", amount)
if data.Device == "wechat" {
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
OutTradeNo: orderNo, OutTradeNo: orderNo,
Subject: product.Name, TotalFee: fmt.Sprintf("%d", int(amount*100)),
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
} else {
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
}
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
break
case "wechat":
if h.App.Config.WechatPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
}
if data.Device == "wechat" {
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name, Subject: product.Name,
NotifyURL: notifyURL, NotifyURL: notifyURL,
ClientIP: c.ClientIP(), ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: payment.PayWayWX,
}) })
} else { if err != nil {
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{ resp.ERROR(c, err.Error())
return
}
} else if h.config.Epay.Enabled { // 聚合支付
logger.Debugf("聚合支付%+v", data)
data.Channel = payment.PayChannelEpay
if h.config.Epay.Domain != "" {
data.Domain = h.config.Epay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
params := payment.PayRequest{
OutTradeNo: orderNo, OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name, Subject: product.Name,
TotalFee: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: payment.PayWayWX,
NotifyURL: notifyURL,
}
r, err := h.epayService.Pay(params)
logger.Debugf("请求支付结果,%+v", r)
if err != nil {
resp.ERROR(c, err.Error())
return
} else {
payURL = r
}
} else {
resp.ERROR(c, "系统没有配置可用的支付渠道!")
return
}
case "alipay":
if h.config.Alipay.Enabled {
logger.Debugf("支付宝,%+v", data)
data.Channel = payment.PayChannelAL
if h.config.Alipay.Domain != "" { // 用于本地调试支付
data.Domain = h.config.Alipay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain)
money := fmt.Sprintf("%.2f", amount)
payURL, err = h.alipayService.Pay(payment.PayRequest{
Device: data.Device,
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
NotifyURL: notifyURL, NotifyURL: notifyURL,
}) })
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
break
case "hupi":
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
}
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
WapName: "GeekAI助手",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
payURL = r.URL
break
case "geek":
if h.App.Config.GeekPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
}
if h.App.Config.GeekPayConfig.ReturnURL != "" {
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
}
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
params := payment.GeekPayParams{
OutTradeNo: orderNo,
Method: "web",
Name: product.Name,
Money: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
Type: data.PayType,
ReturnURL: returnURL,
NotifyURL: notifyURL,
}
res, err := h.geekPayService.Pay(params) if err != nil {
if err != nil { resp.ERROR(c, "error with generate pay url: "+err.Error())
resp.ERROR(c, err.Error()) return
}
} else if h.config.Epay.Enabled { // 聚合支付
logger.Debugf("聚合支付,%+v", data)
data.Channel = payment.PayChannelEpay
if h.config.Epay.Domain != "" {
data.Domain = h.config.Epay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
params := payment.PayRequest{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: data.PayWay,
NotifyURL: notifyURL,
}
r, err := h.epayService.Pay(params)
if err != nil {
resp.ERROR(c, err.Error())
return
} else {
payURL = r
}
} else {
resp.ERROR(c, "系统没有配置可用的支付渠道!")
return return
} }
payURL = res.PayURL
default: default:
resp.ERROR(c, "不支持的支付渠道") resp.ERROR(c, "不支持的支付渠道")
return return
@@ -234,43 +313,40 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
// 创建订单 // 创建订单
remark := types.OrderRemark{ remark := types.OrderRemark{
Days: product.Days, Power: product.Power,
Power: product.Power, Name: product.Name,
Name: product.Name, Price: product.Price,
Price: product.Price,
Discount: product.Discount,
} }
order := model.Order{ order := model.Order{
UserId: user.Id, UserId: user.Id,
Username: user.Username, Username: user.Username,
ProductId: product.Id, OrderNo: orderNo,
OrderNo: orderNo, Subject: product.Name,
Subject: product.Name, Amount: amount,
Amount: amount, Status: types.OrderNotPaid,
Status: types.OrderNotPaid, PayWay: data.PayWay,
PayWay: data.PayWay, Channel: data.Channel,
PayType: data.PayType, Remark: utils.JsonEncode(remark),
Remark: utils.JsonEncode(remark),
} }
err = h.DB.Create(&order).Error err = h.DB.Create(&order).Error
if err != nil { if err != nil {
resp.ERROR(c, "error with create order: "+err.Error()) resp.ERROR(c, "error with create order: "+err.Error())
return return
} }
resp.SUCCESS(c, payURL) resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo})
} }
// 异步通知回调公共逻辑 // 支付成功处理
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
h.lock.Lock()
defer h.lock.Unlock()
var order model.Order var order model.Order
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error
if err != nil { if err != nil {
return fmt.Errorf("error with fetch order: %v", err) return fmt.Errorf("error with fetch order: %v", err)
} }
h.lock.Lock()
defer h.lock.Unlock()
// 已支付订单,直接返回 // 已支付订单,直接返回
if order.Status == types.OrderPaidSuccess { if order.Status == types.OrderPaidSuccess {
return nil return nil
@@ -290,19 +366,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// 增加用户算力 // 增加用户算力
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{ err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
Type: types.PowerRecharge, Type: types.PowerRecharge,
Model: order.PayWay, Model: order.Subject,
Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo), Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
CreatedAt: time.Now(),
}) })
if err != nil { if err != nil {
return err return err
} }
// 更新订单状态 // 更新订单状态
order.PayTime = time.Now().Unix() order.PayTime = utils.Str2stamp(info.PayTime)
order.Status = types.OrderPaidSuccess order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo order.TradeNo = info.TradeId
err = h.DB.Updates(&order).Error order.Checked = true
err = h.DB.Debug().Updates(&order).Error
if err != nil { if err != nil {
return fmt.Errorf("error with update order info: %v", err) return fmt.Errorf("error with update order info: %v", err)
} }
@@ -317,54 +395,6 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
return nil return nil
} }
// GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
payWays := make([]gin.H, 0)
if h.App.Config.AlipayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
}
if h.App.Config.HuPiPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
}
if h.App.Config.GeekPayConfig.Enabled {
for _, v := range h.App.Config.GeekPayConfig.Methods {
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
}
}
if h.App.Config.WechatPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
}
resp.SUCCESS(c, payWays)
}
// HuPiPayNotify 虎皮椒支付异步回调
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
if err = h.huPiPayService.Check(orderNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// AlipayNotify 支付宝支付回调 // AlipayNotify 支付宝支付回调
func (h *PaymentHandler) AlipayNotify(c *gin.Context) { func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
err := c.Request.ParseForm() err := c.Request.ParseForm()
@@ -373,16 +403,15 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return return
} }
result := h.alipayService.TradeVerify(c.Request) orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no"))
logger.Infof("收到支付宝商号订单支付回调:%+v", result) logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo)
if !result.Success() { if !orderInfo.Success() {
logger.Error("订单校验失败:", result.Message) logger.Errorf("订单校验失败:%v", err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
tradeNo := c.Request.Form.Get("trade_no") err = h.paySuccess(orderInfo)
err = h.notify(result.OutTradeNo, tradeNo)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
@@ -392,28 +421,35 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
c.String(http.StatusOK, "success") c.String(http.StatusOK, "success")
} }
// GeekPayNotify 支付异步回调 // EPayNotify 易支付支付异步回调
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) { func (h *PaymentHandler) EPayNotify(c *gin.Context) {
var params = make(map[string]string) var params = make(map[string]string)
for k := range c.Request.URL.Query() { for k := range c.Request.URL.Query() {
params[k] = c.Query(k) params[k] = c.Query(k)
} }
logger.Infof("收到GeekPay订单支付回调:%+v", params) logger.Infof("收到易支付订单支付回调:%+v", params)
// 检查支付状态 // 检查支付状态, 如果未支付,则返回成功
if params["trade_status"] != "TRADE_SUCCESS" { if params["trade_status"] != "TRADE_SUCCESS" {
c.String(http.StatusOK, "success") c.String(http.StatusOK, "success")
return return
} }
sign := h.geekPayService.Sign(params) sign := h.epayService.Sign(params)
if sign != c.Query("sign") { if sign != c.Query("sign") {
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign")) logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
// 查询订单状态
order, err := h.epayService.Query(params["out_trade_no"])
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
err := h.notify(params["out_trade_no"], params["trade_no"]) err = h.paySuccess(order)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
@@ -423,26 +459,23 @@ func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
c.String(http.StatusOK, "success") c.String(http.StatusOK, "success")
} }
// WechatPayNotify 微信商户支付异步回调 // WxpayNotify 微信商户支付异步回调
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) { func (h *PaymentHandler) WxpayNotify(c *gin.Context) {
err := c.Request.ParseForm() err := c.Request.ParseForm()
if err != nil { if err != nil {
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")
return return
} }
result := h.wechatPayService.TradeVerify(c.Request) orderInfo, err := h.wxpayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", result) logger.Infof("收到微信商号订单支付回调:%+v", orderInfo)
if !result.Success() { if err != nil {
logger.Error("订单校验失败:", err) logger.Errorf("订单校验失败:%v", err)
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"})
"code": "FAIL",
"message": err.Error(),
})
return return
} }
err = h.notify(result.OutTradeNo, result.TradeId) err = h.paySuccess(orderInfo)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
c.String(http.StatusOK, "fail") c.String(http.StatusOK, "fail")

View File

@@ -9,11 +9,13 @@ package handler
import ( import (
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/store/model" "geekai/store/model"
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -27,6 +29,18 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *PowerLogHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/powerLog/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("list", h.List)
group.GET("stats", h.Stats)
}
}
func (h *PowerLogHandler) List(c *gin.Context) { func (h *PowerLogHandler) List(c *gin.Context) {
var data struct { var data struct {
Model string `json:"model"` Model string `json:"model"`
@@ -72,3 +86,45 @@ func (h *PowerLogHandler) List(c *gin.Context) {
} }
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
} }
// Stats 获取用户算力统计
func (h *PowerLogHandler) Stats(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.NotAuth(c)
return
}
// 获取用户信息(包含余额)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
resp.ERROR(c, "用户不存在")
return
}
// 计算总消费(所有支出记录)
var totalConsume int64
h.DB.Model(&model.PowerLog{}).
Where("user_id", userId).
Where("mark", types.PowerSub).
Select("COALESCE(SUM(amount), 0)").
Scan(&totalConsume)
// 计算今日消费
today := time.Now().Format("2006-01-02")
var todayConsume int64
h.DB.Model(&model.PowerLog{}).
Where("user_id", userId).
Where("mark", types.PowerSub).
Where("DATE(created_at) = ?", today).
Select("COALESCE(SUM(amount), 0)").
Scan(&todayConsume)
stats := map[string]interface{}{
"total": totalConsume,
"today": todayConsume,
"balance": user.Power,
}
resp.SUCCESS(c, stats)
}

View File

@@ -13,6 +13,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -25,6 +26,12 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}} return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
} }
// RegisterRoutes 注册路由
func (h *ProductHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/product/")
group.GET("list", h.List)
}
// List 模型列表 // List 模型列表
func (h *ProductHandler) List(c *gin.Context) { func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product var items []model.Product

View File

@@ -10,12 +10,14 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
@@ -39,6 +41,20 @@ func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
} }
} }
// RegisterRoutes 注册路由
func (h *PromptHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/prompt/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis)).Use(middleware.RateLimitEvery(h.App.Redis, 30*time.Second))
{
group.POST("lyric", h.Lyric)
group.POST("image", h.Image)
group.POST("video", h.Video)
group.POST("meta", h.MetaPrompt)
}
}
// Lyric 生成歌词 // Lyric 生成歌词
func (h *PromptHandler) Lyric(c *gin.Context) { func (h *PromptHandler) Lyric(c *gin.Context) {
var data struct { var data struct {
@@ -48,25 +64,12 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成歌词",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, content) resp.SUCCESS(c, content)
} }
@@ -79,23 +82,12 @@ func (h *PromptHandler) Image(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成绘画提示词",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, strings.Trim(content, `"`)) resp.SUCCESS(c, strings.Trim(content, `"`))
} }
@@ -108,25 +100,12 @@ func (h *PromptHandler) Video(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成视频脚本",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, strings.Trim(content, `"`)) resp.SUCCESS(c, strings.Trim(content, `"`))
} }
@@ -158,9 +137,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
} }
func (h *PromptHandler) getPromptModel() string { func (h *PromptHandler) getPromptModel() string {
if h.App.SysConfig.AssistantModelId > 0 { if h.App.SysConfig.Base.AssistantModelId > 0 {
var chatModel model.ChatModel var chatModel model.ChatModel
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel) h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel)
return chatModel.Value return chatModel.Value
} }
return "gpt-4o" return "gpt-4o"

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/store/model" "geekai/store/model"
@@ -39,6 +40,18 @@ func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *servic
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService} return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
} }
// RegisterRoutes 注册路由
func (h *RealtimeHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/realtime/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.Any("", h.Connection)
group.POST("voice", h.VoiceChat)
}
}
func (h *RealtimeHandler) Connection(c *gin.Context) { func (h *RealtimeHandler) Connection(c *gin.Context) {
// 获取客户端请求中指定的子协议 // 获取客户端请求中指定的子协议
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol") clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
@@ -154,7 +167,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
return return
} }
if user.Power < h.App.SysConfig.AdvanceVoicePower { if user.Power < h.App.SysConfig.Base.AdvanceVoicePower {
resp.ERROR(c, "当前用户算力不足,无法使用该功能") resp.ERROR(c, "当前用户算力不足,无法使用该功能")
return return
} }
@@ -198,7 +211,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 扣减算力 // 扣减算力
err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{ err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{
Type: types.PowerConsume, Type: types.PowerConsume,
Model: "advanced-voice", Model: "advanced-voice",
Remark: "实时语音通话", Remark: "实时语音通话",

View File

@@ -10,14 +10,16 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
type RedeemHandler struct { type RedeemHandler struct {
@@ -30,6 +32,17 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService} return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
} }
// RegisterRoutes 注册路由
func (h *RedeemHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/redeem/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("verify", h.Verify)
}
}
func (h *RedeemHandler) Verify(c *gin.Context) { func (h *RedeemHandler) Verify(c *gin.Context) {
var data struct { var data struct {
Code string `json:"code"` Code string `json:"code"`

View File

@@ -10,8 +10,10 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/sd" "geekai/service/sd"
"geekai/store" "geekai/store"
@@ -28,12 +30,13 @@ import (
type SdJobHandler struct { type SdJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
sdService *sd.Service sdService *sd.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
snowflake *service.Snowflake snowflake *service.Snowflake
leveldb *store.LevelDB leveldb *store.LevelDB
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
func NewSdJobHandler(app *core.AppServer, func NewSdJobHandler(app *core.AppServer,
@@ -42,13 +45,15 @@ func NewSdJobHandler(app *core.AppServer,
manager *oss.UploaderManager, manager *oss.UploaderManager,
snowflake *service.Snowflake, snowflake *service.Snowflake,
userService *service.UserService, userService *service.UserService,
levelDB *store.LevelDB) *SdJobHandler { levelDB *store.LevelDB,
moderationManager *moderation.ServiceManager) *SdJobHandler {
return &SdJobHandler{ return &SdJobHandler{
sdService: service, sdService: service,
uploader: manager, uploader: manager,
snowflake: snowflake, snowflake: snowflake,
leveldb: levelDB, leveldb: levelDB,
userService: userService, userService: userService,
moderationManager: moderationManager,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
@@ -56,6 +61,23 @@ func NewSdJobHandler(app *core.AppServer,
} }
} }
// RegisterRoutes 注册路由
func (h *SdJobHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/sd/")
// 公开接口,不需要授权
group.GET("imgWall", h.ImgWall)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
func (h *SdJobHandler) preCheck(c *gin.Context) bool { func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
@@ -63,7 +85,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if user.Power < h.App.SysConfig.SdPower { if user.Power < h.App.SysConfig.Base.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false return false
} }
@@ -84,6 +106,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return return
} }
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceSD,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
if data.Width <= 0 { if data.Width <= 0 {
data.Width = 512 data.Width = 512
} }
@@ -131,7 +176,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
HdSteps: data.HdSteps, HdSteps: data.HdSteps,
}, },
UserId: userId, UserId: userId,
TranslateModelId: h.App.SysConfig.AssistantModelId, TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
} }
job := model.SdJob{ job := model.SdJob{
@@ -142,7 +187,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
TaskInfo: utils.JsonEncode(task), TaskInfo: utils.JsonEncode(task),
Prompt: data.Prompt, Prompt: data.Prompt,
Progress: 0, Progress: 0,
Power: h.App.SysConfig.SdPower, Power: h.App.SysConfig.Base.SdPower,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
res := h.DB.Create(&job) res := h.DB.Create(&job)

View File

@@ -24,24 +24,31 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct { type SmsHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
sms *sms.ServiceManager sms *sms.SmsManager
smtp *service.SmtpService smtp *service.SmtpService
captcha *service.CaptchaService captchaService *service.CaptchaService
} }
func NewSmsHandler( func NewSmsHandler(
app *core.AppServer, app *core.AppServer,
client *redis.Client, client *redis.Client,
sms *sms.ServiceManager, sms *sms.SmsManager,
smtp *service.SmtpService, smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler { captcha *service.CaptchaService) *SmsHandler {
return &SmsHandler{ return &SmsHandler{
redis: client, redis: client,
sms: sms, sms: sms,
captcha: captcha, captchaService: captcha,
smtp: smtp, smtp: smtp,
BaseHandler: BaseHandler{App: app}} BaseHandler: BaseHandler{App: app}}
}
// RegisterRoutes 注册路由
func (h *SmsHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/sms/")
// 无需授权的接口
group.POST("code", h.SendCode)
} }
// SendCode 发送验证码 // SendCode 发送验证码
@@ -56,12 +63,12 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if h.App.SysConfig.EnabledVerify { if h.captchaService.GetConfig().Enabled {
var check bool var check bool
if data.X != 0 { if data.X != 0 {
check = h.captcha.SlideCheck(data) check = h.captchaService.SlideCheck(data)
} else { } else {
check = h.captcha.Check(data) check = h.captchaService.Check(data)
} }
if !check { if !check {
resp.ERROR(c, "请先完人机验证") resp.ERROR(c, "请先完人机验证")
@@ -72,14 +79,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
code := utils.RandomNumber(6) code := utils.RandomNumber(6)
var err error var err error
if strings.Contains(data.Receiver, "@") { // email if strings.Contains(data.Receiver, "@") { // email
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") { if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!") resp.ERROR(c, "系统已禁用邮箱注册!")
return return
} }
// 检查邮箱后缀是否在白名单 // 检查邮箱后缀是否在白名单
if len(h.App.SysConfig.EmailWhiteList) > 0 { if len(h.App.SysConfig.Base.EmailWhiteList) > 0 {
inWhiteList := false inWhiteList := false
for _, suffix := range h.App.SysConfig.EmailWhiteList { for _, suffix := range h.App.SysConfig.Base.EmailWhiteList {
if strings.HasSuffix(data.Receiver, suffix) { if strings.HasSuffix(data.Receiver, suffix) {
inWhiteList = true inWhiteList = true
break break
@@ -92,7 +99,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
} }
err = h.smtp.SendVerifyCode(data.Receiver, code) err = h.smtp.SendVerifyCode(data.Receiver, code)
} else { } else {
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") { if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!") resp.ERROR(c, "系统已禁用手机号注册!")
return return
} }

View File

@@ -10,8 +10,10 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/suno" "geekai/service/suno"
"geekai/store/model" "geekai/store/model"
@@ -26,20 +28,41 @@ import (
type SunoHandler struct { type SunoHandler struct {
BaseHandler BaseHandler
sunoService *suno.Service sunoService *suno.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler { func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *SunoHandler {
return &SunoHandler{ return &SunoHandler{
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
}, },
sunoService: service, sunoService: service,
uploader: uploader, uploader: uploader,
userService: userService, userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由
func (h *SunoHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/suno/")
// 公开接口,不需要授权
group.GET("play", h.Play)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("create", h.Create)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.POST("update", h.Update)
group.GET("detail", h.Detail)
} }
} }
@@ -64,13 +87,36 @@ func (h *SunoHandler) Create(c *gin.Context) {
return return
} }
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceSuno,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
if user.Power < h.App.SysConfig.SunoPower { if user.Power < h.App.SysConfig.Base.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!") resp.ERROR(c, "您的算力不足,请充值后再试!")
return return
} }
@@ -118,7 +164,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
RefSongId: data.RefSongId, RefSongId: data.RefSongId,
RefTaskId: data.RefTaskId, RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs, ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower, Power: h.App.SysConfig.Base.SunoPower,
SongId: utils.RandString(32), SongId: utils.RandString(32),
} }
if data.Lyrics != "" { if data.Lyrics != "" {

View File

@@ -1,21 +1,36 @@
package handler package handler
import ( import (
"geekai/core"
"geekai/core/middleware"
"geekai/service" "geekai/service"
"geekai/service/payment" "geekai/service/payment"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
) )
type TestHandler struct { type TestHandler struct {
App *core.AppServer
db *gorm.DB db *gorm.DB
snowflake *service.Snowflake snowflake *service.Snowflake
js *payment.GeekPayService js *payment.EPayService
} }
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler { func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js} return &TestHandler{App: app, db: db, snowflake: snowflake, js: js}
}
// RegisterRoutes 注册路由
func (h *TestHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/test/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.Any("sse", h.PostTest, h.SseTest)
}
} }
func (h *TestHandler) SseTest(c *gin.Context) { func (h *TestHandler) SseTest(c *gin.Context) {

View File

@@ -8,8 +8,10 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"context"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/store" "geekai/store"
@@ -20,8 +22,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/imroc/req/v3"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -36,8 +36,10 @@ type UserHandler struct {
redis *redis.Client redis *redis.Client
levelDB *store.LevelDB levelDB *store.LevelDB
licenseService *service.LicenseService licenseService *service.LicenseService
captcha *service.CaptchaService captchaService *service.CaptchaService
userService *service.UserService userService *service.UserService
wxLoginService *service.WxLoginService
ipSearcher *xdb.Searcher
} }
func NewUserHandler( func NewUserHandler(
@@ -48,15 +50,45 @@ func NewUserHandler(
levelDB *store.LevelDB, levelDB *store.LevelDB,
captcha *service.CaptchaService, captcha *service.CaptchaService,
userService *service.UserService, userService *service.UserService,
wxLoginService *service.WxLoginService,
ipSearcher *xdb.Searcher,
licenseService *service.LicenseService) *UserHandler { licenseService *service.LicenseService) *UserHandler {
return &UserHandler{ return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app}, BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher, searcher: searcher,
redis: client, redis: client,
levelDB: levelDB, levelDB: levelDB,
captcha: captcha, captchaService: captcha,
licenseService: licenseService, licenseService: licenseService,
userService: userService, userService: userService,
wxLoginService: wxLoginService,
ipSearcher: ipSearcher,
}
}
// RegisterRoutes 注册路由
func (h *UserHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/user/")
// 公开接口,不需要授权
group.POST("register", h.Register)
group.POST("login", h.Login)
group.POST("resetPass", h.ResetPass)
group.GET("login/qrcode", h.GetWxLoginQRCode)
group.POST("login/callback", h.WxLoginCallback)
group.GET("login/status", h.GetWxLoginState)
group.GET("logout", h.Logout)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("session", h.Session)
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.GET("signin", h.SignIn)
} }
} }
@@ -80,12 +112,13 @@ func (h *UserHandler) Register(c *gin.Context) {
return return
} }
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" { // 人机验证
if h.captchaService.GetConfig().Enabled {
var check bool var check bool
if data.X != 0 { if data.X != 0 {
check = h.captcha.SlideCheck(data) check = h.captchaService.SlideCheck(data)
} else { } else {
check = h.captcha.Check(data) check = h.captchaService.Check(data)
} }
if !check { if !check {
resp.ERROR(c, "请先完人机验证") resp.ERROR(c, "请先完人机验证")
@@ -125,30 +158,8 @@ func (h *UserHandler) Register(c *gin.Context) {
} }
} }
// 验证邀请码
inviteCode := model.InviteCode{}
if data.InviteCode != "" {
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "无效的邀请码")
return
}
}
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatConfig: "{}",
ChatModels: "{}",
Power: h.App.SysConfig.InitPower,
}
// check if the username is existing // check if the username is existing
user := model.User{Username: data.Username, Password: data.Password}
var item model.User var item model.User
session := h.DB.Session(&gorm.Session{}) session := h.DB.Session(&gorm.Session{})
if data.Mobile != "" { if data.Mobile != "" {
@@ -168,78 +179,19 @@ func (h *UserHandler) Register(c *gin.Context) {
return return
} }
// 被邀请人也获得赠送算力 user, err := h.createNewUser(user, data.InviteCode)
if data.InviteCode != "" { if err != nil {
user.Power += h.App.SysConfig.InvitePower
}
if h.licenseService.GetLicense().Configs.DeCopy {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
defaultNickname := h.App.SysConfig.DefaultNickname
if defaultNickname == "" {
defaultNickname = "极客学长"
}
user.Nickname = fmt.Sprintf("%s@%d", defaultNickname, utils.RandomNumber(6))
}
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 记录邀请关系 token, err := h.doLogin(&user, c.ClientIP())
if data.InviteCode != "" {
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "Invite",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
// 添加邀请记录
err := tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil { if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 保存到 redis
key = fmt.Sprintf("users/%d", user.Id) resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
} }
// Login 用户登录 // Login 用户登录
@@ -255,15 +207,12 @@ func (h *UserHandler) Login(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
verifyKey := fmt.Sprintf("users/verify/%s", data.Username) if h.captchaService.GetConfig().Enabled {
needVerify, err := h.redis.Get(c, verifyKey).Bool()
if h.App.SysConfig.EnabledVerify && needVerify {
var check bool var check bool
if data.X != 0 { if data.X != 0 {
check = h.captcha.SlideCheck(data) check = h.captchaService.SlideCheck(data)
} else { } else {
check = h.captcha.Check(data) check = h.captchaService.Check(data)
} }
if !check { if !check {
resp.ERROR(c, "请先完人机验证") resp.ERROR(c, "请先完人机验证")
@@ -274,54 +223,28 @@ func (h *UserHandler) Login(c *gin.Context) {
var user model.User var user model.User
res := h.DB.Where("username = ?", data.Username).First(&user) res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil { if res.Error != nil {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名不存在") resp.ERROR(c, "用户名不存在")
return return
} }
password := utils.GenPassword(data.Password, user.Salt) password := utils.GenPassword(data.Password, user.Salt)
if password != user.Password { if password != user.Password {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名或密码错误") resp.ERROR(c, "用户名或密码错误")
return return
} }
if user.Status == false { if !user.Status {
resp.ERROR(c, "该用户已被禁止登录,请联系管理员") resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
return return
} }
// 更新最后登录时间和IP token, err := h.doLogin(&user, c.ClientIP())
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil { if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 保存到 redis
sessionKey := fmt.Sprintf("users/%d", user.Id) resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
// 移除登录行为验证码
h.redis.Del(c, verifyKey)
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
} }
// Logout 注 销 // Logout 注 销
@@ -333,134 +256,165 @@ func (h *UserHandler) Logout(c *gin.Context) {
resp.SUCCESS(c) resp.SUCCESS(c)
} }
// CLogin 第三方登录请求二维码 // GetWxLoginQRCode 获取微信登录二维码URL
func (h *UserHandler) CLogin(c *gin.Context) { func (h *UserHandler) GetWxLoginQRCode(c *gin.Context) {
returnURL := h.GetTrim(c, "return_url") if !h.wxLoginService.GetConfig().Enabled {
var res types.BizVo resp.ERROR(c, "微信登录功能未启用")
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL) return
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}). }
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)). if h.wxLoginService.GetConfig().ApiKey == "" {
SetSuccessResult(&res). resp.ERROR(c, "微信登录服务令牌未配置")
Post(apiURL) return
}
state := utils.RandString(32)
qrCodeURL, err := h.wxLoginService.GetLoginQrCodeUrl(state)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success { resp.SUCCESS(c, gin.H{
resp.ERROR(c, "error with http response: "+res.Message) "url": qrCodeURL,
return "state": state,
} })
resp.SUCCESS(c, res.Data)
} }
// CLoginCallback 第三方登录回调 // 查询微信登录状态
func (h *UserHandler) CLoginCallback(c *gin.Context) { func (h *UserHandler) GetWxLoginState(c *gin.Context) {
loginType := c.Query("login_type") state := c.Query("state")
code := c.Query("code") if state == "" {
userId := h.GetInt(c, "user_id", 0) resp.ERROR(c, "参数错误")
action := c.Query("action") return
}
var res types.BizVo status, err := h.wxLoginService.GetLoginStatus(state)
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status) if status.Status != service.LoginStatusSuccess {
resp.SUCCESS(c, status)
return return
} }
if res.Code != types.Success { // 登录成功
resp.ERROR(c, "error with http response: "+res.Message)
return
}
// login successfully
data := res.Data.(map[string]interface{})
var user model.User var user model.User
if action == "bind" && userId > 0 { h.DB.Where("openid = ?", status.OpenID).First(&user)
err = h.DB.Where("openid", data["openid"]).First(&user).Error if user.Id == 0 {
if err == nil { // 创建新用户
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑") user, err = h.createNewUser(model.User{OpenId: status.OpenID}, "")
return
}
err = h.DB.Where("id", userId).First(&user).Error
if err != nil { if err != nil {
resp.ERROR(c, "绑定用户不存在") resp.ERROR(c, err.Error())
return return
} }
}
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error token, err := h.doLogin(&user, c.ClientIP())
if err != nil { if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error()) resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": ""})
return return
} }
session := gin.H{} status.Status = service.LoginStatusExpired
tx := h.DB.Where("openid", data["openid"]).First(&user) h.wxLoginService.SetLoginStatus(state, *status)
if tx.Error != nil {
// create new user
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
salt := utils.RandString(8) status.Status = service.LoginStatusSuccess
password := fmt.Sprintf("%d", utils.RandomNumber(8)) status.Token = token
user = model.User{ resp.SUCCESS(c, status)
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)), }
Password: utils.GenPassword(password, salt),
Avatar: fmt.Sprintf("%s", data["avatar"]),
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
Power: h.App.SysConfig.InitPower,
OpenId: fmt.Sprintf("%s", data["openid"]),
Nickname: fmt.Sprintf("%s", data["nickname"]),
}
tx = h.DB.Create(&user) // createNewUser 创建新用户
if tx.Error != nil { func (h *UserHandler) createNewUser(user model.User, inviteCode string) (model.User, error) {
resp.ERROR(c, "保存数据失败") if user.OpenId != "" {
logger.Error(tx.Error) user.Platform = "wechat"
return user.Nickname = fmt.Sprintf("微信用户@%d", utils.RandomNumber(6))
user.Username = fmt.Sprintf("wx@%d", utils.RandomNumber(8))
user.Password = "geekai123"
} else {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
if user.Username == "" || user.Password == "" {
return user, fmt.Errorf("用户名或密码不能为空")
} }
session["username"] = user.Username
session["password"] = password
} else { // login directly
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
} }
salt := utils.RandString(8)
user.Salt = salt
user.Password = utils.GenPassword(user.Password, salt)
user.Avatar = "/images/avatar/user.png"
user.Status = true
user.ChatRoles = utils.JsonEncode([]string{"gpt"})
user.ChatConfig = "{}"
user.ChatModels = "{}"
user.Power = h.App.SysConfig.Base.InitPower
// 创建用户
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
return user, err
}
// 记录邀请关系
if inviteCode != "" {
inviteCode := model.InviteCode{}
err := h.DB.Where("code = ?", inviteCode).First(&inviteCode).Error
if err != nil {
return user, fmt.Errorf("无效的邀请码")
}
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", inviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.Base.InvitePower > 0 {
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "Invite",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
return user, err
}
// 添加邀请记录
err = tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower),
}).Error
if err != nil {
tx.Rollback()
return user, err
}
}
}
tx.Commit()
return user, nil
}
// doLogin 执行登录操作
func (h *UserHandler) doLogin(user *model.User, ip string) (string, error) {
// 更新最后登录时间和IP
user.LastLoginIp = ip
user.LastLoginAt = time.Now().Unix()
err := h.DB.Model(user).Updates(user).Error
if err != nil {
return "", fmt.Errorf("failed to update user: %v", err)
}
// 记录登录日志
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: ip,
LoginAddress: utils.Ip2Region(h.ipSearcher, ip),
})
// 创建 token // 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id, "user_id": user.Id,
@@ -468,17 +422,42 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
}) })
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil { if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error()) return "", fmt.Errorf("failed to generate token: %v", err)
return
} }
// 保存到 redis // 保存到 redis
key := fmt.Sprintf("users/%d", user.Id) sessionKey := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { if _, err = h.redis.Set(context.Background(), sessionKey, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error()) return "", fmt.Errorf("error with save token: %v", err)
}
return tokenString, nil
}
// WxLoginCallback 微信登录回调处理
func (h *UserHandler) WxLoginCallback(c *gin.Context) {
var data struct {
OpenID string `json:"openid"`
State string `json:"state"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return return
} }
session["token"] = tokenString
resp.SUCCESS(c, session) if data.OpenID == "" || data.State == "" {
resp.ERROR(c, "参数错误")
return
}
// 设置登录状态
status := service.LoginStatus{
Status: service.LoginStatusSuccess,
OpenID: data.OpenID,
}
h.wxLoginService.SetLoginStatus(data.State, status)
resp.SUCCESS(c, status)
} }
// Session 获取/验证会话 // Session 获取/验证会话
@@ -742,11 +721,11 @@ func (h *UserHandler) SignIn(c *gin.Context) {
// 签到 // 签到
h.levelDB.Put(key, true) h.levelDB.Put(key, true)
if h.App.SysConfig.DailyPower > 0 { if h.App.SysConfig.Base.DailyPower > 0 {
h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{ h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{
Type: types.PowerSignIn, Type: types.PowerSignIn,
Model: "SignIn", Model: "SignIn",
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower), Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower),
}) })
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@@ -10,8 +10,10 @@ package handler
import ( import (
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/middleware"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/video" "geekai/service/video"
"geekai/store/model" "geekai/store/model"
@@ -26,20 +28,37 @@ import (
type VideoHandler struct { type VideoHandler struct {
BaseHandler BaseHandler
videoService *video.Service videoService *video.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
userService *service.UserService userService *service.UserService
moderationManager *moderation.ServiceManager
} }
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler { func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *VideoHandler {
return &VideoHandler{ return &VideoHandler{
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
DB: db, DB: db,
}, },
videoService: service, videoService: service,
uploader: uploader, uploader: uploader,
userService: userService, userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由
func (h *VideoHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/video/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("luma/create", h.LumaCreate)
group.POST("keling/create", h.KeLingCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
} }
} }
@@ -62,13 +81,36 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
return return
} }
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceVideo,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }
if user.Power < h.App.SysConfig.LumaPower { if user.Power < h.App.SysConfig.Base.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!") resp.ERROR(c, "您的算力不足,请充值后再试!")
return return
} }
@@ -85,14 +127,14 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
Type: types.VideoLuma, Type: types.VideoLuma,
Prompt: data.Prompt, Prompt: data.Prompt,
Params: params, Params: params,
TranslateModelId: h.App.SysConfig.AssistantModelId, TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
} }
// 插入数据库 // 插入数据库
job := model.VideoJob{ job := model.VideoJob{
UserId: uint(userId), UserId: uint(userId),
Type: types.VideoLuma, Type: types.VideoLuma,
Prompt: data.Prompt, Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower, Power: h.App.SysConfig.Base.LumaPower,
TaskInfo: utils.JsonEncode(task), TaskInfo: utils.JsonEncode(task),
} }
tx := h.DB.Create(&job) tx := h.DB.Create(&job)
@@ -147,7 +189,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
// 计算当前任务所需算力 // 计算当前任务所需算力
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration) key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
power := h.App.SysConfig.KeLingPowers[key] power := h.App.SysConfig.Base.KeLingPowers[key]
if power == 0 { if power == 0 {
resp.ERROR(c, "当前模型暂不支持") resp.ERROR(c, "当前模型暂不支持")
return return
@@ -181,7 +223,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
Type: types.VideoKeLing, Type: types.VideoKeLing,
Prompt: data.Prompt, Prompt: data.Prompt,
Params: params, Params: params,
TranslateModelId: h.App.SysConfig.AssistantModelId, TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
Channel: data.Channel, Channel: data.Channel,
} }
// 插入数据库 // 插入数据库

View File

@@ -19,6 +19,7 @@ import (
"geekai/service/dalle" "geekai/service/dalle"
"geekai/service/jimeng" "geekai/service/jimeng"
"geekai/service/mj" "geekai/service/mj"
"geekai/service/moderation"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/payment" "geekai/service/payment"
"geekai/service/sd" "geekai/service/sd"
@@ -30,7 +31,7 @@ import (
"log" "log"
"os" "os"
"os/signal" "os/signal"
"strconv" "runtime/debug"
"syscall" "syscall"
"time" "time"
@@ -71,15 +72,16 @@ func main() {
if configFile == "" { if configFile == "" {
configFile = "config.toml" configFile = "config.toml"
} }
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
logger.Info("Loading config file: ", configFile) logger.Info("Loading config file: ", configFile)
if !debug { defer func() {
defer func() { if err := recover(); err != nil {
if err := recover(); err != nil { logger.Error("Panic Error:", err)
logger.Error("Panic Error:", err) // 打印堆栈信息
if os.Getenv("GEEKAI_DEBUG") == "true" {
debug.PrintStack()
} }
}() }
} }()
app := fx.New( app := fx.New(
// 初始化配置应用配置 // 初始化配置应用配置
@@ -89,16 +91,16 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
config.Path = configFile config.Path = configFile
if debug {
_ = core.SaveConfig(config)
}
return config return config
}), }),
// 创建应用服务 // 创建应用服务
fx.Provide(core.NewServer), fx.Provide(core.NewServer),
// 初始化 // 初始化
fx.Invoke(func(s *core.AppServer, client *redis.Client) { fx.Invoke(func(s *core.AppServer, client *redis.Client) {
s.Init(debug, client) s.Init(client)
}),
fx.Provide(func(db *gorm.DB) *types.SystemConfig {
return core.LoadSystemConfig(db)
}), }),
// 初始化数据库 // 初始化数据库
@@ -126,7 +128,7 @@ func main() {
}), }),
// 创建控制器 // 创建控制器
fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewChatAppHandler),
fx.Provide(handler.NewUserHandler), fx.Provide(handler.NewUserHandler),
fx.Provide(handler.NewChatHandler), fx.Provide(handler.NewChatHandler),
fx.Provide(handler.NewNetHandler), fx.Provide(handler.NewNetHandler),
@@ -143,6 +145,12 @@ func main() {
fx.Provide(handler.NewPowerLogHandler), fx.Provide(handler.NewPowerLogHandler),
fx.Provide(handler.NewJimengHandler), fx.Provide(handler.NewJimengHandler),
fx.Provide(service.NewMigrationService),
fx.Invoke(func(migrationService *service.MigrationService) {
migrationService.StartMigrate()
}),
// 管理后台控制器
fx.Provide(admin.NewConfigHandler), fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler), fx.Provide(admin.NewAdminHandler),
fx.Provide(admin.NewApiKeyHandler), fx.Provide(admin.NewApiKeyHandler),
@@ -153,34 +161,23 @@ func main() {
fx.Provide(admin.NewChatModelHandler), fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler), fx.Provide(admin.NewProductHandler),
fx.Provide(admin.NewOrderHandler), fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler), fx.Provide(admin.NewPowerLogHandler),
fx.Provide(admin.NewAdminJimengHandler), fx.Provide(admin.NewAdminJimengHandler),
// 创建服务
fx.Provide(sms.NewSendServiceManager),
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
return service.NewCaptchaService(config.ApiConfig)
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(dalle.NewService),
fx.Invoke(func(s *dalle.Service) {
s.Run()
s.DownloadImages()
s.CheckTaskStatus()
}),
fx.Provide(service.NewMigrationService),
fx.Invoke(func(s *service.MigrationService) {
s.Migrate()
}),
// 邮件服务 // 邮件服务
fx.Provide(service.NewSmtpService), fx.Provide(service.NewSmtpService),
// License 服务 // License 服务
fx.Provide(service.NewLicenseService), fx.Provide(service.NewLicenseService),
fx.Invoke(func(licenseService *service.LicenseService) { fx.Invoke(func(licenseService *service.LicenseService) {
// licenseService.SyncLicense() licenseService.SyncLicense()
}),
// Dalle 服务
fx.Provide(dalle.NewService),
fx.Invoke(func(s *dalle.Service) {
s.Run()
s.DownloadImages()
s.CheckTaskStatus()
}), }),
// MidJourney service pool // MidJourney service pool
@@ -213,302 +210,179 @@ func main() {
}), }),
// 即梦AI 服务 // 即梦AI 服务
fx.Provide(jimeng.NewClient),
fx.Provide(jimeng.NewService), fx.Provide(jimeng.NewService),
fx.Invoke(func(service *jimeng.Service) { fx.Invoke(func(service *jimeng.Service) {
service.Start() service.Start()
}), }),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewJPayService),
fx.Provide(payment.NewWechatService),
fx.Provide(service.NewSnowflake), fx.Provide(service.NewSnowflake),
fx.Provide(service.NewXXLJobExecutor),
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) { // 创建短信服务
if config.XXLConfig.Enabled { fx.Provide(sms.NewAliYunSmsService),
go func() { fx.Provide(sms.NewBaoSmsService),
log.Fatal(exec.Run()) fx.Provide(sms.NewSmsManager),
}() fx.Provide(func(config *types.SystemConfig) *service.CaptchaService {
} return service.NewCaptchaService(config.Captcha)
}),
fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService {
return service.NewWxLoginService(config.WxLogin, client)
}),
// 支付服务
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewEPayService),
fx.Provide(payment.NewWxpayService),
// 文件上传服务
fx.Provide(oss.NewLocalStorage),
fx.Provide(oss.NewMiniOss),
fx.Provide(oss.NewQiNiuOss),
fx.Provide(oss.NewAliYunOss),
fx.Provide(oss.NewUploaderManager),
// 用户服务
fx.Provide(service.NewUserService),
// 文本审查服务
fx.Provide(moderation.NewGiteeAIModeration),
fx.Provide(moderation.NewBaiduAIModeration),
fx.Provide(moderation.NewTencentAIModeration),
fx.Provide(moderation.NewServiceManager),
fx.Provide(admin.NewModerationHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ModerationHandler) {
h.RegisterRoutes()
}), }),
// 注册路由 // 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatAppHandler) {
group := s.Engine.Group("/api/app/") h.RegisterRoutes()
group.GET("list", h.List)
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateRole)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) { fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
group := s.Engine.Group("/api/user/") h.RegisterRoutes()
group.POST("register", h.Register)
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.POST("resetPass", h.ResetPass)
group.GET("clogin", h.CLogin)
group.GET("clogin/callback", h.CLoginCallback)
group.GET("signin", h.SignIn)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
group := s.Engine.Group("/api/chat/") h.RegisterRoutes()
group.Any("message", h.Chat)
group.GET("list", h.List)
group.GET("detail", h.Detail)
group.POST("update", h.Update)
group.GET("remove", h.Remove)
group.GET("history", h.History)
group.GET("clear", h.Clear)
group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
group.POST("tts", h.TextToSpeech)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) { fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
s.Engine.POST("/api/upload", h.Upload) h.RegisterRoutes()
s.Engine.POST("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
s.Engine.GET("/api/download", h.Download)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/") h.RegisterRoutes()
group.POST("code", h.SendCode)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) { fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
group := s.Engine.Group("/api/captcha/") h.RegisterRoutes()
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) { fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
group := s.Engine.Group("/api/redeem/") h.RegisterRoutes()
group.POST("verify", h.Verify)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/") h.RegisterRoutes()
group.POST("image", h.Image)
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") h.RegisterRoutes()
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/") h.RegisterRoutes()
group.GET("get", h.Get)
group.GET("license", h.License)
}), }),
// 管理后台控制器 // 管理后台路由注册
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/config") h.RegisterRoutes()
group.POST("update", h.Update)
group.GET("get", h.Get)
group.POST("active", h.Active)
group.GET("fixData", h.FixData)
group.GET("license", h.GetLicense)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/") h.RegisterRoutes()
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/") h.RegisterRoutes()
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) { fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
group := s.Engine.Group("/api/admin/user/") h.RegisterRoutes()
group.GET("list", h.List)
group.POST("save", h.Save)
group.GET("remove", h.Remove)
group.GET("loginLog", h.LoginLog)
group.GET("genLoginLink", h.GenLoginLink)
group.POST("resetPass", h.ResetPass)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
group := s.Engine.Group("/api/admin/role/") h.RegisterRoutes()
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) { fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
group := s.Engine.Group("/api/admin/redeem/") h.RegisterRoutes()
group.GET("list", h.List)
group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("export", h.Export)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/") h.RegisterRoutes()
group.GET("stats", h.Stats)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) {
group := s.Engine.Group("/api/model/") h.RegisterRoutes()
group.GET("list", h.List)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) {
group := s.Engine.Group("/api/admin/model/") h.RegisterRoutes()
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/") h.RegisterRoutes()
group.POST("doPay", h.Pay) h.StartSyncOrders()
group.GET("payWays", h.GetPayWays)
group.POST("notify/alipay", h.AlipayNotify)
group.GET("notify/geek", h.GeekPayNotify)
group.POST("notify/wechat", h.WechatPayNotify)
group.POST("notify/hupi", h.HuPiPayNotify)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/") h.RegisterRoutes()
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) { fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) {
group := s.Engine.Group("/api/admin/order/") h.RegisterRoutes()
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) { fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
group := s.Engine.Group("/api/order/") h.RegisterRoutes()
group.GET("list", h.List)
group.GET("query", h.Query)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
group := s.Engine.Group("/api/product/") h.RegisterRoutes()
group.GET("list", h.List)
}), }),
fx.Provide(handler.NewInviteHandler), fx.Provide(handler.NewInviteHandler),
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) { fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/") h.RegisterRoutes()
group.GET("code", h.Code)
group.GET("list", h.List)
group.GET("hits", h.Hits)
}), }),
fx.Provide(admin.NewFunctionHandler), fx.Provide(admin.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) { fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
group := s.Engine.Group("/api/admin/function/") h.RegisterRoutes()
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("token", h.GenToken)
}), }),
fx.Provide(admin.NewUploadHandler), fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) { fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload) h.RegisterRoutes()
}), }),
fx.Provide(handler.NewFunctionHandler), fx.Provide(handler.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) { fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
group := s.Engine.Group("/api/function/") h.RegisterRoutes()
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
group.POST("websearch", h.WebSearch)
group.GET("list", h.List)
}), }),
fx.Provide(admin.NewChatHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/") h.RegisterRoutes()
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
group := s.Engine.Group("/api/powerLog/") h.RegisterRoutes()
group.POST("list", h.List)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) { fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
group := s.Engine.Group("/api/admin/powerLog/") h.RegisterRoutes()
group.POST("list", h.List)
}), }),
fx.Provide(admin.NewMenuHandler), fx.Provide(admin.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) { fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
group := s.Engine.Group("/api/admin/menu/") h.RegisterRoutes()
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}), }),
fx.Provide(handler.NewMenuHandler), fx.Provide(handler.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
group := s.Engine.Group("/api/menu/") h.RegisterRoutes()
group.GET("list", h.List)
}), }),
fx.Provide(handler.NewMarkMapHandler), fx.Provide(handler.NewMarkMapHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
s.Engine.POST("/api/markMap/gen", h.Generate) h.RegisterRoutes()
}), }),
fx.Provide(handler.NewDallJobHandler), fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall") h.RegisterRoutes()
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.GET("models", h.GetModels)
}), }),
fx.Provide(handler.NewSunoHandler), fx.Provide(handler.NewSunoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
group := s.Engine.Group("/api/suno") h.RegisterRoutes()
group.POST("create", h.Create)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.POST("update", h.Update)
group.GET("detail", h.Detail)
group.GET("play", h.Play)
}), }),
fx.Provide(handler.NewVideoHandler), fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) { fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video") h.RegisterRoutes()
group.POST("luma/create", h.LumaCreate)
group.POST("keling/create", h.KeLingCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}), }),
// 即梦AI 路由 // 即梦AI 路由
@@ -520,30 +394,19 @@ func main() {
}), }),
fx.Provide(admin.NewChatAppTypeHandler), fx.Provide(admin.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
group := s.Engine.Group("/api/admin/app/type") h.RegisterRoutes()
group.POST("save", h.Save)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}), }),
fx.Provide(handler.NewChatAppTypeHandler), fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type") h.RegisterRoutes()
group.GET("list", h.List)
}), }),
fx.Provide(handler.NewTestHandler), fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test") h.RegisterRoutes()
group.Any("sse", h.PostTest, h.SseTest)
}), }),
fx.Provide(handler.NewPromptHandler), fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt") h.RegisterRoutes()
group.POST("/lyric", h.Lyric)
group.POST("/image", h.Image)
group.POST("/video", h.Video)
group.POST("/meta", h.MetaPrompt)
}), }),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) { fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() { go func() {
@@ -568,23 +431,15 @@ func main() {
}), }),
fx.Provide(admin.NewImageHandler), fx.Provide(admin.NewImageHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
group := s.Engine.Group("/api/admin/image") h.RegisterRoutes()
group.POST("/list/mj", h.MjList)
group.POST("/list/sd", h.SdList)
group.POST("/list/dall", h.DallList)
group.GET("/remove", h.Remove)
}), }),
fx.Provide(admin.NewMediaHandler), fx.Provide(admin.NewMediaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) { fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
group := s.Engine.Group("/api/admin/media") h.RegisterRoutes()
group.POST("/suno", h.SunoList)
group.POST("/videos", h.Videos)
group.GET("/remove", h.Remove)
}), }),
fx.Provide(handler.NewRealtimeHandler), fx.Provide(handler.NewRealtimeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) { fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
s.Engine.Any("/api/realtime", h.Connection) h.RegisterRoutes()
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
}), }),
) )
// 启动应用程序 // 启动应用程序

View File

@@ -8,35 +8,38 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
"github.com/imroc/req/v3"
"time" "time"
"github.com/imroc/req/v3"
) )
type CaptchaService struct { type CaptchaService struct {
config types.ApiConfig config types.CaptchaConfig
client *req.Client client *req.Client
} }
func NewCaptchaService(config types.ApiConfig) *CaptchaService { func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
return &CaptchaService{ return &CaptchaService{
config: config, config: captchaConfig,
client: req.C().SetTimeout(10 * time.Second), client: req.C().SetTimeout(10 * time.Second),
} }
} }
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
s.config = config
}
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
return s.config
}
func (s *CaptchaService) Get() (interface{}, error) { func (s *CaptchaService) Get() (interface{}, error) {
if s.config.Token == "" { url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
var res types.BizVo var res types.BizVo
r, err := s.client.R(). r, err := s.client.R().
SetHeader("AppId", s.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err) return nil, fmt.Errorf("请求 API 失败:%v", err)
@@ -49,12 +52,11 @@ func (s *CaptchaService) Get() (interface{}, error) {
return res.Data, nil return res.Data, nil
} }
func (s *CaptchaService) Check(data interface{}) bool { func (s *CaptchaService) Check(data any) bool {
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL) url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
var res types.BizVo var res types.BizVo
r, err := s.client.R(). r, err := s.client.R().
SetHeader("AppId", s.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetBodyJsonMarshal(data). SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url) SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
@@ -68,16 +70,11 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true return true
} }
func (s *CaptchaService) SlideGet() (interface{}, error) { func (s *CaptchaService) SlideGet() (any, error) {
if s.config.Token == "" { url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
var res types.BizVo var res types.BizVo
r, err := s.client.R(). r, err := s.client.R().
SetHeader("AppId", s.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err) return nil, fmt.Errorf("请求 API 失败:%v", err)
@@ -90,12 +87,11 @@ func (s *CaptchaService) SlideGet() (interface{}, error) {
return res.Data, nil return res.Data, nil
} }
func (s *CaptchaService) SlideCheck(data interface{}) bool { func (s *CaptchaService) SlideCheck(data any) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL) url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
var res types.BizVo var res types.BizVo
r, err := s.client.R(). r, err := s.client.R().
SetHeader("AppId", s.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetBodyJsonMarshal(data). SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url) SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {

View File

@@ -1,333 +0,0 @@
package crawler
import (
"context"
"errors"
"fmt"
"geekai/logger"
"net/url"
"strings"
"time"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/launcher"
"github.com/go-rod/rod/lib/proto"
)
// Service 网络爬虫服务
type Service struct {
browser *rod.Browser
}
// NewService 创建一个新的爬虫服务
func NewService() (*Service, error) {
// 启动浏览器
path, _ := launcher.LookPath()
u := launcher.New().Bin(path).
Headless(true). // 无头模式
Set("disable-web-security", ""). // 禁用网络安全限制
Set("disable-gpu", ""). // 禁用 GPU 加速
Set("no-sandbox", ""). // 禁用沙箱模式
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
MustLaunch()
browser := rod.New().ControlURL(u).MustConnect()
return &Service{
browser: browser,
}, nil
}
// SearchResult 搜索结果
type SearchResult struct {
Title string `json:"title"` // 标题
URL string `json:"url"` // 链接
Content string `json:"content"` // 内容摘要
}
// WebSearch 网络搜索
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
if keyword == "" {
return nil, errors.New("搜索关键词不能为空")
}
if maxPages <= 0 {
maxPages = 1
}
if maxPages > 10 {
maxPages = 10 // 最多搜索 10 页
}
results := make([]SearchResult, 0)
// 使用百度搜索
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
// 设置页面超时
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 创建页面
page := s.browser.MustPage()
defer page.MustClose()
// 设置视口大小
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
Width: 1280,
Height: 800,
})
if err != nil {
return nil, fmt.Errorf("设置视口失败: %v", err)
}
// 导航到搜索页面
err = page.Context(ctx).Navigate(searchURL)
if err != nil {
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
}
// 等待搜索结果加载完成
err = page.WaitLoad()
if err != nil {
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
}
// 分析当前页面的搜索结果
for i := 0; i < maxPages; i++ {
if i > 0 {
// 点击下一页按钮
nextPage, err := page.Element("a.n")
if err != nil || nextPage == nil {
break // 没有下一页
}
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
if err != nil {
break // 点击下一页失败
}
// 等待新页面加载
err = page.WaitLoad()
if err != nil {
break
}
}
// 提取搜索结果
resultElements, err := page.Elements(".result, .c-container")
if err != nil || resultElements == nil {
continue
}
for _, result := range resultElements {
// 获取标题
titleElement, err := result.Element("h3, .t")
if err != nil || titleElement == nil {
continue
}
title, err := titleElement.Text()
if err != nil {
continue
}
// 获取 URL
linkElement, err := titleElement.Element("a")
if err != nil || linkElement == nil {
continue
}
href, err := linkElement.Attribute("href")
if err != nil || href == nil {
continue
}
// 获取内容摘要 - 尝试多个可能的选择器
var contentElement *rod.Element
var content string
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
for _, selector := range selectors {
contentElement, err = result.Element(selector)
if err == nil && contentElement != nil {
content, _ = contentElement.Text()
if content != "" {
break
}
}
}
// 如果所有选择器都失败,尝试直接从结果块中提取文本
if content == "" {
// 获取结果元素的所有文本
fullText, err := result.Text()
if err == nil && fullText != "" {
// 简单处理:从全文中移除标题,剩下的可能是摘要
fullText = strings.Replace(fullText, title, "", 1)
// 清理文本
content = strings.TrimSpace(fullText)
// 限制内容长度
if len(content) > 200 {
content = content[:200] + "..."
}
}
}
// 添加到结果集
results = append(results, SearchResult{
Title: title,
URL: *href,
Content: content,
})
// 限制结果数量,每页最多 10 条
if len(results) >= 10*maxPages {
break
}
}
}
// 获取真实 URL百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL
for i, result := range results {
realURL, err := s.getRedirectURL(result.URL)
if err == nil && realURL != "" {
results[i].URL = realURL
}
}
return results, nil
}
// 获取真实 URL
func (s *Service) getRedirectURL(shortURL string) (string, error) {
// 创建页面
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
if err != nil {
return shortURL, err // 返回原始URL
}
defer func() {
_ = page.Close()
}()
// 导航到短链接
err = page.Navigate(shortURL)
if err != nil {
return shortURL, err // 返回原始URL
}
// 等待重定向完成
time.Sleep(2 * time.Second)
// 获取当前 URL
info, err := page.Info()
if err != nil {
return shortURL, err // 返回原始URL
}
return info.URL, nil
}
// Close 关闭浏览器
func (s *Service) Close() error {
if s.browser != nil {
err := s.browser.Close()
s.browser = nil
return err
}
return nil
}
// SearchWeb 封装的搜索方法
func SearchWeb(keyword string, maxPages int) (string, error) {
// 添加panic恢复机制
defer func() {
if r := recover(); r != nil {
log := logger.GetLogger()
log.Errorf("爬虫服务崩溃: %v", r)
}
}()
service, err := NewService()
if err != nil {
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
}
defer service.Close()
// 设置超时上下文
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// 使用goroutine和通道来处理超时
resultChan := make(chan []SearchResult, 1)
errChan := make(chan error, 1)
go func() {
results, err := service.WebSearch(keyword, maxPages)
if err != nil {
errChan <- err
return
}
resultChan <- results
}()
// 等待结果或超时
select {
case <-ctx.Done():
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
case err := <-errChan:
return "", fmt.Errorf("搜索失败: %v", err)
case results := <-resultChan:
if len(results) == 0 {
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
}
// 格式化结果
var builder strings.Builder
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
for i, result := range results {
// // 尝试打开链接获取实际内容
// page := service.browser.MustPage()
// defer page.MustClose()
// // 设置页面超时
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
// defer pageCancel()
// // 导航到目标页面
// err := page.Context(pageCtx).Navigate(result.URL)
// if err == nil {
// // 等待页面加载
// _ = page.WaitLoad()
// // 获取页面标题
// title, err := page.Eval("() => document.title")
// if err == nil && title.Value.String() != "" {
// result.Title = title.Value.String()
// }
// // 获取页面主要内容
// if content, err := page.Element("body"); err == nil {
// if text, err := content.Text(); err == nil {
// // 清理并截取内容
// text = strings.TrimSpace(text)
// if len(text) > 200 {
// text = text[:200] + "..."
// }
// result.Prompt = text
// }
// }
// }
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
if result.Content != "" {
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
}
builder.WriteString("\n")
}
return builder.String(), nil
}
}

View File

@@ -16,6 +16,7 @@ import (
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"strings"
"time" "time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
@@ -94,12 +95,14 @@ func (s *Service) Run() {
} }
type imgReq struct { type imgReq struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Image []string `json:"image,omitempty"`
N int `json:"n,omitempty"` Prompt string `json:"prompt"`
Size string `json:"size,omitempty"` N int `json:"n,omitempty"`
Quality string `json:"quality,omitempty"` Size string `json:"size,omitempty"`
Style string `json:"style,omitempty"` Quality string `json:"quality,omitempty"`
Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
} }
type imgRes struct { type imgRes struct {
@@ -122,15 +125,6 @@ type ErrRes struct {
func (s *Service) Image(task types.DallTask, sync bool) (string, error) { func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task) logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
}
var chatModel model.ChatModel var chatModel model.ChatModel
if task.ModelId > 0 { if task.ModelId > 0 {
@@ -160,12 +154,17 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL) apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{ reqBody := imgReq{
Model: chatModel.Value, Model: chatModel.Value,
Prompt: prompt, Prompt: task.Prompt,
N: 1, N: 1,
Size: task.Size, Size: task.Size,
Style: task.Style, Style: task.Style,
Quality: task.Quality, Quality: task.Quality,
} }
// 图片编辑
if len(task.Image) > 0 {
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
}
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody) logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json"). r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value). SetHeader("Authorization", "Bearer "+apiKey.Value).
@@ -188,7 +187,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
var imgURL string var imgURL string
var data = map[string]interface{}{ var data = map[string]interface{}{
"progress": 100, "progress": 100,
"prompt": prompt, "prompt": task.Prompt,
} }
// 如果返回的是base64则需要上传到oss // 如果返回的是base64则需要上传到oss
if res.Data[0].B64Json != "" { if res.Data[0].B64Json != "" {
@@ -210,11 +209,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
var content string var content string
if sync { if sync {
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url) content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", task.Prompt, imgURL)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
} }
return content, nil return content, nil

View File

@@ -3,8 +3,10 @@ package jimeng
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"geekai/core/types"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"github.com/volcengine/volc-sdk-golang/base" "github.com/volcengine/volc-sdk-golang/base"
"github.com/volcengine/volc-sdk-golang/service/visual" "github.com/volcengine/volc-sdk-golang/service/visual"
@@ -13,14 +15,22 @@ import (
// Client 即梦API客户端 // Client 即梦API客户端
type Client struct { type Client struct {
visual *visual.Visual visual *visual.Visual
config types.JimengConfig
} }
// NewClient 创建即梦API客户端 // NewClient 创建即梦API客户端
func NewClient(accessKey, secretKey string) *Client { func NewClient(sysConfig *types.SystemConfig) *Client {
client := &Client{}
client.UpdateConfig(sysConfig.Jimeng)
return client
}
func (c *Client) UpdateConfig(config types.JimengConfig) error {
// 使用官方SDK的visual实例 // 使用官方SDK的visual实例
visualInstance := visual.NewInstance() visualInstance := visual.NewInstance()
visualInstance.Client.SetAccessKey(accessKey) visualInstance.Client.SetAccessKey(config.AccessKey)
visualInstance.Client.SetSecretKey(secretKey) visualInstance.Client.SetSecretKey(config.SecretKey)
// 添加即梦AI专有的API配置 // 添加即梦AI专有的API配置
jimengApis := map[string]*base.ApiInfo{ jimengApis := map[string]*base.ApiInfo{
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
visualInstance.Client.ApiInfoList[name] = info visualInstance.Client.ApiInfoList[name] = info
} }
return &Client{ c.config = config
visual: visualInstance, c.visual = visualInstance
return c.testConnection()
}
// testConnection 测试即梦AI连接
func (c *Client) testConnection() error {
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
} }
_, err := c.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
} }
// SubmitTask 提交异步任务 // SubmitTask 提交异步任务

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@@ -16,8 +15,6 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/core/types"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
) )
@@ -36,17 +33,8 @@ type Service struct {
} }
// NewService 创建即梦服务 // NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service { func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli) taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
// 从数据库加载配置
var config model.Config
db.Where("name = ?", "Jimeng").First(&config)
var jimengConfig types.JimengConfig
if config.Id > 0 {
_ = utils.JsonDecode(config.Value, &jimengConfig)
}
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Service{ return &Service{
db: db, db: db,
@@ -378,7 +366,7 @@ func (s *Service) pollTaskStatus() {
for _, job := range jobs { for _, job := range jobs {
// 任务超时处理 // 任务超时处理
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) { if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
s.handleTaskError(job.Id, "task timeout") s.handleTaskError(job.Id, "task timeout")
continue continue
} }
@@ -391,7 +379,7 @@ func (s *Service) pollTaskStatus() {
}) })
if err != nil { if err != nil {
logger.Errorf("query jimeng task status failed: %v", err) s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
continue continue
} }
@@ -446,9 +434,7 @@ func (s *Service) pollTaskStatus() {
s.handleTaskError(job.Id, "task not found") s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired: case model.JMTaskStatusExpired:
// 任务过期 continue
s.handleTaskError(job.Id, "task expired")
default: default:
logger.Warnf("unknown task status: %s", resp.Data.Status) logger.Warnf("unknown task status: %s", resp.Data.Status)
} }
@@ -524,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
} }
return &job, nil return &job, nil
} }
// testConnection 测试即梦AI连接
func (s *Service) testConnection(accessKey, secretKey string) error {
testClient := NewClient(accessKey, secretKey)
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := testClient.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// UpdateClientConfig 更新客户端配置
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
// 创建新的客户端
newClient := NewClient(accessKey, secretKey)
// 测试新客户端是否可用
err := s.testConnection(accessKey, secretKey)
if err != nil {
return err
}
// 更新客户端
s.client = newClient
return nil
}
var defaultPower = types.JimengPower{
TextToImage: 20,
ImageToImage: 20,
ImageEdit: 20,
ImageEffects: 20,
TextToVideo: 300,
ImageToVideo: 300,
}
// GetConfig 获取即梦AI配置
func (s *Service) GetConfig() *types.JimengConfig {
var config model.Config
err := s.db.Where("name", "jimeng").First(&config).Error
if err != nil {
// 如果配置不存在,返回默认配置
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
var jimengConfig types.JimengConfig
err = utils.JsonDecode(config.Value, &jimengConfig)
if err != nil {
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
return &jimengConfig
}

View File

@@ -8,30 +8,37 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt" "fmt"
"geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/store" "geekai/store/model"
"geekai/utils"
"strings"
"time" "time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
) )
type LicenseService struct { type LicenseService struct {
config types.ApiConfig
levelDB *store.LevelDB
license *types.License license *types.License
urlWhiteList []string urlWhiteList []string
machineId string machineId string
db *gorm.DB
} }
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService { func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
var license types.License var machineId string
info, err := host.Info()
if err == nil {
machineId = info.HostID
}
logger.Infof("License: %+v", sysConfig.License)
return &LicenseService{ return &LicenseService{
config: server.Config.ApiConfig, license: &sysConfig.License,
levelDB: levelDB, machineId: machineId,
license: &license, db: db,
machineId: "",
} }
} }
@@ -46,15 +53,15 @@ type License struct {
} }
// ActiveLicense 激活 License // ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error { func (s *LicenseService) ActiveLicense(license string) error {
var res struct { var res struct {
Code types.BizCode `json:"code"` Code types.BizCode `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Data License `json:"data"` Data License `json:"data"`
} }
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active") apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
response, err := req.C().R(). response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}). SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL) SetSuccessResult(&res).Post(apiURL)
if err != nil { if err != nil {
return fmt.Errorf("发送激活请求失败: %v", err) return fmt.Errorf("发送激活请求失败: %v", err)
@@ -68,17 +75,24 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
return fmt.Errorf("激活失败:%v", res.Message) return fmt.Errorf("激活失败:%v", res.Message)
} }
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
return fmt.Errorf("License 已过期")
}
s.license = &types.License{ s.license = &types.License{
Key: license, Key: license,
MachineId: machineId, MachineId: s.machineId,
Configs: res.Data.Configs, Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt, ExpiredAt: res.Data.ExpiredAt,
IsActive: true, IsActive: true,
} }
err = s.levelDB.Put(types.LicenseKey, s.license)
// 保存 License 到数据库
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
if err != nil { if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err) return fmt.Errorf("保存 License 到数据库失败: %v", err)
} }
return nil return nil
} }
@@ -96,6 +110,11 @@ func (s *LicenseService) SyncLicense() {
s.license.IsActive = false s.license.IsActive = false
} else { } else {
s.license = license s.license = license
// 保存 License 到数据库
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
if err != nil {
logger.Errorf("保存 License 到数据库失败: %v", err)
}
} }
urls, err := s.fetchUrlWhiteList() urls, err := s.fetchUrlWhiteList()
@@ -109,33 +128,30 @@ func (s *LicenseService) SyncLicense() {
} }
func (s *LicenseService) fetchLicense() (*types.License, error) { func (s *LicenseService) fetchLicense() (*types.License, error) {
//var res struct { var res struct {
// Code types.BizCode `json:"code"` Code types.BizCode `json:"code"`
// Message string `json:"message"` Message string `json:"message"`
// Data License `json:"data"` Data License `json:"data"`
//} }
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
//response, err := req.C().R(). response, err := req.C().R().
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
// SetSuccessResult(&res).Post(apiURL) SetSuccessResult(&res).Post(apiURL)
//if err != nil { if err != nil {
// return nil, fmt.Errorf("发送激活请求失败: %v", err) return nil, fmt.Errorf("License 同步失败: %v", err)
//} }
//if response.IsErrorState() { if response.IsErrorState() {
// return nil, fmt.Errorf("激活失败:%v", response.Status) return nil, fmt.Errorf("License 同步失败:%v", response.Status)
//} }
//if res.Code != types.Success { if res.Code != types.Success {
// return nil, fmt.Errorf("激活失败:%v", res.Message) return nil, fmt.Errorf("License 同步失败:%v", res.Message)
//} }
return &types.License{ return &types.License{
Key: "abc", Key: res.Data.License,
MachineId: "abc", MachineId: res.Data.MachineId,
Configs: types.LicenseConfig{ Configs: res.Data.Configs,
UserNum: 10000, ExpiredAt: res.Data.ExpiredAt,
DeCopy: false,
},
ExpiredAt: 0,
IsActive: true, IsActive: true,
}, nil }, nil
} }
@@ -146,7 +162,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
Message string `json:"message"` Message string `json:"message"`
Data []string `json:"data"` Data []string `json:"data"`
} }
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls") apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL) response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err) return nil, fmt.Errorf("发送请求失败: %v", err)
@@ -163,35 +179,46 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
// GetLicense 获取许可信息 // GetLicense 获取许可信息
func (s *LicenseService) GetLicense() *types.License { func (s *LicenseService) GetLicense() *types.License {
if s.license == nil {
var config model.Config
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
if config.Value != "" {
utils.JsonDecode(config.Value, &s.license)
}
}
return s.license return s.license
} }
func (s *LicenseService) SetLicense(licenseKey string) {
s.license.Key = licenseKey
}
// IsValidApiURL 判断是否合法的中转 URL // IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error { func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行 // 获得许可授权的直接放行
return nil if s.license.IsActive {
//if s.license.IsActive { if s.license.MachineId != s.machineId {
// if s.license.MachineId != s.machineId { return errors.New("系统使用了盗版的许可证书")
// return errors.New("系统使用了盗版的许可证书") }
// }
// if time.Now().Unix() > s.license.ExpiredAt {
// if time.Now().Unix() > s.license.ExpiredAt { return errors.New("系统许可证书已经过期")
// return errors.New("系统许可证书已经过期") }
// } return nil
// return nil }
//}
// if len(s.urlWhiteList) == 0 {
//if len(s.urlWhiteList) == 0 { urls, err := s.fetchUrlWhiteList()
// urls, err := s.fetchUrlWhiteList() if err == nil {
// if err == nil { s.urlWhiteList = urls
// s.urlWhiteList = urls }
// } }
//}
// for _, v := range s.urlWhiteList {
//for _, v := range s.urlWhiteList { if strings.HasPrefix(uri, v) {
// if strings.HasPrefix(uri, v) { return nil
// return nil }
// } }
//} return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
} }

View File

@@ -1,52 +1,342 @@
package service package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved. // Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license // Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file. // that can be found in the LICENSE file.
// * @Author yangjian102621@163.com // @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"context"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store"
"geekai/store/model" "geekai/store/model"
"strings"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
// 迁移状态Redis key
MigrationStatusKey = "config_migration:status"
// 迁移完成标志
MigrationCompleted = "completed"
)
// MigrationService 配置迁移服务
type MigrationService struct { type MigrationService struct {
db *gorm.DB db *gorm.DB
redisClient *redis.Client
appConfig *types.AppConfig
levelDB *store.LevelDB
licenseService *LicenseService
} }
func NewMigrationService(db *gorm.DB) *MigrationService { func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
return &MigrationService{db: db} return &MigrationService{
db: db,
redisClient: redisClient,
appConfig: appConfig,
levelDB: levelDB,
licenseService: licenseService,
}
} }
func (s *MigrationService) Migrate() error { func (s *MigrationService) StartMigrate() {
err := s.db.AutoMigrate( go func() {
&model.AdminUser{}, s.MigrateConfig(s.appConfig)
&model.ApiKey{}, s.TableMigration()
&model.AppType{}, s.MigrateLicense()
&model.ChatItem{}, }()
&model.ChatMessage{}, }
&model.ChatModel{},
&model.ChatRole{}, // 迁移 License
&model.Config{}, func (s *MigrationService) MigrateLicense() {
&model.DallJob{}, key := "migrate:license"
&model.File{}, if s.redisClient.Get(context.Background(), key).Val() == "1" {
&model.Function{}, logger.Info("License 已迁移,跳过迁移")
&model.InviteCode{}, return
&model.InviteLog{}, }
&model.Menu{},
&model.MidJourneyJob{}, logger.Info("开始迁移 License...")
&model.Order{}, var license types.License
&model.PowerLog{}, err := s.levelDB.Get(types.LicenseKey, &license)
&model.Product{}, if err != nil {
&model.Redeem{}, license = types.License{
&model.SdJob{}, Key: "",
&model.SunoJob{}, MachineId: "",
&model.User{}, Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
&model.UserLoginLog{}, ExpiredAt: 0,
&model.VideoJob{}, IsActive: false,
) }
return err }
logger.Infof("迁移 License: %+v", license)
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
logger.Errorf("迁移 License 失败: %v", err)
return
}
s.licenseService.SetLicense(license.Key)
logger.Info("迁移 License 完成")
s.redisClient.Set(context.Background(), key, "1", 0)
}
// 迁移配置内容
func (s *MigrationService) MigrateConfigContent() error {
// 用户协议
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
"content": "用户协议内容",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 隐私政策
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
"content": "隐私政策内容",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 思维导图
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
"content": `# GeekAI 演示站
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
- 基于 Websocket 实现,完美的打字机体验。
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 微信登录配置
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
"api_key": "",
"notify_url": "",
"enabled": "false",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 验证码配置
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
"api_key": "",
"type": "dot",
"enabled": "false",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 文本审核
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
"enable": "false",
"active": "gitee",
"enable_guide": "false",
"guide_prompt": "",
"gitee": map[string]string{
"api_key": "",
"model": "Security-semantic-filtering",
},
"baidu": map[string]string{
"access_key": "",
"secret_key": "",
},
"tencent": map[string]string{
"access_key": "",
"secret_key": "",
},
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
return nil
}
// 数据表迁移
func (s *MigrationService) TableMigration() {
// 新数据表
s.db.AutoMigrate(&model.Moderation{})
// 订单字段整理
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
}
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
s.db.Migrator().AddColumn(&model.Order{}, "checked")
}
// 重命名 config 表字段
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
s.db.Migrator().DropIndex(&model.Config{}, "marker")
}
// 手动删除字段
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
s.db.Migrator().DropColumn(&model.Product{}, "discount")
}
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
s.db.Migrator().DropColumn(&model.Product{}, "days")
}
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
}
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
s.db.Migrator().DropColumn(&model.Product{}, "url")
}
}
// 迁移配置数据
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
logger.Info("开始迁移配置到数据库...")
// 迁移支付配置
if err := s.migratePaymentConfig(config); err != nil {
logger.Errorf("迁移支付配置失败: %v", err)
return err
}
// 迁移存储配置
if err := s.migrateStorageConfig(config); err != nil {
logger.Errorf("迁移存储配置失败: %v", err)
return err
}
// 迁移通信配置
if err := s.migrateCommunicationConfig(config); err != nil {
logger.Errorf("迁移通信配置失败: %v", err)
return err
}
// 迁移配置内容
if err := s.MigrateConfigContent(); err != nil {
logger.Errorf("迁移配置内容失败: %v", err)
return err
}
logger.Info("配置迁移完成")
return nil
}
// 迁移支付配置
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
paymentConfig := types.PaymentConfig{
Alipay: config.AlipayConfig,
Epay: config.GeekPayConfig,
WxPay: config.WechatPayConfig,
}
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
return err
}
return nil
}
// 迁移存储配置
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
ossConfig := types.OSSConfig{
Active: config.OSS.Active,
Local: config.OSS.Local,
Minio: config.OSS.Minio,
QiNiu: config.OSS.QiNiu,
AliYun: config.OSS.AliYun,
}
return s.saveConfig(types.ConfigKeyOss, ossConfig)
}
// 迁移通信配置
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
// SMTP配置
smtpConfig := map[string]any{
"use_tls": config.SmtpConfig.UseTls,
"host": config.SmtpConfig.Host,
"port": config.SmtpConfig.Port,
"app_name": config.SmtpConfig.AppName,
"from": config.SmtpConfig.From,
"password": config.SmtpConfig.Password,
}
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
return err
}
// 短信配置
smsConfig := map[string]any{
"active": strings.ToLower(config.SMS.Active),
"aliyun": map[string]any{
"access_key": config.SMS.Ali.AccessKey,
"access_secret": config.SMS.Ali.AccessSecret,
"sign": config.SMS.Ali.Sign,
"code_temp_id": config.SMS.Ali.CodeTempId,
},
"bao": map[string]any{
"username": config.SMS.Bao.Username,
"password": config.SMS.Bao.Password,
"sign": config.SMS.Bao.Sign,
"code_template": config.SMS.Bao.CodeTemplate,
},
}
return s.saveConfig(types.ConfigKeySms, smsConfig)
}
// 保存配置到数据库
func (s *MigrationService) saveConfig(key string, config any) error {
// 检查是否已存在
var existingConfig model.Config
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
// 配置已存在,跳过
logger.Infof("配置 %s 已存在,跳过迁移", key)
return nil
}
// 序列化配置
configJSON, err := json.Marshal(config)
if err != nil {
return err
}
// 保存到数据库
newConfig := model.Config{
Name: key,
Value: string(configJSON),
}
if err := s.db.Create(&newConfig).Error; err != nil {
return err
}
logger.Infof("成功迁移配置 %s", key)
return nil
} }

View File

@@ -67,25 +67,6 @@ func (s *Service) Run() {
continue continue
} }
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// use fast mode as default // use fast mode as default
if task.Mode == "" { if task.Mode == "" {
task.Mode = "fast" task.Mode = "fast"

View File

@@ -0,0 +1,33 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
)
type BaiduAIModeration struct {
config types.ModerationBaiduConfig
}
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
return &BaiduAIModeration{
config: sysConfig.Moderation.Baidu,
}
}
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
s.config = config
}
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
return types.ModerationResult{}, errors.New("not implemented")
}
var _ Service = (*BaiduAIModeration)(nil)

View File

@@ -0,0 +1,58 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
"github.com/imroc/req/v3"
)
type GiteeAIModeration struct {
config types.ModerationGiteeConfig
apiURL string
}
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
return &GiteeAIModeration{
config: sysConfig.Moderation.Gitee,
apiURL: "https://ai.gitee.com/v1/moderations",
}
}
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
s.config = config
}
type GiteeAIModerationResult struct {
ID string `json:"id"`
Model string `json:"model"`
Results []types.ModerationResult `json:"results"`
}
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
body := map[string]any{
"input": text,
"model": s.config.Model,
}
var res GiteeAIModerationResult
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
if err != nil {
return types.ModerationResult{}, err
}
if r.IsErrorState() {
return types.ModerationResult{}, errors.New(r.String())
}
return res.Results[0], nil
}
var _ Service = (*GiteeAIModeration)(nil)

View File

@@ -0,0 +1,58 @@
package moderation
import (
"geekai/core/types"
logger2 "geekai/logger"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
var logger = logger2.GetLogger()
type Service interface {
Moderate(text string) (types.ModerationResult, error)
}
type ServiceManager struct {
gitee *GiteeAIModeration
baidu *BaiduAIModeration
tencent *TencentAIModeration
active string
}
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
return &ServiceManager{
gitee: gitee,
baidu: baidu,
tencent: tencent,
}
}
func (s *ServiceManager) GetService() Service {
switch s.active {
case types.ModerationBaidu:
return s.baidu
case types.ModerationTencent:
return s.tencent
default:
return s.gitee
}
}
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
switch config.Active {
case types.ModerationGitee:
s.gitee.UpdateConfig(config.Gitee)
case types.ModerationBaidu:
s.baidu.UpdateConfig(config.Baidu)
case types.ModerationTencent:
s.tencent.UpdateConfig(config.Tencent)
}
s.active = config.Active
}

View File

@@ -0,0 +1,33 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
)
type TencentAIModeration struct {
config types.ModerationTencentConfig
}
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
return &TencentAIModeration{
config: sysConfig.Moderation.Tencent,
}
}
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
s.config = config
}
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
return types.ModerationResult{}, errors.New("not implemented")
}
var _ Service = (*TencentAIModeration)(nil)

View File

@@ -23,35 +23,35 @@ import (
) )
type AliYunOss struct { type AliYunOss struct {
config *types.AliYunOssConfig config types.AliYunOssConfig
bucket *oss.Bucket bucket *oss.Bucket
proxyURL string proxyURL string
} }
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) { func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
config := &appConfig.OSS.AliYun s := &AliYunOss{
// 创建 OSS 客户端 proxyURL: appConfig.ProxyURL,
}
err := s.UpdateConfig(sysConfig.OSS.AliYun)
if err != nil {
logger.Warnf("阿里云OSS初始化失败: %v", err)
}
return s, nil
}
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret) client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
if err != nil { if err != nil {
return nil, err return err
} }
// 获取存储空间
bucket, err := client.Bucket(config.Bucket) bucket, err := client.Bucket(config.Bucket)
if err != nil { if err != nil {
return nil, err return err
} }
s.bucket = bucket
if config.SubDir == "" { s.config = config
config.SubDir = "gpt" return nil
}
return &AliYunOss{
config: config,
bucket: bucket,
proxyURL: appConfig.ProxyURL,
}, nil
} }
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) { func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer src.Close() defer src.Close()
fileExt := filepath.Ext(file.Filename) fileExt := filepath.Ext(file.Filename)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
// 上传文件 // 上传文件
err = s.bucket.PutObject(objectKey, src) err = s.bucket.PutObject(objectKey, src)
if err != nil { if err != nil {
@@ -102,7 +102,7 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
if ext == "" { if ext == "" {
ext = filepath.Ext(parse.Path) ext = filepath.Ext(parse.Path)
} }
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext) objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
// 上传文件字节数据 // 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData)) err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil { if err != nil {
@@ -116,7 +116,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err) return "", fmt.Errorf("error decoding base64:%v", err)
} }
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
// 上传文件字节数据 // 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil { if err != nil {
@@ -128,8 +128,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
func (s AliYunOss) Delete(fileURL string) error { func (s AliYunOss) Delete(fileURL string) error {
var objectKey string var objectKey string
if strings.HasPrefix(fileURL, "http") { if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL) objectKey = filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else { } else {
objectKey = fileURL objectKey = fileURL
} }

View File

@@ -21,17 +21,21 @@ import (
) )
type LocalStorage struct { type LocalStorage struct {
config *types.LocalStorageConfig config types.LocalStorageConfig
proxyURL string proxyURL string
} }
func NewLocalStorage(config *types.AppConfig) LocalStorage { func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
return LocalStorage{ return &LocalStorage{
config: &config.OSS.Local, config: sysConfig.OSS.Local,
proxyURL: config.ProxyURL, proxyURL: appConfig.ProxyURL,
} }
} }
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
s.config = config
}
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) { func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name) file, err := ctx.FormFile(name)
if err != nil { if err != nil {

View File

@@ -24,24 +24,32 @@ import (
) )
type MiniOss struct { type MiniOss struct {
config *types.MiniOssConfig config types.MiniOssConfig
client *minio.Client client *minio.Client
proxyURL string proxyURL string
} }
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) { func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
config := &appConfig.OSS.Minio
s := &MiniOss{proxyURL: appConfig.ProxyURL}
err := s.UpdateConfig(sysConfig.OSS.Minio)
if err != nil {
logger.Warnf("MinioOSS初始化失败: %v", err)
}
return s, nil
}
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
minioClient, err := minio.New(config.Endpoint, &minio.Options{ minioClient, err := minio.New(config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""), Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
Secure: config.UseSSL, Secure: config.UseSSL,
}) })
if err != nil { if err != nil {
return MiniOss{}, err return err
} }
if config.SubDir == "" { s.config = config
config.SubDir = "gpt" s.client = minioClient
} return nil
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
} }
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) { func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
@@ -62,7 +70,7 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
if ext == "" { if ext == "" {
ext = filepath.Ext(parse.Path) ext = filepath.Ext(parse.Path)
} }
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext) filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
info, err := s.client.PutObject( info, err := s.client.PutObject(
context.Background(), context.Background(),
s.config.Bucket, s.config.Bucket,
@@ -89,7 +97,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer fileReader.Close() defer fileReader.Close()
fileExt := filepath.Ext(file.Filename) fileExt := filepath.Ext(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Body-Type"), ContentType: file.Header.Get("Body-Type"),
}) })
@@ -111,7 +119,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err) return "", fmt.Errorf("error decoding base64:%v", err)
} }
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
info, err := s.client.PutObject( info, err := s.client.PutObject(
context.Background(), context.Background(),
s.config.Bucket, s.config.Bucket,
@@ -128,8 +136,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
func (s MiniOss) Delete(fileURL string) error { func (s MiniOss) Delete(fileURL string) error {
var objectKey string var objectKey string
if strings.HasPrefix(fileURL, "http") { if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL) objectKey = filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else { } else {
objectKey = fileURL objectKey = fileURL
} }

View File

@@ -24,18 +24,24 @@ import (
"github.com/qiniu/go-sdk/v7/storage" "github.com/qiniu/go-sdk/v7/storage"
) )
type QinNiuOss struct { type QiNiuOss struct {
config *types.QiNiuOssConfig config types.QiNiuOssConfig
mac *qbox.Mac mac *qbox.Mac
putPolicy storage.PutPolicy putPolicy storage.PutPolicy
uploader *storage.FormUploader uploader *storage.FormUploader
manager *storage.BucketManager bucket *storage.BucketManager
proxyURL string proxyURL string
} }
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiNiuOss {
config := &appConfig.OSS.QiNiu s := &QiNiuOss{
// build storage uploader proxyURL: appConfig.ProxyURL,
}
s.UpdateConfig(sysConfig.OSS.QiNiu)
return s
}
func (s *QiNiuOss) UpdateConfig(config types.QiNiuOssConfig) {
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone)) zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
if !ok { if !ok {
zone = storage.ZoneHuanan zone = storage.ZoneHuanan
@@ -47,20 +53,13 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
putPolicy := storage.PutPolicy{ putPolicy := storage.PutPolicy{
Scope: config.Bucket, Scope: config.Bucket,
} }
if config.SubDir == "" { s.config = config
config.SubDir = "gpt" s.mac = mac
} s.putPolicy = putPolicy
return QinNiuOss{ s.uploader = formUploader
config: config, s.bucket = storage.NewBucketManager(mac, &storeConfig)
mac: mac,
putPolicy: putPolicy,
uploader: formUploader,
manager: storage.NewBucketManager(mac, &storeConfig),
proxyURL: appConfig.ProxyURL,
}
} }
func (s QiNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单 // 解析表单
file, err := ctx.FormFile(name) file, err := ctx.FormFile(name)
if err != nil { if err != nil {
@@ -74,7 +73,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer src.Close() defer src.Close()
fileExt := filepath.Ext(file.Filename) fileExt := filepath.Ext(file.Filename)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
// 上传文件 // 上传文件
ret := storage.PutRet{} ret := storage.PutRet{}
extra := storage.PutExtra{} extra := storage.PutExtra{}
@@ -93,7 +92,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
} }
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) { func (s QiNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
var fileData []byte var fileData []byte
var err error var err error
if useProxy { if useProxy {
@@ -111,7 +110,7 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
if ext == "" { if ext == "" {
ext = filepath.Ext(parse.Path) ext = filepath.Ext(parse.Path)
} }
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext) key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
ret := storage.PutRet{} ret := storage.PutRet{}
extra := storage.PutExtra{} extra := storage.PutExtra{}
// 上传文件字节数据 // 上传文件字节数据
@@ -122,12 +121,12 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
} }
func (s QinNiuOss) PutBase64(base64Img string) (string, error) { func (s QiNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img) imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil { if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err) return "", fmt.Errorf("error decoding base64:%v", err)
} }
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
ret := storage.PutRet{} ret := storage.PutRet{}
extra := storage.PutExtra{} extra := storage.PutExtra{}
// 上传文件字节数据 // 上传文件字节数据
@@ -138,16 +137,15 @@ func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
} }
func (s QinNiuOss) Delete(fileURL string) error { func (s QiNiuOss) Delete(fileURL string) error {
var objectKey string var objectKey string
if strings.HasPrefix(fileURL, "http") { if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL) objectKey = filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
} else { } else {
objectKey = fileURL objectKey = fileURL
} }
return s.manager.Delete(s.config.Bucket, objectKey) return s.bucket.Delete(s.config.Bucket, objectKey)
} }
var _ Uploader = QinNiuOss{} var _ Uploader = QiNiuOss{}

View File

@@ -9,10 +9,10 @@ package oss
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
const Local = "LOCAL" const Local = "local"
const Minio = "MINIO" const Minio = "minio"
const QiNiu = "QINIU" const QiNiu = "qiniu"
const AliYun = "ALIYUN" const AliYun = "aliyun"
type File struct { type File struct {
Name string `json:"name"` Name string `json:"name"`

View File

@@ -9,45 +9,58 @@ package oss
import ( import (
"geekai/core/types" "geekai/core/types"
"strings"
logger2 "geekai/logger"
) )
var logger = logger2.GetLogger()
type UploaderManager struct { type UploaderManager struct {
handler Uploader local *LocalStorage
aliyun *AliYunOss
mini *MiniOss
qiniu *QiNiuOss
active string
} }
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) { func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QiNiuOss) (*UploaderManager, error) {
active := Local if sysConfig.OSS.Active == "" {
if config.OSS.Active != "" { sysConfig.OSS.Active = Local
active = strings.ToUpper(config.OSS.Active)
}
var handler Uploader
switch active {
case Local:
handler = NewLocalStorage(config)
break
case Minio:
client, err := NewMiniOss(config)
if err != nil {
return nil, err
}
handler = client
break
case QiNiu:
handler = NewQiNiuOss(config)
break
case AliYun:
client, err := NewAliYunOss(config)
if err != nil {
return nil, err
}
handler = client
break
} }
return &UploaderManager{handler: handler}, nil return &UploaderManager{
active: sysConfig.OSS.Active,
local: local,
aliyun: aliyun,
mini: mini,
qiniu: qiniu,
}, nil
} }
func (m *UploaderManager) GetUploadHandler() Uploader { func (m *UploaderManager) GetUploadHandler() Uploader {
return m.handler switch m.active {
case Local:
return m.local
case AliYun:
return m.aliyun
case Minio:
return m.mini
case QiNiu:
return m.qiniu
}
return m.local
}
func (m *UploaderManager) UpdateConfig(config types.OSSConfig) {
switch config.Active {
case Local:
m.local.UpdateConfig(config.Local)
case AliYun:
m.aliyun.UpdateConfig(config.AliYun)
case Minio:
m.mini.UpdateConfig(config.Minio)
case QiNiu:
m.qiniu.UpdateConfig(config.QiNiu)
}
m.active = config.Active
} }

View File

@@ -12,129 +12,98 @@ import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
"net/http" "net/http"
"os" "os"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
) )
type AlipayService struct { type AlipayService struct {
config *types.AlipayConfig
client *alipay.Client client *alipay.Client
config *types.AlipayConfig
} }
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) { func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
config := appConfig.AlipayConfig config := sysConfig.Payment.Alipay
if !config.Enabled { if !config.Enabled {
logger.Info("Disabled Alipay service") logger.Debug("Disabled Alipay service")
return nil, nil
} }
priKey, err := readKey(config.PrivateKey)
service := &AlipayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("支付宝服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
if err != nil { if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err) return fmt.Errorf("error with initialize alipay service: %v", err)
} }
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox) s.client = client
if err != nil { s.config = config
return nil, fmt.Errorf("error with initialize alipay service: %v", err) if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("Alipay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
} }
return nil
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with load payment public key: %v", err)
}
return &AlipayService{config: &config, client: client}, nil
} }
type AlipayParams struct { func (s *AlipayService) Pay(params PayRequest) (string, error) {
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"subject"`
TotalFee string `json:"total_fee"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("quit_url", params.ReturnURL)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap) bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject) bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo) bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee) bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY") return s.client.TradeWapPay(context.Background(), bm)
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm) }
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
}
switch rsp.Response.TradeStatus {
case "TRADE_SUCCESS":
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
return OrderInfo{
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Status: Success,
PayTime: rsp.Response.SendPayDate,
}, nil
case "TRADE_CLOSED":
return OrderInfo{Status: Closed}, nil
default:
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
}
} }
// TradeVerify 交易验证 // TradeVerify 交易验证
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo { func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法 notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil { if err != nil {
return NotifyVo{ return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
Status: Failure,
Message: "error with parse notify request: " + err.Error(),
}
} }
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq) _, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil { if err != nil {
return NotifyVo{ return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
} }
return s.TradeQuery(request.Form.Get("out_trade_no")) return s.Query(request.Form.Get("out_trade_no"))
} }
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo { var _ PayService = (*AlipayService)(nil)
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
}
}
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{
Status: Success,
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Subject: rsp.Response.Subject,
Message: "OK",
}
} else {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo,
}
}
}
func readKey(filename string) (string, error) {
data, err := os.ReadFile(filename)
if err != nil {
return "", err
}
return string(data), nil
}

View File

@@ -22,41 +22,30 @@ import (
"time" "time"
) )
// GeekPayService Geek 支付服务 // EPayService 支付服务
type GeekPayService struct { type EPayService struct {
config *types.GeekPayConfig config *types.EpayConfig
} }
func NewJPayService(appConfig *types.AppConfig) *GeekPayService { func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
return &GeekPayService{ return &EPayService{
config: &appConfig.GeekPayConfig, config: &sysConfig.Payment.Epay,
} }
} }
type GeekPayParams struct { func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
Method string `json:"method"` // 接口类型 s.config = config
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
} }
// Pay 支付订单 // Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) { func (s *EPayService) Pay(params PayRequest) (string, error) {
p := map[string]string{ p := map[string]string{
"pid": s.config.AppId, "pid": s.config.AppId,
//"method": params.Method,
"device": params.Device, "device": params.Device,
"type": params.Type, "type": params.PayWay,
"out_trade_no": params.OutTradeNo, "out_trade_no": params.OutTradeNo,
"name": params.Name, "name": params.Subject,
"money": params.Money, "money": params.TotalFee,
"clientip": params.ClientIP, "clientip": params.ClientIP,
"notify_url": params.NotifyURL, "notify_url": params.NotifyURL,
"return_url": params.ReturnURL, "return_url": params.ReturnURL,
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
} }
p["sign"] = s.Sign(p) p["sign"] = s.Sign(p)
p["sign_type"] = "MD5" p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p) resp, err := s.sendRequest(s.config.ApiURL, p)
if err != nil {
return "", err
}
if resp.Code != 1 {
return "", errors.New(resp.Msg)
}
if resp.PayURL != "" {
return resp.PayURL, nil
} else {
return resp.QrCode, nil
}
} }
func (s *GeekPayService) Sign(params map[string]string) string { func (s *EPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数 // 按字母顺序排序参数
var keys []string var keys []string
for k := range params { for k := range params {
@@ -100,7 +100,7 @@ type GeekPayResp struct {
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接 UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
} }
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) { func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{} form := url.Values{}
for k, v := range params { for k, v := range params {
form.Add(k, v) form.Add(k, v)
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
} }
return &r, nil return &r, nil
} }
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
params := url.Values{}
params.Set("act", "order")
params.Set("pid", s.config.AppId)
params.Set("key", s.config.PrivateKey)
params.Set("out_trade_no", outTradeNo)
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: tr}
resp, err := client.Get(apiURL)
if err != nil {
return OrderInfo{}, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return OrderInfo{}, err
}
logger.Debugf(string(body))
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Status string `json:"status"`
Name string `json:"name"`
Money string `json:"money"`
EndTime string `json:"endtime"`
TradeNo string `json:"trade_no"`
}
if err := json.Unmarshal(body, &result); err != nil {
return OrderInfo{}, errors.New("订单查询响应解析失败")
}
if result.Code != 1 {
return OrderInfo{}, errors.New(result.Msg)
}
logger.Debugf("订单信息:%+v", result)
orderInfo := OrderInfo{
OutTradeNo: outTradeNo,
TradeId: result.TradeNo,
Amount: result.Money,
PayTime: result.EndTime,
}
if result.Status == "1" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
var _ PayService = (*EPayService)(nil)

View File

@@ -1,171 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
type HuPiPayService struct {
appId string
appSecret string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
type HuPiPayParams struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
TotalFee string `json:"total_fee"`
Title string `json:"title"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
WapName string `json:"wap_name"`
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
}
type HuPiPayResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
params.Time = simple
params.NonceStr = simple
encode := utils.JsonEncode(params)
m := make(map[string]string)
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
}
var res HuPiPayResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
}

View File

@@ -0,0 +1,54 @@
package payment
// 支付渠道定义
const PayChannelAL = "alipay" // 支付宝
const PayChannelWX = "wxpay" // 微信支付
const PayChannelEpay = "epay" // 易支付
// 支付方式
const PayWayAL = "alipay"
const PayWayWX = "wxpay"
const (
Success = 0
Failure = 1
Closed = 2
)
type PayRequest struct {
OutTradeNo string // 商户订单号
Subject string // 商品名称
TotalFee string // 商品金额
ReturnURL string // 回调地址
NotifyURL string // 回调地址
// 易支付专有参数
Method string // 接口类型
Device string // 设备类型
PayWay string // 支付方式
ClientIP string //用户IP地址
OpenID string // 用户openid
}
type OrderInfo struct {
Mchid string // 商户号
OutTradeNo string // 商户订单号
TradeId string // 交易号
Amount string // 金额
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
PayTime string // 完成支付时间
}
func (o OrderInfo) Closed() bool {
return o.Status == Closed
}
func (o OrderInfo) Success() bool {
return o.Status == Success
}
type PayService interface {
Pay(params PayRequest) (string, error) // 生成支付链接
Query(outTradeNo string) (OrderInfo, error) // 查询订单
}

View File

@@ -1,19 +0,0 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -1,144 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
"net/http"
"time"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
priKey, err := readKey(config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -0,0 +1,217 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"os"
"time"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
)
type WxPayService struct {
config *types.WxPayConfig
client *wechat.ClientV3
}
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
config := sysConfig.Payment.WxPay
if !config.Enabled {
logger.Debug("Disabled WechatPay service")
}
service := &WxPayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("微信支付服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
if err != nil {
return fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return fmt.Errorf("error with autoVerifySign: %v", err)
}
s.client = client
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("WechatPay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
s.config = config
return nil
}
func (s *WxPayService) Pay(params PayRequest) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
Set("currency", "CNY")
})
logger.Debugf("wxpay params: %+v", bm)
if params.Device == "mobile" {
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP)
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
bm.Set("openid", params.OpenID)
})
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.PrepayId, nil
} else if params.Device == "pc" {
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
return "", nil
}
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
}
if wxRsp.Code != wechat.Success {
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
}
if wxRsp.Response.TradeState == "CLOSED" {
return OrderInfo{Status: Closed}, nil
}
orderInfo := OrderInfo{
OutTradeNo: wxRsp.Response.OutTradeNo,
TradeId: wxRsp.Response.TransactionId,
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
PayTime: wxRsp.Response.SuccessTime,
}
if wxRsp.Response.TradeState == "SUCCESS" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
// TradeVerify 交易验证
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
}
return OrderInfo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
PayTime: result.SuccessTime,
}, nil
}
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// })
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.CodeUrl, nil
// }
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// }).
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
// bm.Set("payer_client_ip", params.ClientIP).
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
// bm.Set("type", "Wap")
// })
// })
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.H5Url, nil
// }
// type NotifyResponse struct {
// Code string `json:"code"`
// Message string `xml:"message"`
// }
var _ PayService = (*WxPayService)(nil)

Some files were not shown because too many files have changed in this diff Show More