mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-11 05:34:25 +08:00
Compare commits
116 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ba2e000a1 | ||
|
|
805c2a045e | ||
|
|
1b1c327c35 | ||
|
|
7d0a05ee11 | ||
|
|
825a1b1027 | ||
|
|
950e7d1b00 | ||
|
|
d8f9f48278 | ||
|
|
bdd76addf3 | ||
|
|
2936f21f12 | ||
|
|
705ca4d20a | ||
|
|
6a066f1b8e | ||
|
|
b7d137247a | ||
|
|
4373642ebd | ||
|
|
5bf90920f5 | ||
|
|
edcbb3e226 | ||
|
|
5213bdf08b | ||
|
|
cf817fd8ea | ||
|
|
a2481ff1cf | ||
|
|
bc7d06d3e5 | ||
|
|
8e81dfa12a | ||
|
|
0ff76f0f21 | ||
|
|
787caa84c8 | ||
|
|
c2503e663a | ||
|
|
405a88862b | ||
|
|
296eabe09a | ||
|
|
54b45ec2ff | ||
|
|
c434f85045 | ||
|
|
4d10279870 | ||
|
|
9de9489673 | ||
|
|
9814fec930 | ||
|
|
53ba731159 | ||
|
|
b2f57aa483 | ||
|
|
4c2dba1004 | ||
|
|
79adc871ef | ||
|
|
8144fada25 | ||
|
|
754ba02263 | ||
|
|
7ddf57ae06 | ||
|
|
cc5180a6f7 | ||
|
|
9f44c34d34 | ||
|
|
b793b81768 | ||
|
|
233f6e00f0 | ||
|
|
b7dba68549 | ||
|
|
bdea12c51a | ||
|
|
a27d9ea259 | ||
|
|
7cd824c284 | ||
|
|
e27d95e2b5 | ||
|
|
6839827db0 | ||
|
|
d6a04f96fe | ||
|
|
5f820b9dc1 | ||
|
|
6788edbe9d | ||
|
|
3895305882 | ||
|
|
c2acbaaa94 | ||
|
|
4a99be2f15 | ||
|
|
27c816cf3b | ||
|
|
0d81776212 | ||
|
|
cccab31c0f | ||
|
|
4ddf3bf2bf | ||
|
|
3d37a3d367 | ||
|
|
73d8236697 | ||
|
|
3c34e8e0e7 | ||
|
|
922202734a | ||
|
|
b270960a04 | ||
|
|
5c4899df6e | ||
|
|
2f0215ac87 | ||
|
|
dd5cc206e5 | ||
|
|
657ecccee3 | ||
|
|
2e023cb8dc | ||
|
|
e933f32d9c | ||
|
|
0b2501c1d8 | ||
|
|
c1d892069e | ||
|
|
bda335212d | ||
|
|
06f4cdc649 | ||
|
|
9bf7fa4081 | ||
|
|
4ca9dfd9c0 | ||
|
|
adfee8bf58 | ||
|
|
fbfa2a71a9 | ||
|
|
a7237fe62f | ||
|
|
c3c454b7d7 | ||
|
|
d4d708d44b | ||
|
|
7f0b6a3a46 | ||
|
|
c2a7c089d2 | ||
|
|
df5bd4df60 | ||
|
|
79b6010104 | ||
|
|
97b0a98793 | ||
|
|
d900a3d08e | ||
|
|
cdf5b66729 | ||
|
|
a6c00c42fa | ||
|
|
be8a0ec184 | ||
|
|
b02e3aad95 | ||
|
|
08eca511ad | ||
|
|
c34e911596 | ||
|
|
8a452c3072 | ||
|
|
13bfb14107 | ||
|
|
4188b0969e | ||
|
|
0c27795a10 | ||
|
|
d05693c5c1 | ||
|
|
c0b2063b38 | ||
|
|
4d183747b1 | ||
|
|
08fe1b2f75 | ||
|
|
db3e8a267e | ||
|
|
8fc62682c4 | ||
|
|
75031914a3 | ||
|
|
a4c9fdd95a | ||
|
|
6a9bfeb5aa | ||
|
|
e654766f60 | ||
|
|
0ef6955f96 | ||
|
|
b4501557c9 | ||
|
|
a2ed99e6cb | ||
|
|
6bd6bb3885 | ||
|
|
399cf65fc9 | ||
|
|
24906a6df1 | ||
|
|
d772bbebe6 | ||
|
|
14988853a3 | ||
|
|
7b3f16ac9f | ||
|
|
82b2755c18 | ||
|
|
4e4dc4cb73 |
@@ -1,37 +0,0 @@
|
||||
---
|
||||
name: frontend-developer
|
||||
description: Use this agent when you need assistance with frontend development tasks including Vue.js components, UI implementation, styling, responsive design, state management, or frontend architecture decisions. Examples: <example>Context: User is working on a Vue.js component and needs help with implementing a responsive layout. user: 'I need to create a mobile-friendly chat interface component' assistant: 'I'll use the frontend-developer agent to help design and implement this responsive chat component' <commentary>Since this involves frontend development work with Vue.js and responsive design, use the frontend-developer agent.</commentary></example> <example>Context: User encounters styling issues with Element Plus components. user: 'The Element Plus dialog is not displaying correctly on mobile devices' assistant: 'Let me use the frontend-developer agent to troubleshoot this mobile styling issue' <commentary>This is a frontend styling problem that requires expertise in Element Plus and responsive design.</commentary></example>
|
||||
color: purple
|
||||
---
|
||||
|
||||
You are a Senior Frontend Development Engineer with deep expertise in modern web development technologies, particularly Vue.js 3, Element Plus, Vant, and responsive design patterns. You specialize in creating high-quality, maintainable frontend applications with excellent user experience.
|
||||
|
||||
Your core responsibilities include:
|
||||
- Developing Vue.js 3 components using Composition API and best practices
|
||||
- Implementing responsive designs that work seamlessly across desktop and mobile devices
|
||||
- Working with Element Plus for desktop UI and Vant for mobile components
|
||||
- Managing application state using Pinia store patterns
|
||||
- Styling with Stylus preprocessor and Tailwind CSS utilities
|
||||
- Optimizing build processes with Vite and ensuring proper code organization
|
||||
- Implementing theme switching (dark/light mode) and accessibility features
|
||||
- Follow decoupled development, with HTML, CSS, and JS codes placed in separate files for easier maintenance
|
||||
|
||||
When working on frontend tasks, you will:
|
||||
1. Analyze requirements and suggest the most appropriate Vue.js patterns and component structures
|
||||
2. Ensure responsive design principles are followed, considering both desktop and mobile viewports
|
||||
3. Choose appropriate UI components from Element Plus (desktop) or Vant (mobile) libraries
|
||||
4. Write clean, maintainable code following Vue.js 3 Composition API best practices
|
||||
5. Consider performance implications and suggest optimizations when relevant
|
||||
6. Ensure proper state management using Pinia when component state needs to be shared
|
||||
7. Follow the project's established patterns for routing, API integration, and component organization
|
||||
8. Provide specific code examples and explain the reasoning behind architectural decisions
|
||||
|
||||
You have deep knowledge of:
|
||||
- Vue.js 3 ecosystem (Vue Router, Pinia, Composition API)
|
||||
- Modern CSS techniques and preprocessors (Stylus, Tailwind)
|
||||
- Component library integration (Element Plus, Vant)
|
||||
- Build tools and development workflow (Vite, npm scripts)
|
||||
- Cross-browser compatibility and mobile-first design principles
|
||||
- Performance optimization and code splitting strategies
|
||||
|
||||
Always consider the user experience, code maintainability, and alignment with modern frontend development standards. When suggesting solutions, provide clear explanations and consider both immediate needs and long-term scalability.
|
||||
@@ -1,6 +0,0 @@
|
||||
重构当前页面代码
|
||||
|
||||
1. 把当前页面 JS 代码全部抽离,然后是采用 Pinia 重构
|
||||
2. 把当前页面 CSS 代码全部抽离,如果是 stylus 语法代码,则需要改成 SCSS 语法代码
|
||||
3. 尽量做到代码的复用性,不要重复造轮子
|
||||
4. 移动端的 css 和 js 分别放到对应的 mobile 目录下,不要覆盖 PC 端的代码
|
||||
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Bug 报告 🐛
|
||||
description: 为 chatgpt-plus 提交错误报告
|
||||
description: 为 geekai 提交错误报告
|
||||
labels: ['Bug']
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
2
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: 功能优化 🚀
|
||||
description: 为 chatgpt-plus 提交优化建议
|
||||
description: 为 geekai 提交优化建议
|
||||
labels: ['feature']
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
||||
648
CHANGELOG.md
648
CHANGELOG.md
@@ -1,464 +1,350 @@
|
||||
# 更新日志
|
||||
|
||||
## v4.2.6
|
||||
|
||||
- 功能重构:优化系统配置管理功能,把 OSS,支付,短信,邮件等配置全部迁移到管理后台,无需通过修改配置文档的方式修改 🎉🎉🎉
|
||||
- 功能优化:重构 API 授权代码,采用中间件鉴权方式,实现更加精准的 API 鉴权 🎉🎉🎉
|
||||
- 功能优化:优化 PC 端的 Suno 音乐,视频生成,以及即梦 AI 页面 UI
|
||||
- 功能优化:重构登录和注册页面,兼容移动端和 PC 端,并且所有的登录组件共用了同一套组件代码,大大降低维护成本 🎉🎉🎉
|
||||
- 功能优化:管理后台增加模型批量删除功能
|
||||
- 功能优化:优化 Table 组件 UI,并支持 dark 主题
|
||||
- 功能优化:移动端对话页面支持上传文件和图片
|
||||
- 功能新增:新增微信扫码登录支持
|
||||
- 功能新增:新增安全监控,内容审核功能,支持敏感内容过滤拦截
|
||||
- 功能新增:DALL-E 绘图支持参 Google Banana 图片编辑功能
|
||||
|
||||
## v4.2.5
|
||||
|
||||
- 功能优化:在代码右下角增加复制代码功能按钮,增加收起和展开代码功能
|
||||
- Bug 修复:修复 Shift + Enter 不换行的 Bug
|
||||
- Bug 修复:修复管理后台菜单添加页面的文本错误
|
||||
- Bug 修复:解决聊天页面异常退出不断重连的 bug
|
||||
- 功能优化:把 Luma 和可灵视频生成页面整合成一个视频创作中心页面,统一管理视频任务
|
||||
- 功能新增:增加即梦 AI 专题页面,支持即梦官方原生 API 的图片和视频生成 🎉🎉🎉
|
||||
|
||||
## v4.2.4
|
||||
|
||||
- 功能优化:更改前端构建技术选型,使用 Vite 构建,提升构建速度和兼容性
|
||||
- 功能优化:使用 SSE 发送消息,替换原来的 Websocket 消息方案
|
||||
- 功能新增:管理后台支持设置默认昵称
|
||||
- 功能优化:支持 Suno v4.5 模型支持
|
||||
- 功能新增:用户注册和用户登录增加用户协议和隐私政策功能,需要用户同意协议才可注册和登录。
|
||||
- 功能优化:修改重新回答功能,撤回千面的问答内容为可编辑内容,撤回的内容不会增加额外的上下文
|
||||
- 功能优化:优化聊天记录的存储结构,增加模型名称字段,支持存储更长的模型名称
|
||||
- Bug 修复:聊天应用绑定模型后无效,还是会轮询 API KEY,导致一会成功,一会请求失败。
|
||||
- 功能优化:如果管理后台没有启用会员充值菜单,移动端也不显示充值套餐功能
|
||||
|
||||
## v4.2.3
|
||||
|
||||
- 功能优化:增加模型分组与模型描述,采用卡片展示模式改进模型选择功能体验
|
||||
- 功能优化:化思维导图下载图片的清晰度以及解决拖动、缩放操作后下载图片内容不全问题
|
||||
- Bug 修复:修复 MJ 画图页面已画出的图,点复制指令无效问题
|
||||
- 功能优化:MJ 画图的分辨率支持自定义,优先使用 prompt 中--ar 参数
|
||||
- Bug 修复:修复 MJ 绘画 U1-V1,拼写错误
|
||||
- 功能优化:支持自动迁移数据表结构,无需在手动执行 SQL 了
|
||||
- 功能优化:移除首页的文字动画效果
|
||||
- 功能优化:在聊天页面增加对话列表展开和隐藏功能
|
||||
- 功能优化:聊天页面增加 AI 思考中动画效果
|
||||
|
||||
## v4.2.2
|
||||
|
||||
- 功能优化:开启图形验证码功能的时候现检查是否配置了 API 服务,防止开启之后没法登录的 Bug。
|
||||
- 功能优化:支持原生的 DeepSeek 推理模型 API,聊天 API KEY 支持设置完整的 API 路径,比如 https://api.geekai.pro/v1/chat/completions
|
||||
- 功能优化:支持 GPT-4o 图片编辑功能。
|
||||
- 功能新增:对话页面支持 AI 输出语音播报(TTS)。
|
||||
- 功能优化:替换瀑布流组件,优化用户体验。
|
||||
- 功能优化:生成思维导图时候自动缓存上一次的结果。
|
||||
- 功能优化:优化 MJ 绘图页面,增加 MJ-V7 模型支持。
|
||||
- 功能优化:后台管理增加生成一键登录链接地址功能
|
||||
|
||||
## v4.2.1
|
||||
|
||||
- 功能新增:新增支持可灵生成视频,支持文生视频,图生生视频。
|
||||
- Bug 修复:修复手机端登录页面 Logo 无法修改的问题。
|
||||
- 功能新增:重构所有异步任务(绘图,音乐,视频)更新方式,使用 http pull 来替代 websocket。
|
||||
- 功能优化:优化 Luma 图生视频功能,支持本地上传图片和远程图片。
|
||||
- Bug 修复:修复移动端聊天页面新建对话时候角色没有更模型绑定的 Bug。
|
||||
- 功能优化:优化聊天页面代码块样式,优化公式的解析。
|
||||
- 功能优化:在绘图,视频相关 API 增加提示词长度的检查,防止提示词超出导致写入数据库失败。
|
||||
- Bug 修复:优化 Redis 连接池配置,增加连接池超时时间,单核服务器报错 `redis: connection pool timeout`。
|
||||
- 功能优化:优化邮件验证码发送逻辑,更新邮件发送成功提示。
|
||||
|
||||
## v4.2.0
|
||||
|
||||
- 功能优化:优化聊天页面 Notice 组件样式,采用 Vuepress 文档样式
|
||||
- Bug 修复:修复主题切换的组件显示异常问题
|
||||
- 功能优化:支持 DeepSeek-R1 推理模型,优化推理样式输出
|
||||
- 功能优化:优化 Suno 歌曲播放按钮样式,居中显示
|
||||
- 功能优化:后台管理新增模型的时候,可以绑定所有的 API KEY,而不只是能绑定 Chat 类型的 API KEY
|
||||
- 功能新增:新增每日签到功能,每日签到可以获得算力奖励
|
||||
- 功能优化:兼容 OpenAI o3 系列模型
|
||||
- 功能优化:API 默认开启允许跨域调用
|
||||
- 功能优化:优化 docker-compose.yaml 配置,增加各容器依赖关系
|
||||
|
||||
## v4.1.9
|
||||
|
||||
- 功能优化:优化系统配置,移除已废弃的配置项
|
||||
- 功能优化:GPT-O1 模型支持流式输出
|
||||
- 功能优化:优化代码引用快样式,支持主题切换
|
||||
- 功能优化:登录,注册页面允许替换用户自己的 Logo 和 Title
|
||||
- Bug 修复:修复 OpenAI 实时语音通话没有检测用户算力不足的 Bug
|
||||
- 功能新增:管理后台增加算力日志查询功能,支持按用户,按模型,按日期,按类型查询算力日志
|
||||
- 功能优化:支持为模型绑定 Dalle 和 chat 类型的 API KEY
|
||||
- 功能新增:支持管理后台设置 ICP 备案号
|
||||
|
||||
## v4.1.8
|
||||
|
||||
- 功能优化:**UI 全新改版,支持主题切换**。 :rocket: :rocket: :rocket:
|
||||
- 功能新增:Gitee AI API 接口接入,目前支持 Gitee 的 SD 绘图接口,支持 Gitee 的 AI 对话接口。:rocket: :rocket: :rocket:
|
||||
- Bug 修复:修复音 Luma API 更新导致任务响应解析失败的错误
|
||||
- 功能优化:支持 Suno v4.0 模型支持
|
||||
- Bug 修复:修复 Suno 已完成任务删除失败的 错误
|
||||
- 功能新增:支持 OpenAI 实时语音通话功能,目前已经支持按次收费,支持管理员设置每次实时语音通话的算力消耗
|
||||
- 功能新增:生成提示词需要消耗算力,支持管理员设置每次生成提示词的算力消耗,防止被白嫖
|
||||
- 功能新增:DALL-E-3 绘图支持 Flux 绘图模型,支持在管理后添加 Flux,SD 等绘图模型
|
||||
- 功能优化:Markdown 支持解析 emoji 表情
|
||||
- 功能优化:当管理后台禁用了某个绘图菜单的时候,移动端绘图菜单也会同步禁用(不显示该功能)
|
||||
|
||||
## v4.1.7
|
||||
|
||||
- Bug 修复:手机邮箱相关的注册问题 [#IB0HS5](https://gitee.com/blackfox/geekai/issues/IB0HS5)
|
||||
- Bug 修复:音乐视频无法下载,思维导图下载后看不清文字[#IB0N2E](https://gitee.com/blackfox/geekai/issues/IB0N2E)
|
||||
- 功能优化:保存所有 AIGC 任务的原始信息,程序启动之后自动将未执行的任务加入到 redis 队列
|
||||
- 功能优化:失败的任务自动退回算力,而不需要在删除的时候再退回
|
||||
- 功能新增:支持设置一个专门的模型来翻译提示词,提供 Mate 提示词生成功能
|
||||
- Bug 修复:修复图片对话的时候,上下文不起作用的 Bug
|
||||
- 功能新增:管理后台新增批量导出兑换码功能
|
||||
* Bug修复:手机邮箱相关的注册问题 [#IB0HS5](https://gitee.com/blackfox/geekai/issues/IB0HS5)
|
||||
* Bug修复:音乐视频无法下载,思维导图下载后看不清文字[#IB0N2E](https://gitee.com/blackfox/geekai/issues/IB0N2E)
|
||||
* 功能优化:保存所有AIGC任务的原始信息,程序启动之后自动将未执行的任务加入到 redis 队列
|
||||
* 功能优化:失败的任务自动退回算力,而不需要在删除的时候再退回
|
||||
* 功能新增:支持设置一个专门的模型来翻译提示词,提供 Mate 提示词生成功能
|
||||
* Bug修复:修复图片对话的时候,上下文不起作用的Bug
|
||||
* 功能新增:管理后台新增批量导出兑换码功能
|
||||
|
||||
## v4.1.6
|
||||
|
||||
- 功能新增:**支持 OpenAI 实时语音对话功能** :rocket: :rocket: :rocket:, Beta 版,目前没有做算力计费控制,目前只有 VIP 用户可以使用。
|
||||
- 功能优化:优化 MysQL 容器配置文档,解决 MysQL 容器资源占用过高问题
|
||||
- 功能新增:管理后台增加 AI 绘图任务管理,可在管理后台浏览和删除用户的绘图任务
|
||||
- 功能新增:管理后台增加 Suno 和 Luma 任务管理功能
|
||||
- Bug 修复:修复管理后台删除兑换码报 404 错误
|
||||
- 功能优化:优化充值产品定价逻辑,可以设置原价和优惠价,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**。
|
||||
* 功能新增:**支持OpenAI实时语音对话功能** :rocket: :rocket: :rocket:, Beta 版,目前没有做算力计费控制,目前只有 VIP 用户可以使用。
|
||||
* 功能优化:优化MysQL容器配置文档,解决MysQL容器资源占用过高问题
|
||||
* 功能新增:管理后台增加AI绘图任务管理,可在管理后台浏览和删除用户的绘图任务
|
||||
* 功能新增:管理后台增加Suno和Luma任务管理功能
|
||||
* Bug修复:修复管理后台删除兑换码报 404 错误
|
||||
* 功能优化:优化充值产品定价逻辑,可以设置原价和优惠价,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**,**升级当前版本之后请务必要到管理后台去重新设置一下产品价格,以免造成损失!!!**。
|
||||
|
||||
## v4.1.5
|
||||
|
||||
- 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
|
||||
- Bug 修复:兼容手机端原生微信支付和支付宝支付渠道
|
||||
- Bug 修复:修复删除绘图任务时候因为字段长度过短导致 SQL 执行失败问题
|
||||
- 功能优化:优化 Vue 组件通信代码,使用共享数据来替换之前的事件订阅模式,效率更高一些
|
||||
- 功能优化:优化思维导图生成功果页面,优化用户体验
|
||||
* 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
|
||||
* Bug修复:兼容手机端原生微信支付和支付宝支付渠道
|
||||
* Bug修复:修复删除绘图任务时候因为字段长度过短导致SQL执行失败问题
|
||||
* 功能优化:优化 Vue 组件通信代码,使用共享数据来替换之前的事件订阅模式,效率更高一些
|
||||
* 功能优化:优化思维导图生成功果页面,优化用户体验
|
||||
|
||||
## v4.1.4
|
||||
|
||||
- 功能优化:用户文件列表组件增加分页功能支持
|
||||
- Bug 修复:修复用户注册失败 Bug,注册操作只弹出一次行为验证码
|
||||
- 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
|
||||
- 功能新增:给 AI 应用(角色)增加分类,前端支持分类筛选
|
||||
- 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
|
||||
- 功能优化:移除 PayJS 支付渠道支持,PayJs 已经关闭注册服务,请使用其他支付方式。
|
||||
- 功能新增:新增 GeeK 易支付支付渠道,支持支付宝,微信支付,QQ 钱包,京东支付,抖音支付,Paypal 支付等支付方式
|
||||
- Bug 修复:修复注册页面 tab 组件没有自动选中问题 [#6](https://github.com/yangjian102621/geekai-plus/issues/6)
|
||||
- 功能优化:Luma 生成视频任务增加自动翻译功能
|
||||
- Bug 修复:Suno 和 Luma 任务没有判断用户算力
|
||||
- 功能新增:邮箱注册增加邮箱后缀白名单,防止使用某些垃圾邮箱注册薅羊毛
|
||||
- 功能优化:清空未支付订单时,只清空超过 15 分钟未支付的订单
|
||||
* 功能优化:用户文件列表组件增加分页功能支持
|
||||
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
|
||||
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
|
||||
* 功能新增:给 AI 应用(角色)增加分类,前端支持分类筛选
|
||||
* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
|
||||
* 功能优化:移除PayJS支付渠道支持,PayJs已经关闭注册服务,请使用其他支付方式。
|
||||
* 功能新增:新增GeeK易支付支付渠道,支持支付宝,微信支付,QQ钱包,京东支付,抖音支付,Paypal支付等支付方式
|
||||
* Bug修复:修复注册页面 tab 组件没有自动选中问题 [#6](https://github.com/yangjian102621/geekai-plus/issues/6)
|
||||
* 功能优化:Luma生成视频任务增加自动翻译功能
|
||||
* Bug修复:Suno 和 Luma 任务没有判断用户算力
|
||||
* 功能新增:邮箱注册增加邮箱后缀白名单,防止使用某些垃圾邮箱注册薅羊毛
|
||||
* 功能优化:清空未支付订单时,只清空超过15分钟未支付的订单
|
||||
|
||||
## v4.1.3
|
||||
|
||||
- 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
|
||||
- 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
|
||||
- 功能优化:管理后台给可以拖动排序的组件添加拖动图标
|
||||
- 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
|
||||
- Bug 修复:手机端角色和模型选择不生效
|
||||
- Bug 修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
|
||||
- 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
|
||||
- 功能优化:给算力增减服务全部加上数据库事务和同步锁
|
||||
- 功能优化:支持用户在前端对话界面选择插件
|
||||
- 功能新增:支持 Luma 文生视频功能
|
||||
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
|
||||
* 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
|
||||
* 功能优化:管理后台给可以拖动排序的组件添加拖动图标
|
||||
* 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
|
||||
* Bug修复:手机端角色和模型选择不生效
|
||||
* Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
|
||||
* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
|
||||
* 功能优化:给算力增减服务全部加上数据库事务和同步锁
|
||||
* 功能优化:支持用户在前端对话界面选择插件
|
||||
* 功能新增:支持 Luma 文生视频功能
|
||||
|
||||
## v4.1.2
|
||||
|
||||
- Bug 修复:修复思维导图页面获取模型失败的问题
|
||||
- 功能优化:优化 MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
|
||||
- Bug 修复:修复后台拖动排序组件 Bug
|
||||
- 功能优化:更新数据库失败时候显示具体的的报错信息
|
||||
- Bug 修复:修复管理后台对话详情页内容显示异常问题
|
||||
- 功能优化:管理后台新增清空所有未支付订单的功能
|
||||
- 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
|
||||
- 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
|
||||
* Bug修复:修复思维导图页面获取模型失败的问题
|
||||
* 功能优化:优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
|
||||
* Bug修复:修复后台拖动排序组件 Bug
|
||||
* 功能优化:更新数据库失败时候显示具体的的报错信息
|
||||
* Bug修复:修复管理后台对话详情页内容显示异常问题
|
||||
* 功能优化:管理后台新增清空所有未支付订单的功能
|
||||
* 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
|
||||
* 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
|
||||
|
||||
## v4.1.1
|
||||
|
||||
- Bug 修复:修复 GPT 模型 function call 调用后没有输出的问题
|
||||
- 功能新增:允许获取 License 授权用户可以自定义版权信息
|
||||
- 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
|
||||
- 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
|
||||
- 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
|
||||
- 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
|
||||
- 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
|
||||
- 功能新增:允许在管理后台设置首页显示的导航菜单
|
||||
- Bug 修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
|
||||
- 功能新增:增加 Suno 文生歌曲功能
|
||||
- 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
|
||||
- 功能优化:在所有列表页面增加返回顶部按钮
|
||||
* Bug修复:修复 GPT 模型 function call 调用后没有输出的问题
|
||||
* 功能新增:允许获取 License 授权用户可以自定义版权信息
|
||||
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
|
||||
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
|
||||
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
|
||||
* 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
|
||||
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
|
||||
* 功能新增:允许在管理后台设置首页显示的导航菜单
|
||||
* Bug修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
|
||||
* 功能新增:增加 Suno 文生歌曲功能
|
||||
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
|
||||
* 功能优化:在所有列表页面增加返回顶部按钮
|
||||
|
||||
## v4.1.0
|
||||
* bug修复:修复移动端修改聊天标题不生效的问题
|
||||
* Bug修复:修复用户注册不显示用户名的问题
|
||||
* Bug修复:修复管理后台拖动排序不生效的问题
|
||||
* 功能优化:允许用户设置自定义首页背景图片
|
||||
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
|
||||
* 功能优化:优化聊天界面的用户上传文件的列表样式
|
||||
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
|
||||
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
|
||||
|
||||
- bug 修复:修复移动端修改聊天标题不生效的问题
|
||||
- Bug 修复:修复用户注册不显示用户名的问题
|
||||
- Bug 修复:修复管理后台拖动排序不生效的问题
|
||||
- 功能优化:允许用户设置自定义首页背景图片
|
||||
- 功能新增:**支持 AI 解读 PDF, Word, Excel 等文件**
|
||||
- 功能优化:优化聊天界面的用户上传文件的列表样式
|
||||
- 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
|
||||
- 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
|
||||
|
||||
## v4.0.9
|
||||
|
||||
- 环境升级:升级 Golang 到 go1.22.4
|
||||
- 功能增加:接入微信商户号支付渠道
|
||||
- Bug 修复:修复前端页面菜单把页面撑开,底部留白问题
|
||||
- 功能优化:聊天页面自动根据内容调整输入框的高度
|
||||
- Bug 修复:修复 Dalle 绘图失败退回算力的问题
|
||||
- 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
|
||||
- 功能优化:允许设置邮件验证码的抬头
|
||||
- Bug 修复:修复免费模型不会记录聊天记录的 bug
|
||||
- Bug 修复:修复聊天输入公式显示异常的 Bug
|
||||
* 环境升级:升级 Golang 到 go1.22.4
|
||||
* 功能增加:接入微信商户号支付渠道
|
||||
* Bug修复:修复前端页面菜单把页面撑开,底部留白问题
|
||||
* 功能优化:聊天页面自动根据内容调整输入框的高度
|
||||
* Bug修复:修复Dalle绘图失败退回算力的问题
|
||||
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
|
||||
* 功能优化:允许设置邮件验证码的抬头
|
||||
* Bug修复:修复免费模型不会记录聊天记录的bug
|
||||
* Bug修复:修复聊天输入公式显示异常的Bug
|
||||
|
||||
## v4.0.8
|
||||
|
||||
- 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
|
||||
- 功能优化:当数据库更新失败的时候记录错误日志
|
||||
- 功能优化:聊天输入框会随着输入内容的增多自动调整高度
|
||||
- Bug 修复:修复移动端聊天页面模型切换不生效的 Bug
|
||||
- 功能优化:给 PC 端扫码支付增加签名验证和有效期验证
|
||||
- Bug 修复:修复支付码生成 API 权限控制的问题
|
||||
- Bug 修复:模型算力设置为 0 时,不扣减用户算力,并且不记录算力消费日志
|
||||
- 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
|
||||
- 功能新增:H5 端支持 Dalle 绘图
|
||||
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
|
||||
* 功能优化:当数据库更新失败的时候记录错误日志
|
||||
* 功能优化:聊天输入框会随着输入内容的增多自动调整高度
|
||||
* Bug修复:修复移动端聊天页面模型切换不生效的Bug
|
||||
* 功能优化:给PC端扫码支付增加签名验证和有效期验证
|
||||
* Bug修复:修复支付码生成API权限控制的问题
|
||||
* Bug修复:模型算力设置为0时,不扣减用户算力,并且不记录算力消费日志
|
||||
* 功能优化:新增随机背景配置项,可以在后台设置,首页使用 Bing 壁纸作为背景图片
|
||||
* 功能新增:H5端支持 Dalle 绘图
|
||||
|
||||
## v4.0.7
|
||||
|
||||
- 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
|
||||
- Bug 修复:修复弹窗等于图形验证码一直验证失败的问题
|
||||
- 功能重构:重构前端 UI 页面,增加顶部导航
|
||||
- 功能优化:优化 Vue 非父子组件之间的通信方式
|
||||
- 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
|
||||
* 功能优化:升级quic-go,支持 Go1.21
|
||||
* 功能优化:添加导航菜单的时候支持框入外部链接,并支持上传自定义菜单图片
|
||||
* Bug修复:修复弹窗等于图形验证码一直验证失败的问题
|
||||
* 功能重构:重构前端 UI 页面,增加顶部导航
|
||||
* 功能优化:优化 Vue 非父子组件之间的通信方式
|
||||
* 功能优化:优化 ItemList 组件,自动根据页面宽度计算 cols 数量
|
||||
|
||||
## v4.0.6
|
||||
|
||||
- Bug 修复:修复 PC 端画廊页面的瀑布流组件样式错乱问题
|
||||
- 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
|
||||
- Bug 修复:修复思维导图不扣费的 Bug
|
||||
- Bug 修复:修复管理后台角色删除失败的 Bug
|
||||
- Bug 修复:兼容最新版秋叶 SD 懒人包的 SD API,新增 scheduler 参数
|
||||
- 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
|
||||
- Bug 修复:修复注册用户提示注册人数达到上限的 Bug
|
||||
- 功能优化:将 MJ,SD,Dall 绘画页面的任务列表全改成瀑布流组件
|
||||
* Bug修复:修复PC端画廊页面的瀑布流组件样式错乱问题
|
||||
* 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
|
||||
* Bug修复:修复思维导图不扣费的Bug
|
||||
* Bug修复:修复管理后台角色删除失败的Bug
|
||||
* Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数
|
||||
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
|
||||
* Bug修复:修复注册用户提示注册人数达到上限的 Bug
|
||||
* 功能优化:将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
|
||||
|
||||
## v4.0.5
|
||||
|
||||
- 功能优化:已授权系统在后台显示授权信息
|
||||
- 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
|
||||
- 功能优化:优化首页登录注册页面的 UI
|
||||
- BUG 修复:修复 License 验证的逻辑漏洞
|
||||
- Bug 修复:后台添加用户的时候密码规则限制跟前台注册保持一致
|
||||
- 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
|
||||
- 功能新增:移动端新增 DALL-E 绘画功能
|
||||
- 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
|
||||
- 功能新增:移动支持免登录预览功能
|
||||
- Bug 修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
|
||||
- Bug 修复:修复部分中转 API 模型会出现第一输出的字符被淹没的 Bug
|
||||
* 功能优化:已授权系统在后台显示授权信息
|
||||
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
|
||||
* 功能优化:优化首页登录注册页面的 UI
|
||||
* BUG修复:修复License验证的逻辑漏洞
|
||||
* Bug修复:后台添加用户的时候密码规则限制跟前台注册保持一致
|
||||
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
|
||||
* 功能新增:移动端新增 DALL-E 绘画功能
|
||||
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
|
||||
* 功能新增:移动支持免登录预览功能
|
||||
* Bug修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
|
||||
* Bug修复:修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
|
||||
|
||||
## v4.0.4
|
||||
|
||||
- Bug 修复:修复统一千问第二句不回复的问题
|
||||
- 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
|
||||
- 功能新增:Dalle AI 绘画功能实现
|
||||
- Bug 修复:修复思维导图格式乱码问题
|
||||
- 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
|
||||
- 功能新增:支持从应用列表直接和某个应用对话
|
||||
- 功能优化:优化算力日志的页面和首页的 UI
|
||||
- 功能新增:支持思维导图导出 PNG 图片下载
|
||||
* Bug修复:修复统一千问第二句不回复的问题
|
||||
* 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
|
||||
* 功能新增:Dalle AI 绘画功能实现
|
||||
* Bug修复:修复思维导图格式乱码问题
|
||||
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
|
||||
* 功能新增:支持从应用列表直接和某个应用对话
|
||||
* 功能优化:优化算力日志的页面和首页的UI
|
||||
* 功能新增:支持思维导图导出 PNG 图片下载
|
||||
|
||||
## v4.0.3
|
||||
|
||||
- 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
|
||||
- Bug 修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
|
||||
- Bug 修复:修复 MidJourney 在任务超时后出现后面的任务覆盖前面任务的问题
|
||||
- 功能新增:支持上传图片和视觉模型
|
||||
- 功能优化:优化聊天页面的复制代码按钮样式乱码
|
||||
- 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
|
||||
- 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用 GPT3.5 或者 GPT4
|
||||
- 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
|
||||
- 功能新增:支持管理后台 Logo 修改
|
||||
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
|
||||
* Bug修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
|
||||
* Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
|
||||
* 功能新增:支持上传图片和视觉模型
|
||||
* 功能优化:优化聊天页面的复制代码按钮样式乱码
|
||||
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
|
||||
* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4
|
||||
* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
|
||||
* 功能新增:支持管理后台 Logo 修改
|
||||
|
||||
## 4.0.2
|
||||
|
||||
- 功能新增:支持前端菜单可以配置
|
||||
- 功能优化:在登录和注册界面标题显示软件版本号
|
||||
- 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
|
||||
- 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
|
||||
- Bug 修复:解决因为图片上传使用相对路径而导致融图失败的问题。
|
||||
- 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||
- 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
||||
* 功能新增:支持前端菜单可以配置
|
||||
* 功能优化:在登录和注册界面标题显示软件版本号
|
||||
* 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
|
||||
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
|
||||
* Bug修复:解决因为图片上传使用相对路径而导致融图失败的问题。
|
||||
* 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
||||
|
||||
## v4.0.1
|
||||
|
||||
- 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
|
||||
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
|
||||
发行版,稳定性更强一些
|
||||
- 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
|
||||
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
|
||||
MJ-Plus 中转
|
||||
- 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||
- Bug 修复:修复 iphone 手机无法通过图形验证码的 Bug,使用滑动验证码替换
|
||||
- Bug 修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的 Bug
|
||||
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
|
||||
* Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
|
||||
|
||||
## v4.0.0
|
||||
|
||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI 对话,MJ 绘画,SD 绘画,DALL 绘画)全部使用算力来兑换。
|
||||
只要你的算力值余额不为 0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗 1 个单位算力,一次 GPT4 对话消耗 10 个算力。一次 MJ
|
||||
对话消耗 15 个算力...
|
||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ
|
||||
对话消耗15个算力...
|
||||
|
||||
- 功能重构:重构整体系统,全部采用算力来进行结算
|
||||
- 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
||||
- 功能优化:移动端聊天页面图片支持预览和放大功能
|
||||
- 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
|
||||
- 功能优化:**PC 端不登录也可以预览功能,只有在发起操作的时候才需要登录**
|
||||
- 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
|
||||
- 功能新增:支持 H5 支付
|
||||
- 功能优化:支持数学公式的识别和美化输出
|
||||
- 功能新增:新增算力消费日志功能
|
||||
- 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
|
||||
- 功能新增:管理后台新增 7 日内新增用户和新增订单统计
|
||||
* 功能重构:重构整体系统,全部采用算力来进行结算
|
||||
* 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
||||
* 功能优化:移动端聊天页面图片支持预览和放大功能
|
||||
* 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
|
||||
* 功能优化:**PC端不登录也可以预览功能,只有在发起操作的时候才需要登录**
|
||||
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
|
||||
* 功能新增:支持H5支付
|
||||
* 功能优化:支持数学公式的识别和美化输出
|
||||
* 功能新增:新增算力消费日志功能
|
||||
* 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
|
||||
* 功能新增:管理后台新增7日内新增用户和新增订单统计
|
||||
|
||||
## v3.2.7
|
||||
|
||||
- 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
||||
- 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
||||
- Bug 修复:修复 issue [
|
||||
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
||||
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
||||
* Bug修复:修复 issue [
|
||||
管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
|
||||
- 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型
|
||||
- Bug 修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
|
||||
- Bug 修复:重新生成按钮功能失效问题
|
||||
- Bug 修复:对话输入 HTML 标签不显示的问题
|
||||
- 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
|
||||
- 功能新增:新增删除文件功能
|
||||
- Bug 修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
|
||||
- 功能新增:后台管理新怎对话查看和检索功能
|
||||
* 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型
|
||||
* Bug修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
|
||||
* Bug修复:重新生成按钮功能失效问题
|
||||
* Bug修复:对话输入HTML标签不显示的问题
|
||||
* 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
|
||||
* 功能新增:新增删除文件功能
|
||||
* Bug修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
|
||||
* 功能新增:后台管理新怎对话查看和检索功能
|
||||
|
||||
## v3.2.6
|
||||
|
||||
- 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
|
||||
- 功能优化:兼用旧版本微信收款消息解析
|
||||
- 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
|
||||
- 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
|
||||
- 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
|
||||
- 功能新增:新增网站公告,可以在管理后台自定义配置
|
||||
- 功能新增:新增阿里通义千问大模型支持
|
||||
- Bug 修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
|
||||
- 功能优化:新增虎皮椒和 PayJS 订单状态校验功能,增加安全性
|
||||
- Bug 修复:修复微信转账交易 ID 提取失败 Bug
|
||||
- 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
|
||||
- 功能新增:新增短信宝短信平台发送平台集成
|
||||
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
|
||||
* 功能优化:兼用旧版本微信收款消息解析
|
||||
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
|
||||
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
|
||||
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
|
||||
* 功能新增:新增网站公告,可以在管理后台自定义配置
|
||||
* 功能新增:新增阿里通义千问大模型支持
|
||||
* Bug修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
|
||||
* 功能优化:新增虎皮椒和PayJS订单状态校验功能,增加安全性
|
||||
* Bug修复:修复微信转账交易 ID 提取失败 Bug
|
||||
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
|
||||
* 功能新增:新增短信宝短信平台发送平台集成
|
||||
|
||||
## v3.2.5
|
||||
|
||||
- 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
|
||||
- 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
|
||||
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
|
||||
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
|
||||
Plus 账号了!!!
|
||||
- 功能优化:增强 markdown 图片和引用块解析。
|
||||
- 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
|
||||
- 功能优化:function call 兼用中转 API。
|
||||
- Bug 修复:修复部分已知的 Bug。
|
||||
* 功能优化:增强 markdown 图片和引用块解析。
|
||||
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
|
||||
* 功能优化:function call 兼用中转 API。
|
||||
* Bug修复:修复部分已知的 Bug。
|
||||
|
||||
## v3.2.4.1
|
||||
|
||||
- 功能新增:新增 PayJs 支付通道
|
||||
- Bug 修复:紧急修复后台添加用户失败问题
|
||||
- Bug 修复:紧急修复使用中转 API-KEY 无法绘图的问题
|
||||
- Bug 修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
|
||||
* 功能新增:新增 PayJs 支付通道
|
||||
* Bug修复:紧急修复后台添加用户失败问题
|
||||
* Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题
|
||||
* Bug修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
|
||||
|
||||
## v3.2.4
|
||||
|
||||
- 功能新增:重磅更新,支持邮箱注册
|
||||
- 功能优化:优化函数调用授权
|
||||
- 功能优化:给用户表新增 nickname 字段
|
||||
- 功能优化:管理后台给聊天角色增加启用/禁用开关
|
||||
- Bug 修复:SD 绘画出现重复扣减绘图次数
|
||||
- 功能优化:优化聊天对话导出样式,适应移动端
|
||||
- 功能新增:众筹核销可以选择兑换对话还是绘图的额度
|
||||
- Bug 修复:修复[从历史记录获取 reply 有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
|
||||
- Bug 修复:修复 MidJourney 绘图任务调度 Bug,为 task_id 建议唯一索引
|
||||
- 功能重构:重构了 API KEY 模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
|
||||
* 功能新增:重磅更新,支持邮箱注册
|
||||
* 功能优化:优化函数调用授权
|
||||
* 功能优化:给用户表新增 nickname 字段
|
||||
* 功能优化:管理后台给聊天角色增加启用/禁用开关
|
||||
* Bug修复:SD绘画出现重复扣减绘图次数
|
||||
* 功能优化:优化聊天对话导出样式,适应移动端
|
||||
* 功能新增:众筹核销可以选择兑换对话还是绘图的额度
|
||||
* Bug修复:修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
|
||||
* Bug修复:修复 MidJourney 绘图任务调度Bug,为 task_id 建议唯一索引
|
||||
* 功能重构:重构了 API KEY模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
|
||||
|
||||
## v3.2.3
|
||||
|
||||
- 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
|
||||
- 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
|
||||
- Bug 修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
|
||||
- Bug 修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
|
||||
- 功能新增:新增删除绘图任何和图片功能
|
||||
- Bug 修复:修复虎皮椒支付二维码重复扫码时报错问题
|
||||
- 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
|
||||
- 功能优化:优化 AI 绘画的大图压缩算法,新增图片缓存
|
||||
- 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
|
||||
- Bug 修复:修复[Azure Api 更换 api-version 参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
|
||||
- Bug 修复:修复科大讯飞 V1.5 API 请求失败的问题
|
||||
- Bug 修复:绘图失败后,自动恢复用户的剩余绘图次数
|
||||
- 功能新增:为移动端新增 SD 绘图功能,分享功能
|
||||
* 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
|
||||
* 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
|
||||
* Bug修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
|
||||
* Bug修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
|
||||
* 功能新增:新增删除绘图任何和图片功能
|
||||
* Bug修复:修复虎皮椒支付二维码重复扫码时报错问题
|
||||
* 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
|
||||
* 功能优化:优化AI绘画的大图压缩算法,新增图片缓存
|
||||
* 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
|
||||
* Bug修复:修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
|
||||
* Bug修复:修复科大讯飞 V1.5 API 请求失败的问题
|
||||
* Bug修复:绘图失败后,自动恢复用户的剩余绘图次数
|
||||
* 功能新增:为移动端新增 SD 绘图功能,分享功能
|
||||
|
||||
## v3.2.2
|
||||
|
||||
- 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
|
||||
- 功能新增:AI 绘画页面增加翻译和重写提示词功能
|
||||
- 功能优化:OSS 上传组件支持在 Bucket 下设置二级目录
|
||||
- Bug 修复:修复阿里云 OSS 访问路径错误
|
||||
- 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
|
||||
* 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
|
||||
* 功能新增:AI绘画页面增加翻译和重写提示词功能
|
||||
* 功能优化:OSS上传组件支持在 Bucket 下设置二级目录
|
||||
* Bug修复:修复阿里云 OSS 访问路径错误
|
||||
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
|
||||
|
||||
## v3.2.1
|
||||
|
||||
- 功能优化:切换角色和模型的时候自动创建新的对话
|
||||
- Bug 修复:修复文件上传失败 No such file bug
|
||||
- 功能新增:MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
|
||||
- Bug 修复:[PC 端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
|
||||
- 功能新增:增加 arm64 架构打包脚本
|
||||
- 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
|
||||
- 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
|
||||
* 功能优化:切换角色和模型的时候自动创建新的对话
|
||||
* Bug修复:修复文件上传失败No such file bug
|
||||
* 功能新增:MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
|
||||
* Bug修复:[PC端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
|
||||
* 功能新增:增加 arm64 架构打包脚本
|
||||
* 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
|
||||
* 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
|
||||
|
||||
## v3.2.0
|
||||
|
||||
- 功能新增:新增邀请注册功能
|
||||
- 功能优化:增加中间件自动对 HTTP 请求的参数去掉首尾空格
|
||||
- 功能优化:增加中间件自动为大图片生成缩略图
|
||||
- 功能优化:MidJourney 页面图片加载优化,实现图片预览懒加载
|
||||
- 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
|
||||
- Bug 修复:修复阿里云 OSS 域名设置不起做用的 bug
|
||||
- Bug 修复:修复 MidJourney 绘图失败后重复添加到队列的问题
|
||||
* 功能新增:新增邀请注册功能
|
||||
* 功能优化:增加中间件自动对HTTP请求的参数去掉首尾空格
|
||||
* 功能优化:增加中间件自动为大图片生成缩略图
|
||||
* 功能优化:MidJourney 页面图片加载优化,实现图片预览懒加载
|
||||
* 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
|
||||
* Bug修复:修复阿里云 OSS 域名设置不起做用的bug
|
||||
* Bug修复:修复MidJourney绘图失败后重复添加到队列的问题
|
||||
|
||||
## v3.1.9
|
||||
|
||||
- 功能新增:增加讯飞星火大模型 v3.0 支持
|
||||
- 功能新增:新增找回密码功能
|
||||
- 功能新增:支持 Markdown 代码复制功能
|
||||
- Bug 修复: xxl-job 任务调度失败的 Bug
|
||||
- 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font
|
||||
- Bug 修复:Stable-Diffusion 绘画成功之后没有扣减用户画图次数
|
||||
- 功能优化:优化会员充值页面 ItemList 组件
|
||||
- 功能优化:给首页 Logo 增加链接
|
||||
- Bug 修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51)
|
||||
- Bug 修复:聊天上下文失效问题
|
||||
- 功能优化:关闭注册时显示联系管理员二维码
|
||||
- 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能
|
||||
- Bug 修复:后台启用用户 VIP 不生效问题
|
||||
- 功能优化:充值支付页面的支付说明文字可以后台配置
|
||||
- Bug 修复:ChatGLM,百度文心,科大讯飞模型输出代码不换行问题
|
||||
* 功能新增:增加讯飞星火大模型 v3.0 支持
|
||||
* 功能新增:新增找回密码功能
|
||||
* 功能新增:支持 Markdown 代码复制功能
|
||||
* Bug修复: xxl-job 任务调度失败的 Bug
|
||||
* 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font
|
||||
* Bug修复:Stable-Diffusion 绘画成功之后没有扣减用户画图次数
|
||||
* 功能优化:优化会员充值页面 ItemList 组件
|
||||
* 功能优化:给首页 Logo 增加链接
|
||||
* Bug修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51)
|
||||
* Bug修复:聊天上下文失效问题
|
||||
* 功能优化:关闭注册时显示联系管理员二维码
|
||||
* 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能
|
||||
* Bug修复:后台启用用户 VIP 不生效问题
|
||||
* 功能优化:充值支付页面的支付说明文字可以后台配置
|
||||
* Bug修复:ChatGLM,百度文心,科大讯飞模型输出代码不换行问题
|
||||
|
||||
## v3.1.8
|
||||
|
||||
1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道
|
||||
2. Bug 修复:修复 MidJourney API 参数版本更新导致调用失败问题
|
||||
3. Bug 修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题
|
||||
4. Bug 修复:修复七牛云上传报错 expired token
|
||||
5. Bug 修复:修复高权重模型导致的对话次数为负数的漏洞
|
||||
2. Bug修复:修复 MidJourney API 参数版本更新导致调用失败问题
|
||||
3. Bug修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题
|
||||
4. Bug修复:修复七牛云上传报错 expired token
|
||||
5. Bug修复:修复高权重模型导致的对话次数为负数的漏洞
|
||||
6. 功能优化:将聊天报错信息定义为统一常量,方便修改
|
||||
7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式
|
||||
8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单
|
||||
|
||||
## v3.1.7
|
||||
|
||||
1. 功能新增:支持文心 4.0 AI 模型
|
||||
1. 功能新增:支持文心4.0 AI 模型
|
||||
2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型
|
||||
3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如 GPT4 是 GPT3.5 的 10 倍
|
||||
3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如GPT4是GPT3.5的10倍
|
||||
4. 功能新增:新增系统配置关闭 AI 模型的函数功能
|
||||
5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式
|
||||
|
||||
## v3.1.6
|
||||
|
||||
1. 功能新增:新增 AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品
|
||||
1. 功能新增:新增AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品
|
||||
2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用
|
||||
3. 功能优化:优化瀑布流组件的页面布局
|
||||
4. 功能优化:新注册用户成功之后自动登录
|
||||
@@ -470,55 +356,55 @@
|
||||
2. 功能新增:新增科大讯飞星火大模型 API 接入支持
|
||||
3. 功能重构:将 chat_handler 的所有功能实现放入单独的包中
|
||||
4. 功能新增:新增系统配置 `enabled_function` 用于启用和关闭函数功能
|
||||
5. Bug 修复:修复管理后台更新 API Key 失败的 Bug
|
||||
6. Bug 修复:修复新建的对话无法更新对话标题的 Bug
|
||||
5. Bug修复:修复管理后台更新 API Key 失败的 Bug
|
||||
6. Bug修复:修复新建的对话无法更新对话标题的 Bug
|
||||
7. 功能优化:其他一些小的体验优化工作
|
||||
|
||||
## v3.1.4
|
||||
|
||||
1. 功能新增:新增阿里云 OSS 图片上传实现,目前已支持本地存储,七牛云,Minio 和阿里云 OSS 四种存储介质。
|
||||
1. 功能新增:新增阿里云 OSS 图片上传实现,目前已支持本地存储,七牛云,Minio和阿里云 OSS 四种存储介质。
|
||||
2. 功能新增:**增加 Stable Diffusion 绘画功能页面**。
|
||||
3. 功能重构:将 [chatgpt-plus-exts](https://github.com/yangjian102621/chatgpt-plus-exts) 合并到本项目,部署更加简单,无需部署两个项目了。
|
||||
4. Bug 修复:修复[用户注册报错 BUG #37](https://github.com/yangjian102621/chatgpt-plus/issues/37)。
|
||||
5. Bug 修复:修复 MidJourney API 接口升级导致图片文保存失败的 Bug。
|
||||
4. Bug修复:修复[用户注册报错BUG #37](https://github.com/yangjian102621/chatgpt-plus/issues/37)。
|
||||
5. Bug修复:修复 MidJourney API 接口升级导致图片文保存失败的 Bug。
|
||||
6. 功能优化:增加阿里云短信服务配置项 `Sign` 和 `CodeTempId` 用来配置自己的短信签名和短信验证码模版 ID。
|
||||
7. 功能优化:添加系统配置用来设置自定义的众筹微信收款二维码。
|
||||
8. 功能优化:优化绘画页面的弹窗样式和页面布局。
|
||||
|
||||
## v3.1.3
|
||||
|
||||
1. 页面重构:重后 Home 页面,拆分成聊天,MJ 绘画,SD 绘画,应用广场等多个功能菜单。
|
||||
1. 页面重构:重后 Home 页面,拆分成聊天,MJ绘画,SD 绘画,应用广场等多个功能菜单。
|
||||
2. 功能新增:新增 MidJourney 专业绘画页面,开放更高级的 MJ 绘画姿势。
|
||||
3. 功能优化:采用队列的方式控制绘画任务并发,简化任务回调通知逻辑,给任务回调加锁。
|
||||
4. 功能优化:精简用户表字段,删除用户名和昵称,只保留手机号。
|
||||
5. 功能优化:优化文件上传服务工厂实现,只创建激活的 Uploader 服务,节省资源。
|
||||
6. Bug 修复:修复 JWT token 有效期计算错误的 Bug。
|
||||
6. Bug修复:修复 JWT token 有效期计算错误的 Bug。
|
||||
|
||||
## v3.1.2
|
||||
|
||||
1. 功能新增:新增七牛云 OSS 实现,目前已支持三种文件上传服务:Local, Minio, QiNiu OSS。
|
||||
2. 功能新增:新增桌面版,使用 electron 套壳网页版。
|
||||
3. Bug 修复:自动去除众筹核销时候转账单号中的空格,防止复制的时候多复制了空格。
|
||||
3. Bug修复:自动去除众筹核销时候转账单号中的空格,防止复制的时候多复制了空格。
|
||||
4. 功能优化:ChatPlus.vue 页面支持通过 chat_id path variable 来定位到指定的聊天。
|
||||
5. 功能优化:取消导出聊天页面的授权验证
|
||||
6. 功能优化:所有路由跳转都使用绝对路径
|
||||
|
||||
## v3.1.1
|
||||
|
||||
紧急修复版本,采用弹窗的方式显示验证码,解决验证码在低分辨率下被掩盖的 Bug
|
||||
紧急修复版本,采用弹窗的方式显示验证码,解决验证码在低分辨率下被掩盖的Bug
|
||||
|
||||
## v3.1.0(大版本更新)
|
||||
|
||||
1. 功能重构:将聊天模型独立拆分,以便支持多平台模型,目前已经内置支持 OPenAI,Azure 以及
|
||||
ChatGLM,用户可以在这两个平台的模型中随意切换,体验不同的模型聊天。
|
||||
2. 功能重构:重写系统 API 授权机制,使用 JWT 替换传统的 session 会话授权,使得 API 授权变得更加灵活。
|
||||
3. 功能重构:重构文件夹上传服务,支持多种文件上传存储 handler,目前已经实现本地存储和 minio oss 存储。
|
||||
3. 功能重构:重构文件夹上传服务,支持多种文件上传存储handler,目前已经实现本地存储和 minio oss 存储。
|
||||
4. 功能优化:更新头像自动删除旧的图片资源。
|
||||
5. 功能优化:将应用日志在终端输出的同时存盘,方便 docker 部署查看日志。
|
||||
6. 功能新增:允许用户配置自己的 OPenAI,Azure 以及 ChatGLM API KEY。
|
||||
7. 功能优化:优化移动版的行为验证码样式,修复低分辨率显示器验证码被遮挡的 Bug
|
||||
8. 升级 gin, element-plus,redis 组件到最新版本。
|
||||
9. Bug 修复:修复若干已知的的 Bug
|
||||
9. Bug修复:修复若干已知的的 Bug
|
||||
|
||||
## v3.0.7
|
||||
|
||||
@@ -528,7 +414,7 @@
|
||||
4. 功能新增:支持导出聊天记录为 PDF 文件。
|
||||
5. 功能优化:在后台 dashboard 页面新增统计今日众筹收入。
|
||||
6. 功能优化:支持用户设置默认的 GPT 模型
|
||||
7. Bug 修复:修复若干已知的的 Bug
|
||||
7. Bug修复:修复若干已知的的 Bug
|
||||
|
||||
## v3.0.6
|
||||
|
||||
@@ -536,8 +422,8 @@
|
||||
2. 管理后台:新增重置用户密码功能
|
||||
3. 管理后台:支持关闭注册功能,新增添加用户功能,适用于内部使用场景
|
||||
4. 管理后台:新增仪表盘页面,统计当天的新增用户,新增会话数据,以及 Token 消耗
|
||||
5. Bug 修复:修复注册页面验证码不显示 Bug
|
||||
6. Bug 修复:优化上下文 Token 计算算法,修复聊天上下文超出限制时循环发送消息的 Bug
|
||||
5. Bug修复:修复注册页面验证码不显示 Bug
|
||||
6. Bug修复:优化上下文 Token 计算算法,修复聊天上下文超出限制时循环发送消息的 Bug
|
||||
7. 功能修正:允许用户使用手机号码登录
|
||||
8. 功能优化:更新系统配置后同步更新服务端内存变量数据
|
||||
9. 功能优化:优化打包脚本,减少容器镜像大小
|
||||
@@ -595,5 +481,5 @@
|
||||
4. 新增聊天设置功能,用户可以导入自己的 API KEY
|
||||
5. 保存聊天记录,支持聊天上下文。
|
||||
6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。
|
||||
7. 引入 ip2region 组件,记录用户的登录 IP 和地址。
|
||||
8. 支持会话搜索过滤。
|
||||
7. 引入 ip2region 组件,记录用户的登录IP和地址。
|
||||
8. 支持会话搜索过滤。
|
||||
66
CLAUDE.md
66
CLAUDE.md
@@ -1,66 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build Commands
|
||||
|
||||
### Go Backend (api/)
|
||||
- **Development**: `cd api && go run main.go` (uses config.toml)
|
||||
- **Build**: `cd api && make` (builds both amd64 and arm64 binaries)
|
||||
- **Individual builds**: `make amd64` or `make arm64`
|
||||
- **Clean**: `make clean`
|
||||
- **Config**: Copy `config.sample.toml` to `config.toml` and configure
|
||||
|
||||
### Web Frontend (web/)
|
||||
- **Development**: `cd web && npm run dev` (runs on Vite dev server with --host)
|
||||
- **Build**: `cd web && npm run build`
|
||||
- **Lint**: `cd web && npm run lint` (ESLint with auto-fix)
|
||||
|
||||
### Testing
|
||||
- Backend tests: `cd api/test && bash run_crawler_test.sh`
|
||||
- No specific frontend test configuration found
|
||||
|
||||
## Project Architecture
|
||||
|
||||
### Backend (Go)
|
||||
- **Framework**: Gin web framework with dependency injection via uber-go/fx
|
||||
- **Database**: GORM with MySQL, Redis for caching, LevelDB for local storage
|
||||
- **Authentication**: JWT tokens with Redis session storage
|
||||
- **Middleware**: CORS, authorization, parameter handling, static resource serving
|
||||
- **Structure**:
|
||||
- `handler/`: HTTP request handlers (REST API endpoints)
|
||||
- `service/`: Business logic services (AI integrations, payments, etc.)
|
||||
- `store/`: Database models and data access layer
|
||||
- `core/`: Application server and middleware configuration
|
||||
- `utils/`: Utility functions and helpers
|
||||
|
||||
### Frontend (Vue.js)
|
||||
- **Framework**: Vue 3 with Composition API
|
||||
- **UI Components**: Element Plus + Vant (mobile components)
|
||||
- **State Management**: Pinia
|
||||
- **Routing**: Vue Router with nested routes
|
||||
- **Build Tool**: Vite
|
||||
- **CSS**: Stylus preprocessor with Tailwind CSS utilities
|
||||
- **Features**: Responsive design (desktop/mobile views), theme switching (dark/light)
|
||||
|
||||
### Key Features
|
||||
- **AI Chat**: Multiple chat models and conversation management
|
||||
- **Image Generation**: MidJourney, Stable Diffusion, DALL-E integration
|
||||
- **Audio/Video**: Suno music creation, Luma/KeLing video generation
|
||||
- **User Management**: Authentication, payments, power logs, invitations
|
||||
- **Admin Panel**: Comprehensive management interface
|
||||
|
||||
### Database Models
|
||||
Key entities: User, ChatItem, ChatMessage, ChatRole, ChatModel, Order, Product, AdminUser, and various job types for AI services.
|
||||
|
||||
### API Structure
|
||||
- User APIs: `/api/user/*` (auth, profile, settings)
|
||||
- Chat APIs: `/api/chat/*` (conversations, messages)
|
||||
- AI Service APIs: `/api/mj/*`, `/api/sd/*`, `/api/dall/*`, `/api/suno/*`, `/api/video/*`
|
||||
- Admin APIs: `/api/admin/*` (management functions)
|
||||
|
||||
### Configuration
|
||||
- Backend: TOML configuration file (`config.toml`)
|
||||
- Database: MySQL with automatic migrations
|
||||
- Services: Redis, various AI API integrations
|
||||
- File Storage: Local, Aliyun OSS, MinIO, Qiniu options
|
||||
148
README.md
148
README.md
@@ -1,77 +1,91 @@
|
||||
# 🚀 GeekAI-PLUS:一站式 AI 创意生产力平台
|
||||
# GeekAI
|
||||
> 根据[《生成式人工智能服务管理暂行办法》](https://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
**重新定义 AI 创作体验,让每个人都能成为内容创作大师**
|
||||
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Claude, 通义千问,Kimi,DeepSeek,Gitee AI 等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
|
||||
|
||||
基于 GeekAI 项目开发的高级版,增加了很多高级功能,比如思维导图,Dalle 绘画等。**高级版源码不会一次性开放,只提供镜像给大家免费使用**,源码会逐步逐步按照版同步迁移到[社区版(GeekAI)](https://github.com/yangjian102621/geekai)。所以如果大家想要二次开发,请移步去社区版。
|
||||
主要特性:
|
||||
|
||||
## ✨ 核心特色
|
||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||
- 基于 Websocket 实现,完美的打字机体验。
|
||||
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
|
||||
- 支持 OpenAI, Claude, 通义千问,Kimi,DeepSeek等多个大语言模型,**支持 Gitee AI Serverless 大模型 API**。
|
||||
- 支持 Suno 文生音乐
|
||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
|
||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
|
||||
绘画函数插件。
|
||||
|
||||
### 🎨 **全能 AI 创作矩阵**
|
||||
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
|
||||
|
||||
- **智能对话**:集成 ChatGPT、Claude 等多款顶级 AI 模型,支持角色扮演和专业对话
|
||||
- **图像生成**:整合 MidJourney、DALL-E、Stable Diffusion 三大主流 AI 绘画引擎
|
||||
- **音频创作**:Suno AI 音乐生成,从旋律到歌词一键创作专属音乐
|
||||
- **视频制作**:Luma 和 KeLing,即梦,Veo3 视频 AI,文本到视频,创意无限
|
||||
- **思维导图**:AI 辅助思维整理,复杂想法可视化呈现
|
||||
|
||||
### 🏗️ **企业级技术架构**
|
||||
|
||||
- **高性能后端**:Go + Gin + MySQL + Redis,支持高并发访问
|
||||
- **现代化前端**:Vue3 + Element Plus + Vant,桌面移动双端适配
|
||||
- **智能缓存**:多层缓存策略,响应速度提升 80%
|
||||
- **弹性部署**:Docker 容器化部署,一键启动,轻松扩展
|
||||
- **私有化部署**:支持私有化部署,私有化部署不支持升级,需要手动升级
|
||||
- **文档支持**:丰富且详细的部署和 API 开发文档支持,二次开发轻松上手
|
||||
|
||||
### 💼 **商业化就绪**
|
||||
|
||||
- **完整用户系统**:注册登录、权限管理、积分充值
|
||||
- **灵活计费模式**:支持按次付费、包月订阅等多种商业模式
|
||||
- **数据统计分析**:用户行为、消费记录、系统性能全方位监控
|
||||
- **管理后台**:功能完备的管理员界面,运营数据一目了然
|
||||
|
||||
### 🎯 **用户体验优势**
|
||||
|
||||
- **响应式设计**:完美适配桌面、平板、手机等全终端设备
|
||||
- **暗黑模式**:支持明暗主题切换,护眼舒适
|
||||
- **实时交互**:WebSocket 实时通信,创作过程流畅无卡顿
|
||||
- **文件管理**:支持多种云存储,作品安全可靠
|
||||
|
||||
## 🎪 **应用场景**
|
||||
|
||||
- **内容创作者**:博客写作、社交媒体素材、短视频制作
|
||||
- **企业营销**:品牌宣传材料、产品介绍、创意广告
|
||||
- **教育培训**:课件制作、知识图谱、互动内容
|
||||
- **个人娱乐**:AI 聊天、创意绘画、音乐创作
|
||||
|
||||
## 🔥 **为什么选择 GeekAI-PLUS?**
|
||||
|
||||
1. **技术领先**:集成当前最先进的 AI 技术,始终保持创新前沿
|
||||
2. **开箱即用**:完整的商业化解决方案,无需从零开发
|
||||
3. **高度定制**:模块化架构设计,支持个性化功能扩展
|
||||
4. **稳定可靠**:经过大量用户验证,性能稳定,安全可信
|
||||
5. **持续更新**:紧跟 AI 技术发展,功能持续迭代升级
|
||||
|
||||
## 演示站点
|
||||
|
||||
[Geek-AI 创作系统](https://www.geekai.me)
|
||||
|
||||
## 文档地址
|
||||
|
||||
[Geek-AI 文档](https://www.geekai.me/docs/)
|
||||
|
||||
## 部署
|
||||
|
||||
1. 安装 docker 和 docker-compose 程序,这个自行解决。
|
||||
2. 直接在项目根目录运行启动命令:
|
||||
```shell
|
||||
docker-compose up -d
|
||||
```
|
||||
- [x] 更友好的 UI 界面
|
||||
- [x] 支持 Dall-E 文生图功能
|
||||
- [x] 支持文生思维导图
|
||||
- [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能
|
||||
- [x] 支持网站 Logo 版权等信息的修改
|
||||
|
||||
## 功能截图
|
||||
|
||||
请参考 [GeekAI 项目介绍](https://docs.geekai.me/info/)。
|
||||
|
||||
---
|
||||
### 体验地址
|
||||
|
||||
_让 AI 成为你最强大的创作伙伴,开启无限创意可能!_
|
||||
> 免费体验地址:[https://chat.geekai.me](https://chat.geekai.me) <br/>
|
||||
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
|
||||
|
||||
## 快速部署
|
||||
|
||||
请参考文档 [**GeekAI 快速部署**](https://docs.geekai.me/install/)。
|
||||
|
||||
## 使用须知
|
||||
|
||||
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
|
||||
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
|
||||
|
||||
## 项目地址
|
||||
|
||||
* Github 地址:https://github.com/yangjian102621/geekai
|
||||
* 码云地址:https://gitee.com/blackfox/geekai
|
||||
|
||||
## 客户端下载
|
||||
|
||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
|
||||
|
||||
## TODOLIST
|
||||
|
||||
* [ ] 支持基于知识库的 AI 问答
|
||||
* [ ] 文生视频,文生歌曲功能
|
||||
* [ ] 微信支付功能
|
||||
|
||||
## 项目文档
|
||||
|
||||
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
|
||||
|
||||
详细的部署和开发文档请参考 [**GeekAI 文档**](https://docs.geekai.me)。
|
||||
|
||||
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。**
|
||||
|
||||

|
||||
|
||||
## 参与贡献
|
||||
|
||||
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
|
||||
|
||||
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合!
|
||||
|
||||
### Commit 类型
|
||||
|
||||
* feat: 新特性或功能
|
||||
* fix: 缺陷修复
|
||||
* docs: 文档更新
|
||||
* style: 代码风格或者组件样式更新
|
||||
* refactor: 代码重构,不引入新功能和缺陷修复
|
||||
* opt: 性能优化
|
||||
* chore: 一些不涉及到功能变动的小提交,比如修改文字表述,修改注释等
|
||||
|
||||
## 打赏
|
||||
|
||||
如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
1
api/.gitignore
vendored
1
api/.gitignore
vendored
@@ -17,6 +17,5 @@ bin
|
||||
data
|
||||
config.toml
|
||||
static/upload
|
||||
static/audio
|
||||
storage.json
|
||||
res/certs/wechat/apiclient_key.pem
|
||||
|
||||
@@ -3,11 +3,11 @@ NAME := geekai
|
||||
all: amd64 arm64
|
||||
|
||||
amd64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w" -o bin/$(NAME)-linux main.go
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go
|
||||
.PHONY: amd64
|
||||
|
||||
arm64:
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -ldflags "-s -w" -o bin/$(NAME)-linux main.go
|
||||
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go
|
||||
.PHONY: arm64
|
||||
|
||||
clean:
|
||||
|
||||
@@ -8,77 +8,91 @@ package core
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"golang.org/x/image/webp"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AppServer struct {
|
||||
Debug bool
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
SysConfig *types.SystemConfig // system config cache
|
||||
Redis *redis.Client
|
||||
}
|
||||
|
||||
func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer {
|
||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
Config: appConfig,
|
||||
Redis: redis,
|
||||
Engine: gin.Default(),
|
||||
SysConfig: sysConfig,
|
||||
Debug: false,
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AppServer) Init(client *redis.Client) {
|
||||
s.Engine.Use(middleware.ParameterHandlerMiddleware())
|
||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||
// 允许跨域请求 API
|
||||
s.Engine.Use(corsMiddleware())
|
||||
s.Engine.Use(staticResourceMiddleware())
|
||||
s.Engine.Use(authorizeMiddleware(s, client))
|
||||
s.Engine.Use(parameterHandlerMiddleware())
|
||||
s.Engine.Use(errorHandler)
|
||||
// 添加静态资源访问
|
||||
s.Engine.Static("/static", s.Config.StaticDir)
|
||||
s.Engine.Use(middleware.StaticMiddleware())
|
||||
}
|
||||
|
||||
func (s *AppServer) Run(db *gorm.DB) error {
|
||||
// load system configs
|
||||
var sysConfig model.Config
|
||||
err := db.Where("name", "system").First(&sysConfig).Error
|
||||
err := db.Where("marker", "system").First(&sysConfig).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load system config: %v", err)
|
||||
}
|
||||
err = utils.JsonDecode(sysConfig.Value, &s.SysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode system config: %v", err)
|
||||
}
|
||||
logger.Infof("http://%s", s.Config.Listen)
|
||||
|
||||
// 统计安装信息
|
||||
go func() {
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
apiURL := fmt.Sprintf("%s/api/installs/push", types.GeekAPIURL)
|
||||
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
|
||||
timestamp := time.Now().Unix()
|
||||
product := "geekai-plus"
|
||||
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
||||
sign := utils.Sha256(signStr)
|
||||
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
logger.Errorf("register install info failed: %v", err)
|
||||
} else {
|
||||
logger.Debugf("register install info success: %v", resp.String())
|
||||
}
|
||||
}
|
||||
}()
|
||||
logger.Infof("http://%s", s.Config.Listen)
|
||||
|
||||
return s.Engine.Run(s.Config.Listen)
|
||||
}
|
||||
|
||||
@@ -95,3 +109,287 @@ func errorHandler(c *gin.Context) {
|
||||
//加载完 defer recover,继续后续接口调用
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// 跨域中间件设置
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
// 设置允许的请求源
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
c.Header("Access-Control-Max-Age", "172800")
|
||||
//允许客户端传递校验信息比如 cookie (重要)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if method == http.MethodOptions {
|
||||
c.JSON(http.StatusOK, "ok!")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Info("Panic info is: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 用户授权验证
|
||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
var tokenString string
|
||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||
if isAdminApi { // 后台管理 API
|
||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||
} else if clientProtocols != "" { // Websocket 连接
|
||||
// 解析子协议内容
|
||||
protocols := strings.Split(clientProtocols, ",")
|
||||
if protocols[0] == "realtime" {
|
||||
tokenString = strings.TrimSpace(protocols[1][25:])
|
||||
} else if protocols[0] == "token" {
|
||||
tokenString = strings.TrimSpace(protocols[1])
|
||||
}
|
||||
} else {
|
||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
if needLogin(c) {
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
} else { // 直接放行
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
if isAdminApi {
|
||||
return []byte(s.Config.AdminSession.SecretKey), nil
|
||||
} else {
|
||||
return []byte(s.Config.Session.SecretKey), nil
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
if err != nil && needLogin(c) {
|
||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid && needLogin(c) {
|
||||
resp.NotAuth(c, "Token is invalid")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
|
||||
resp.NotAuth(c, "Token is expired")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||
if isAdminApi {
|
||||
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
||||
}
|
||||
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
|
||||
resp.NotAuth(c, "Token is not found in redis")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(types.LoginUserID, claims["user_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func needLogin(c *gin.Context) bool {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/user/logout" ||
|
||||
c.Request.URL.Path == "/api/user/resetPass" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/admin/logout" ||
|
||||
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
c.Request.URL.Path == "/api/chat/history" ||
|
||||
c.Request.URL.Path == "/api/chat/detail" ||
|
||||
c.Request.URL.Path == "/api/chat/list" ||
|
||||
c.Request.URL.Path == "/api/app/list" ||
|
||||
c.Request.URL.Path == "/api/app/type/list" ||
|
||||
c.Request.URL.Path == "/api/app/list/user" ||
|
||||
c.Request.URL.Path == "/api/model/list" ||
|
||||
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||
c.Request.URL.Path == "/api/mj/notify" ||
|
||||
c.Request.URL.Path == "/api/invite/hits" ||
|
||||
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||
c.Request.URL.Path == "/api/product/list" ||
|
||||
c.Request.URL.Path == "/api/menu/list" ||
|
||||
c.Request.URL.Path == "/api/markMap/client" ||
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
c.Request.URL.Path == "/api/download" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 统一参数处理
|
||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// GET 参数处理
|
||||
params := c.Request.URL.Query()
|
||||
for key, values := range params {
|
||||
for i, value := range values {
|
||||
params[key][i] = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
// update get parameters
|
||||
c.Request.URL.RawQuery = params.Encode()
|
||||
// skip file upload requests
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// process POST JSON request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 还原请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// 将请求体解析为 JSON
|
||||
var jsonData map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&jsonData); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 对 JSON 数据中的字符串值去除两端空格
|
||||
trimJSONStrings(jsonData)
|
||||
// 更新请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 递归对 JSON 数据中的字符串值去除两端空格
|
||||
func trimJSONStrings(data interface{}) {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[key] = strings.TrimSpace(valueType)
|
||||
case map[string]interface{}, []interface{}:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for i, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[i] = strings.TrimSpace(valueType)
|
||||
case map[string]interface{}, []interface{}:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 静态资源中间件
|
||||
func staticResourceMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
url := c.Request.URL.String()
|
||||
// 拦截生成缩略图请求
|
||||
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
|
||||
r := strings.SplitAfter(url, "imageView2")
|
||||
size := strings.Split(r[1], "/")
|
||||
if len(size) != 8 {
|
||||
c.String(http.StatusNotFound, "invalid thumb args")
|
||||
return
|
||||
}
|
||||
with := utils.IntValue(size[3], 0)
|
||||
height := utils.IntValue(size[5], 0)
|
||||
quality := utils.IntValue(size[7], 75)
|
||||
|
||||
// 打开图片文件
|
||||
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
c.String(http.StatusNotFound, "Image not found")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 解码图片
|
||||
img, _, err := image.Decode(file)
|
||||
// for .webp image
|
||||
if err != nil {
|
||||
img, err = webp.Decode(file)
|
||||
}
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||
return
|
||||
}
|
||||
|
||||
var newImg image.Image
|
||||
if height == 0 || with == 0 {
|
||||
// 固定宽度,高度自适应
|
||||
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
|
||||
} else {
|
||||
// 生成缩略图
|
||||
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置图片缓存有效期为一年 (365天)
|
||||
c.Header("Cache-Control", "max-age=31536000, public")
|
||||
// 直接输出图像数据流
|
||||
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
||||
c.Abort() // 中断请求
|
||||
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,10 @@ import (
|
||||
"bytes"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
@@ -32,6 +30,7 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
SecretKey: utils.RandString(64),
|
||||
MaxAge: 86400,
|
||||
},
|
||||
ApiConfig: types.ApiConfig{},
|
||||
OSS: types.OSSConfig{
|
||||
Active: "local",
|
||||
Local: types.LocalStorageConfig{
|
||||
@@ -39,6 +38,7 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
BasePath: "./static/upload",
|
||||
},
|
||||
},
|
||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,108 +74,3 @@ func SaveConfig(config *types.AppConfig) error {
|
||||
|
||||
return os.WriteFile(config.Path, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
|
||||
// 加载系统配置
|
||||
var sysConfig model.Config
|
||||
var baseConfig types.BaseConfig
|
||||
db.Where("name", "system").First(&sysConfig)
|
||||
err := utils.JsonDecode(sysConfig.Value, &baseConfig)
|
||||
if err != nil {
|
||||
logger.Error("load system config error: ", err)
|
||||
}
|
||||
|
||||
// 加载许可证配置
|
||||
var license types.License
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyLicense).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &license)
|
||||
if err != nil {
|
||||
logger.Error("load license config error: ", err)
|
||||
}
|
||||
|
||||
// 加载验证码配置
|
||||
var captchaConfig types.CaptchaConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyCaptcha).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &captchaConfig)
|
||||
if err != nil {
|
||||
logger.Error("load geek service config error: ", err)
|
||||
}
|
||||
|
||||
// 加载微信登录配置
|
||||
var wxLoginConfig types.WxLoginConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyWxLogin).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &wxLoginConfig)
|
||||
if err != nil {
|
||||
logger.Error("load wx login config error: ", err)
|
||||
}
|
||||
|
||||
// 加载短信配置
|
||||
var smsConfig types.SMSConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeySms).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &smsConfig)
|
||||
if err != nil {
|
||||
logger.Error("load sms config error: ", err)
|
||||
}
|
||||
|
||||
// 加载 OSS 配置
|
||||
var ossConfig types.OSSConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyOss).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &ossConfig)
|
||||
if err != nil {
|
||||
logger.Error("load oss config error: ", err)
|
||||
}
|
||||
|
||||
// 加载 SMTP 配置
|
||||
var smtpConfig types.SmtpConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeySmtp).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &smtpConfig)
|
||||
if err != nil {
|
||||
logger.Error("load smtp config error: ", err)
|
||||
}
|
||||
|
||||
// 加载支付配置
|
||||
var paymentConfig types.PaymentConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyPayment).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &paymentConfig)
|
||||
if err != nil {
|
||||
logger.Error("load payment config error: ", err)
|
||||
}
|
||||
|
||||
// 加载文本审查配置
|
||||
var moderationConfig types.ModerationConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyModeration).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &moderationConfig)
|
||||
if err != nil {
|
||||
logger.Error("load moderation config error: ", err)
|
||||
}
|
||||
|
||||
// 加载即梦AI配置
|
||||
var jimengConfig types.JimengConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyJimeng).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
logger.Error("load jimeng config error: ", err)
|
||||
}
|
||||
|
||||
return &types.SystemConfig{
|
||||
Base: baseConfig,
|
||||
License: license,
|
||||
SMS: smsConfig,
|
||||
OSS: ossConfig,
|
||||
SMTP: smtpConfig,
|
||||
Payment: paymentConfig,
|
||||
Captcha: captchaConfig,
|
||||
WxLogin: wxLoginConfig,
|
||||
Moderation: moderationConfig,
|
||||
Jimeng: jimengConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// 前端用户授权验证
|
||||
func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenString := c.GetHeader(types.UserAuthHeader)
|
||||
if tokenString == "" {
|
||||
resp.NotAuth(c, "无效的授权令牌")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "令牌无效")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "令牌过期")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(types.LoginUserID, claims["user_id"])
|
||||
}
|
||||
}
|
||||
|
||||
// 管理后台用户授权验证
|
||||
func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenString := c.GetHeader(types.AdminAuthHeader)
|
||||
if tokenString == "" {
|
||||
resp.NotAuth(c, "无效的授权令牌")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "令牌无效")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "令牌过期")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("admin/%v", claims["user_id"])
|
||||
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(types.AdminUserID, claims["user_id"])
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 统一参数处理
|
||||
func ParameterHandlerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// GET 参数处理
|
||||
params := c.Request.URL.Query()
|
||||
for key, values := range params {
|
||||
for i, value := range values {
|
||||
params[key][i] = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
// update get parameters
|
||||
c.Request.URL.RawQuery = params.Encode()
|
||||
// skip file upload requests
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// process POST JSON request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 还原请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// 将请求体解析为 JSON
|
||||
var jsonData map[string]any
|
||||
if err := c.ShouldBindJSON(&jsonData); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 对 JSON 数据中的字符串值去除两端空格
|
||||
trimJSONStrings(jsonData)
|
||||
// 更新请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 递归对 JSON 数据中的字符串值去除两端空格
|
||||
func trimJSONStrings(data any) {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
for key, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[key] = strings.TrimSpace(valueType)
|
||||
case map[string]any, []any:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for i, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[i] = strings.TrimSpace(valueType)
|
||||
case map[string]any, []any:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求
|
||||
// Key 优先使用登录用户ID,若没有则退化为 route + IP
|
||||
func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
keyID := ""
|
||||
if userID, ok := c.Get(types.LoginUserID); ok {
|
||||
keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID))
|
||||
} else {
|
||||
keyID = fmt.Sprintf("ip:%s", c.ClientIP())
|
||||
}
|
||||
|
||||
fullPath := c.FullPath()
|
||||
if fullPath == "" {
|
||||
fullPath = c.Request.URL.Path
|
||||
}
|
||||
key := fmt.Sprintf("rl:%s:%s", fullPath, keyID)
|
||||
|
||||
okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result()
|
||||
if err != nil {
|
||||
// Redis 异常时放行,避免误伤可用性
|
||||
return
|
||||
}
|
||||
if !okSet {
|
||||
c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"geekai/utils"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/nfnt/resize"
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// 静态资源中间件
|
||||
func StaticMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
url := c.Request.URL.String()
|
||||
// 拦截生成缩略图请求
|
||||
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
|
||||
r := strings.SplitAfter(url, "imageView2")
|
||||
size := strings.Split(r[1], "/")
|
||||
if len(size) != 8 {
|
||||
c.String(http.StatusNotFound, "invalid thumb args")
|
||||
return
|
||||
}
|
||||
with := utils.IntValue(size[3], 0)
|
||||
height := utils.IntValue(size[5], 0)
|
||||
quality := utils.IntValue(size[7], 75)
|
||||
|
||||
// 打开图片文件
|
||||
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
c.String(http.StatusNotFound, "Image not found")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 解码图片
|
||||
img, _, err := image.Decode(file)
|
||||
// for .webp image
|
||||
if err != nil {
|
||||
img, err = webp.Decode(file)
|
||||
}
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||
return
|
||||
}
|
||||
|
||||
var newImg image.Image
|
||||
if height == 0 || with == 0 {
|
||||
// 固定宽度,高度自适应
|
||||
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
|
||||
} else {
|
||||
// 生成缩略图
|
||||
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置图片缓存有效期为一年 (365天)
|
||||
c.Header("Cache-Control", "max-age=31536000, public")
|
||||
// 直接输出图像数据流
|
||||
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
||||
c.Abort() // 中断请求
|
||||
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -9,20 +9,20 @@ package types
|
||||
|
||||
// ApiRequest API 请求实体
|
||||
type ApiRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Messages []any `json:"messages,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []any `json:"functions,omitempty"` // 兼容中转平台
|
||||
ResponseFormat any `json:"response_format,omitempty"` // 响应格式
|
||||
Model string `json:"model,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Messages []interface{} `json:"messages,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||
ResponseFormat interface{} `json:"response_format,omitempty"` // 响应格式
|
||||
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
|
||||
Input map[string]any `json:"input,omitempty"` //兼容阿里通义千问
|
||||
Parameters map[string]any `json:"parameters,omitempty"` //兼容阿里通义千问
|
||||
Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -41,17 +41,27 @@ type ChoiceItem struct {
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Content any `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FunctionCall struct {
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Content interface{} `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
} `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatSession 聊天会话对象
|
||||
type ChatSession struct {
|
||||
UserId uint `json:"user_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model ChatModel `json:"model"` // GPT 模型
|
||||
Start int64 `json:"start"` // 开始请求时间戳
|
||||
Tools []int `json:"tools"` // 工具函数列表
|
||||
Stream bool `json:"stream"` // 是否采用流式输出
|
||||
}
|
||||
|
||||
type ChatModel struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -59,8 +69,6 @@ type ChatModel struct {
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Description string `json:"description"` //模型描述
|
||||
Category string `json:"category"` //模型类别
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id"` // 绑定 API KEY
|
||||
}
|
||||
@@ -87,7 +95,6 @@ const (
|
||||
PowerInvite = PowerType(4) // 邀请奖励
|
||||
PowerRedeem = PowerType(5) // 众筹
|
||||
PowerGift = PowerType(6) // 系统赠送
|
||||
PowerSignIn = PowerType(7) // 每日签到
|
||||
)
|
||||
|
||||
func (t PowerType) String() string {
|
||||
@@ -100,12 +107,7 @@ func (t PowerType) String() string {
|
||||
return "退款"
|
||||
case PowerRedeem:
|
||||
return "兑换"
|
||||
case PowerGift:
|
||||
return "赠送"
|
||||
case PowerInvite:
|
||||
return "邀请"
|
||||
case PowerSignIn:
|
||||
return "签到"
|
||||
|
||||
}
|
||||
return "其他"
|
||||
}
|
||||
|
||||
@@ -17,17 +17,87 @@ type AppConfig struct {
|
||||
Session Session
|
||||
AdminSession Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
GeekPayConfig EpayConfig // GEEK 支付配置
|
||||
WechatPayConfig WxPayConfig // 微信支付渠道配置
|
||||
TikaHost string // TiKa 服务器地址
|
||||
MysqlDns string // mysql 连接地址
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
XXLConfig XXLConfig
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
||||
GeekPayConfig GeekPayConfig // GEEK 支付配置
|
||||
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
||||
TikaHost string // TiKa 服务器地址
|
||||
}
|
||||
|
||||
type SmtpConfig struct {
|
||||
UseTls bool // 是否使用 TLS 发送
|
||||
Host string
|
||||
Port int
|
||||
AppName string // 应用名称
|
||||
From string // 发件人邮箱地址
|
||||
Password string // 发件人邮箱密码
|
||||
}
|
||||
|
||||
type ApiConfig struct {
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
}
|
||||
|
||||
type AlipayConfig struct {
|
||||
Enabled bool // 是否启用该支付通道
|
||||
SandBox bool // 是否沙盒环境
|
||||
AppId string // 应用 ID
|
||||
UserId string // 支付宝用户 ID
|
||||
PrivateKey string // 用户私钥文件路径
|
||||
PublicKey string // 用户公钥文件路径
|
||||
AlipayPublicKey string // 支付宝公钥文件路径
|
||||
RootCert string // Root 秘钥路径
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
}
|
||||
|
||||
type WechatPayConfig struct {
|
||||
Enabled bool // 是否启用该支付通道
|
||||
AppId string // 公众号的APPID,如:wxd678efh567hg6787
|
||||
MchId string // 直连商户的商户号,由微信支付生成并下发
|
||||
SerialNo string // 商户证书的证书序列号
|
||||
PrivateKey string // 用户私钥文件路径
|
||||
ApiV3Key string // API V3 秘钥
|
||||
NotifyURL string // 异步通知地址
|
||||
}
|
||||
|
||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||
Enabled bool // 是否启用该支付通道
|
||||
AppId string // App ID
|
||||
AppSecret string // app 密钥
|
||||
ApiURL string // 支付网关
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
}
|
||||
|
||||
// GeekPayConfig GEEK支付配置
|
||||
type GeekPayConfig struct {
|
||||
Enabled bool
|
||||
AppId string // 商户 ID
|
||||
PrivateKey string // 私钥
|
||||
ApiURL string // API 网关
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
Methods []string // 支付方式
|
||||
}
|
||||
|
||||
type XXLConfig struct { // XXL 任务调度配置
|
||||
Enabled bool
|
||||
ServerAddr string
|
||||
ExecutorIp string
|
||||
ExecutorPort string
|
||||
AccessToken string
|
||||
RegistryKey string
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
@@ -57,28 +127,28 @@ func (c RedisConfig) Url() string {
|
||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
Title string `json:"title,omitempty"` // 网站标题
|
||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||
Logo string `json:"logo,omitempty"` // 圆形 Logo
|
||||
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
|
||||
type SystemConfig struct {
|
||||
Title string `json:"title,omitempty"` // 网站标题
|
||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||
Logo string `json:"logo,omitempty"` // 方形 Logo
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
||||
|
||||
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
||||
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||
|
||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间,单位:分钟
|
||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
||||
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
|
||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||
|
||||
@@ -88,44 +158,12 @@ type BaseConfig struct {
|
||||
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||
MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo
|
||||
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
ICP string `json:"icp"` // ICP 备案号
|
||||
GaBeian string `json:"ga_beian"` // 公安备案号
|
||||
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
||||
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
|
||||
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB
|
||||
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
Base BaseConfig
|
||||
Payment PaymentConfig
|
||||
OSS OSSConfig
|
||||
SMS SMSConfig
|
||||
SMTP SmtpConfig
|
||||
Captcha CaptchaConfig
|
||||
WxLogin WxLoginConfig
|
||||
Jimeng JimengConfig
|
||||
License License
|
||||
Moderation ModerationConfig
|
||||
}
|
||||
|
||||
// 配置键名常量
|
||||
const (
|
||||
ConfigKeySystem = "system"
|
||||
ConfigKeyNotice = "notice"
|
||||
ConfigKeyAgreement = "agreement"
|
||||
ConfigKeyPrivacy = "privacy"
|
||||
ConfigKeyMarkMap = "mark_map"
|
||||
ConfigKeyCaptcha = "captcha"
|
||||
ConfigKeyWxLogin = "wx_login"
|
||||
ConfigKeyLicense = "license"
|
||||
ConfigKeySms = "sms"
|
||||
ConfigKeySmtp = "smtp"
|
||||
ConfigKeyOss = "oss"
|
||||
ConfigKeyPayment = "payment"
|
||||
ConfigKeyModeration = "moderation"
|
||||
ConfigKeyAI3D = "ai3d"
|
||||
ConfigKeyJimeng = "jimeng"
|
||||
)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package types
|
||||
|
||||
import "os"
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// GeekAI 增值服务
|
||||
var GeekAPIURL = "https://sapi.geekai.me"
|
||||
|
||||
func init() {
|
||||
if os.Getenv("GEEK_API_URL") != "" {
|
||||
GeekAPIURL = os.Getenv("GEEK_API_URL")
|
||||
}
|
||||
}
|
||||
|
||||
// CaptchaConfig 行为验证码配置
|
||||
type CaptchaConfig struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
Type string `json:"type"` // 验证码类型, 可选值: "dot" 或 "slide"
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// WxLoginConfig 微信登录配置
|
||||
type WxLoginConfig struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
|
||||
Enabled bool `json:"enabled"` // 是否启用微信登录
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package types
|
||||
|
||||
// JimengConfig 即梦AI配置
|
||||
type JimengConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Power JimengPower `json:"power"`
|
||||
}
|
||||
|
||||
// JimengPower 即梦AI算力配置
|
||||
type JimengPower struct {
|
||||
TextToImage int `json:"text_to_image"`
|
||||
ImageToImage int `json:"image_to_image"`
|
||||
ImageEdit int `json:"image_edit"`
|
||||
ImageEffects int `json:"image_effects"`
|
||||
TextToVideo int `json:"text_to_video"`
|
||||
ImageToVideo int `json:"image_to_video"`
|
||||
}
|
||||
@@ -16,7 +16,7 @@ type MKey interface {
|
||||
string | int | uint
|
||||
}
|
||||
type MValue interface {
|
||||
*WsClient | context.CancelFunc | []any
|
||||
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
||||
}
|
||||
type LMap[K MKey, T MValue] struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// 文本审查
|
||||
type ModerationConfig struct {
|
||||
Enable bool `json:"enable"` // 是否启用文本审查
|
||||
Active string `json:"active"`
|
||||
EnableGuide bool `json:"enable_guide"` // 是否启用模型引导提示词
|
||||
GuidePrompt string `json:"guide_prompt"` // 模型引导提示词
|
||||
Gitee ModerationGiteeConfig `json:"gitee"`
|
||||
Baidu ModerationBaiduConfig `json:"baidu"`
|
||||
Tencent ModerationTencentConfig `json:"tencent"`
|
||||
}
|
||||
|
||||
const (
|
||||
ModerationGitee = "gitee"
|
||||
ModerationBaidu = "baidu"
|
||||
ModerationTencent = "tencent"
|
||||
)
|
||||
|
||||
// GiteeAI 文本审查配置
|
||||
type ModerationGiteeConfig struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
Model string `json:"model"` // 文本审核模型
|
||||
}
|
||||
|
||||
// 百度文本审查配置
|
||||
type ModerationBaiduConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
}
|
||||
|
||||
// 腾讯云文本审查配置
|
||||
type ModerationTencentConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
}
|
||||
|
||||
type ModerationResult struct {
|
||||
Flagged bool `json:"flagged"`
|
||||
Categories map[string]bool `json:"categories"`
|
||||
CategoryScores map[string]float64 `json:"category_scores"`
|
||||
}
|
||||
|
||||
var ModerationCategories = map[string]string{
|
||||
"politic": "内容涉及人物、事件或敏感的政治观点",
|
||||
"porn": "明确的色情内容",
|
||||
"insult": "具有侮辱、攻击性语言、人身攻击或冒犯性表达",
|
||||
"violence": "包含暴力、血腥、攻击行为或煽动暴力的言论",
|
||||
"illegal": "涉及违法活动的内容,如诈骗、赌博等",
|
||||
"terror": "宣扬恐怖主义、极端暴力或煽动恐怖行为的内容",
|
||||
"ad": "垃圾广告或未经许可的推广内容",
|
||||
"spam": "无意义重复内容或诱导性信息",
|
||||
"abuse": "人身攻击、恶意辱骂或侮辱性言论",
|
||||
"polity": "涉及国家政治、领导人或政策的违规讨论内容",
|
||||
}
|
||||
|
||||
// 敏感词来源
|
||||
const (
|
||||
ModerationSourceChat = "chat"
|
||||
ModerationSourceMJ = "mj"
|
||||
ModerationSourceDalle = "dalle"
|
||||
ModerationSourceSD = "sd"
|
||||
ModerationSourceSuno = "suno"
|
||||
ModerationSourceVideo = "video"
|
||||
ModerationSourceJiMeng = "jimeng"
|
||||
)
|
||||
@@ -11,25 +11,29 @@ type OrderStatus int
|
||||
|
||||
const (
|
||||
OrderNotPaid = OrderStatus(0)
|
||||
OrderPaidSuccess = OrderStatus(2) // 已支付
|
||||
OrderPaidFailed = OrderStatus(3) // 已关闭
|
||||
OrderScanned = OrderStatus(1) // 已扫码
|
||||
OrderPaidSuccess = OrderStatus(2)
|
||||
)
|
||||
|
||||
type OrderRemark struct {
|
||||
Days int `json:"days"` // 有效期
|
||||
Power int `json:"power"` // 增加算力点数
|
||||
Name string `json:"name"` // 产品名称
|
||||
Price float64 `json:"price"`
|
||||
Days int `json:"days"` // 有效期
|
||||
Power int `json:"power"` // 增加算力点数
|
||||
Name string `json:"name"` // 产品名称
|
||||
Price float64 `json:"price"`
|
||||
Discount float64 `json:"discount"`
|
||||
}
|
||||
|
||||
// PayChannel 支付渠道
|
||||
var PayChannel = map[string]string{
|
||||
var PayMethods = map[string]string{
|
||||
"alipay": "支付宝商号",
|
||||
"wxpay": "微信商号",
|
||||
"epay": "易支付",
|
||||
"wechat": "微信商号",
|
||||
"hupi": "虎皮椒",
|
||||
"geek": "易支付",
|
||||
}
|
||||
|
||||
var PayWays = map[string]string{
|
||||
var PayNames = map[string]string{
|
||||
"alipay": "支付宝",
|
||||
"wxpay": "微信支付",
|
||||
"qqpay": "QQ钱包",
|
||||
"jdpay": "京东支付",
|
||||
"douyin": "抖音支付",
|
||||
"paypal": "PayPal支付",
|
||||
}
|
||||
|
||||
@@ -8,39 +8,41 @@ package types
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type OSSConfig struct {
|
||||
Active string `json:"active"`
|
||||
Local LocalStorageConfig `json:"local"`
|
||||
Minio MiniOssConfig `json:"minio"`
|
||||
QiNiu QiNiuOssConfig `json:"qiniu"`
|
||||
AliYun AliYunOssConfig `json:"aliyun"`
|
||||
Active string
|
||||
Local LocalStorageConfig
|
||||
Minio MiniOssConfig
|
||||
QiNiu QiNiuOssConfig
|
||||
AliYun AliYunOssConfig
|
||||
}
|
||||
|
||||
type MiniOssConfig struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Bucket string `json:"bucket"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
Domain string `json:"domain"`
|
||||
Endpoint string
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Bucket string
|
||||
SubDir string
|
||||
UseSSL bool
|
||||
Domain string
|
||||
}
|
||||
|
||||
type QiNiuOssConfig struct {
|
||||
Zone string `json:"zone"`
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Bucket string `json:"bucket"`
|
||||
Domain string `json:"domain"`
|
||||
Zone string
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Bucket string
|
||||
SubDir string
|
||||
Domain string
|
||||
}
|
||||
|
||||
type AliYunOssConfig struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Bucket string `json:"bucket"`
|
||||
Domain string `json:"domain"`
|
||||
Endpoint string
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Bucket string
|
||||
SubDir string
|
||||
Domain string
|
||||
}
|
||||
|
||||
type LocalStorageConfig struct {
|
||||
BasePath string `json:"base_path"`
|
||||
BaseURL string `json:"base_url"`
|
||||
BasePath string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
package types
|
||||
|
||||
type PaymentConfig struct {
|
||||
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
|
||||
Epay EpayConfig `json:"epay"` // 易支付配置
|
||||
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
|
||||
}
|
||||
|
||||
// AlipayConfig 支付宝支付配置
|
||||
type AlipayConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||
SandBox bool `json:"sandbox"` // 是否沙盒环境
|
||||
AppId string `json:"app_id"` // 应用 ID
|
||||
PrivateKey string `json:"private_key"` // 应用私钥
|
||||
AlipayPublicKey string `json:"alipay_public_key"` // 支付宝公钥
|
||||
Domain string `json:"domain"` // 支付回调域名
|
||||
}
|
||||
|
||||
func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
|
||||
return c.AppId == other.AppId &&
|
||||
c.PrivateKey == other.PrivateKey &&
|
||||
c.AlipayPublicKey == other.AlipayPublicKey &&
|
||||
c.Domain == other.Domain
|
||||
}
|
||||
|
||||
// WxPayConfig 微信支付配置
|
||||
type WxPayConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||
AppId string `json:"app_id"` // 公众号的APPID,如:wxd678efh567hg6787
|
||||
MchId string `json:"mch_id"` // 直连商户的商户号,由微信支付生成并下发
|
||||
SerialNo string `json:"serial_no"` // 商户证书的证书序列号
|
||||
PrivateKey string `json:"private_key"` // 商户证书私钥
|
||||
ApiV3Key string `json:"api_v3_key"` // API V3 秘钥
|
||||
Domain string `json:"domain"` // 支付回调域名
|
||||
}
|
||||
|
||||
func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
|
||||
return c.AppId == other.AppId &&
|
||||
c.MchId == other.MchId &&
|
||||
c.SerialNo == other.SerialNo &&
|
||||
c.PrivateKey == other.PrivateKey &&
|
||||
c.ApiV3Key == other.ApiV3Key &&
|
||||
c.Domain == other.Domain
|
||||
}
|
||||
|
||||
// EpayConfig 易支付配置
|
||||
type EpayConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||
AppId string `json:"app_id"` // 商户 ID
|
||||
PrivateKey string `json:"private_key"` // 私钥
|
||||
ApiURL string `json:"api_url"` // z支付 API 网关
|
||||
Domain string `json:"domain"` // 支付回调域名
|
||||
}
|
||||
|
||||
func (c *EpayConfig) Equal(other *EpayConfig) bool {
|
||||
return c.AppId == other.AppId &&
|
||||
c.PrivateKey == other.PrivateKey &&
|
||||
c.ApiURL == other.ApiURL &&
|
||||
c.Domain == other.Domain
|
||||
}
|
||||
@@ -8,7 +8,6 @@ package types
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
const LoginUserID = "LOGIN_USER_ID"
|
||||
const AdminUserID = "ADMIN_USER_ID"
|
||||
const LoginUserCache = "LOGIN_USER_CACHE"
|
||||
|
||||
const UserAuthHeader = "Authorization"
|
||||
|
||||
@@ -8,23 +8,26 @@ package types
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type SMSConfig struct {
|
||||
Active string `json:"active"`
|
||||
Ali SmsConfigAli `json:"aliyun"`
|
||||
Bao SmsConfigBao `json:"bao"`
|
||||
Active string
|
||||
Ali SmsConfigAli
|
||||
Bao SmsConfigBao
|
||||
}
|
||||
|
||||
// SmsConfigAli 阿里云短信平台配置
|
||||
type SmsConfigAli struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Sign string `json:"sign"` // 短信签名
|
||||
CodeTempId string `json:"code_temp_id"` // 验证码短信模板 ID
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Product string
|
||||
Domain string
|
||||
Sign string // 短信签名
|
||||
CodeTempId string // 验证码短信模板 ID
|
||||
}
|
||||
|
||||
// SmsConfigBao 短信宝平台配置
|
||||
type SmsConfigBao struct {
|
||||
Username string `json:"username"` //短信宝平台注册的用户名
|
||||
Password string `json:"password"` //短信宝平台注册的密码
|
||||
Sign string `json:"sign"` // 短信签名
|
||||
CodeTemplate string `json:"code_template"` // 验证码短信模板 匹配
|
||||
Username string //短信宝平台注册的用户名
|
||||
Password string //短信宝平台注册的密码
|
||||
Domain string //域名
|
||||
Sign string // 短信签名
|
||||
CodeTemplate string // 验证码短信模板 匹配
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type SmtpConfig struct {
|
||||
UseTls bool `json:"use_tls"` // 是否使用 TLS 发送
|
||||
Host string `json:"host"` // 邮件服务器地址
|
||||
Port int `json:"port"` // 邮件服务器端口
|
||||
AppName string `json:"app_name"` // 应用名称
|
||||
From string `json:"from"` // 发件人邮箱地址
|
||||
Password string `json:"password"` // 发件人邮箱密码
|
||||
}
|
||||
|
||||
func (s *SmtpConfig) Equal(other *SmtpConfig) bool {
|
||||
return s.UseTls == other.UseTls &&
|
||||
s.Host == other.Host &&
|
||||
s.Port == other.Port &&
|
||||
s.AppName == other.AppName &&
|
||||
s.From == other.From &&
|
||||
s.Password == other.Password
|
||||
}
|
||||
@@ -26,6 +26,7 @@ const (
|
||||
type MjTask struct {
|
||||
Id uint `json:"id"` // 任务ID
|
||||
TaskId string `json:"task_id"` // 中转任务ID
|
||||
ClientId string `json:"client_id"`
|
||||
ImgArr []string `json:"img_arr"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -43,6 +44,7 @@ type MjTask struct {
|
||||
type SdTask struct {
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
Type TaskType `json:"type"`
|
||||
ClientId string `json:"client_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Params SdTaskParams `json:"params"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
@@ -50,6 +52,7 @@ type SdTask struct {
|
||||
}
|
||||
|
||||
type SdTaskParams struct {
|
||||
ClientId string `json:"client_id"` // 客户端ID
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||
@@ -70,21 +73,21 @@ type SdTaskParams struct {
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
ModelId uint `json:"model_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
Power int `json:"power"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
|
||||
Power int `json:"power"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -92,8 +95,7 @@ type SunoTask struct {
|
||||
Title string `json:"title"`
|
||||
RefTaskId string `json:"ref_task_id,omitempty"`
|
||||
RefSongId string `json:"ref_song_id,omitempty"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Lyrics string `json:"lyrics,omitempty"` // 歌词
|
||||
Prompt string `json:"prompt"` // 提示词/歌词
|
||||
Tags string `json:"tags"`
|
||||
Model string `json:"model"`
|
||||
Instrumental bool `json:"instrumental"` // 是否纯音乐
|
||||
@@ -106,21 +108,21 @@ const (
|
||||
VideoLuma = "luma"
|
||||
VideoRunway = "runway"
|
||||
VideoCog = "cog"
|
||||
VideoKeLing = "keling"
|
||||
)
|
||||
|
||||
type VideoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params interface{} `json:"params"`
|
||||
Params VideoParams `json:"params"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
}
|
||||
|
||||
type LumaVideoParams struct {
|
||||
type VideoParams struct {
|
||||
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
|
||||
Loop bool `json:"loop"` // 是否循环参考图
|
||||
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
|
||||
@@ -130,33 +132,3 @@ type LumaVideoParams struct {
|
||||
Style string `json:"style"` // 风格
|
||||
Duration int `json:"duration"` // 视频时长(秒)
|
||||
}
|
||||
|
||||
type KeLingVideoParams struct {
|
||||
TaskType string `json:"task_type"` // 任务类型: text2video/image2video
|
||||
Model string `json:"model"` // 模型: default/anime
|
||||
Prompt string `json:"prompt"` // 视频描述
|
||||
NegPrompt string `json:"negative_prompt"` // 负面提示词
|
||||
CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1)
|
||||
Mode string `json:"mode"` // 生成模式: std/pro
|
||||
AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1
|
||||
Duration string `json:"duration"` // 视频时长: 5/10
|
||||
CameraControl CameraControl `json:"camera_control"` // 摄像机控制
|
||||
Image string `json:"image"` // 参考图片URL(image2video)
|
||||
ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video)
|
||||
}
|
||||
|
||||
// CameraControl 摄像机控制
|
||||
type CameraControl struct {
|
||||
Type string `json:"type"` // 控制类型: simple/down_back/forward_up/right_turn_forward/left_turn_forward
|
||||
Config CameraConfig `json:"config"` // 控制参数(仅simple类型时使用)
|
||||
}
|
||||
|
||||
// CameraConfig 摄像机参数
|
||||
type CameraConfig struct {
|
||||
Horizontal int `json:"horizontal"` // 水平移动(-10到10)
|
||||
Vertical int `json:"vertical"` // 垂直移动(-10到10)
|
||||
Pan int `json:"pan"` // 左右旋转(-10到10)
|
||||
Tilt int `json:"tilt"` // 上下旋转(-10到10)
|
||||
Roll int `json:"roll"` // 横向翻转(-10到10)
|
||||
Zoom int `json:"zoom"` // 镜头缩放(-10到10)
|
||||
}
|
||||
|
||||
@@ -34,14 +34,13 @@ const (
|
||||
MsgTypeErr = WsMsgType("error")
|
||||
MsgTypePing = WsMsgType("ping") // 心跳消息
|
||||
|
||||
ChPing = WsChannel("ping")
|
||||
ChChat = WsChannel("chat")
|
||||
ChMj = WsChannel("mj")
|
||||
ChSd = WsChannel("sd")
|
||||
ChDall = WsChannel("dall")
|
||||
ChSuno = WsChannel("suno")
|
||||
ChLuma = WsChannel("luma")
|
||||
ChKeLing = WsChannel("keling")
|
||||
ChPing = WsChannel("ping")
|
||||
ChChat = WsChannel("chat")
|
||||
ChMj = WsChannel("mj")
|
||||
ChSd = WsChannel("sd")
|
||||
ChDall = WsChannel("dall")
|
||||
ChSuno = WsChannel("suno")
|
||||
ChLuma = WsChannel("luma")
|
||||
)
|
||||
|
||||
// InputMessage 对话输入消息结构
|
||||
|
||||
@@ -4,7 +4,7 @@ build_name: runner-build
|
||||
build_log: runner-build-errors.log
|
||||
valid_ext: .go, .tpl, .tmpl, .html
|
||||
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
|
||||
ignored: assets, tmp, web, .git, .idea, test, data, static
|
||||
ignored: assets, tmp, web, .git, .idea, test, data
|
||||
build_delay: 600
|
||||
colors: 1
|
||||
log_color_main: cyan
|
||||
|
||||
14
api/go.mod
14
api/go.mod
@@ -18,19 +18,17 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23
|
||||
go.uber.org/zap v1.23.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/driver/mysql v1.4.7
|
||||
)
|
||||
|
||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
||||
|
||||
require (
|
||||
github.com/go-pay/gopay v1.5.101
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/go-tika v0.3.1
|
||||
github.com/microcosm-cc/bluemonday v1.0.26
|
||||
github.com/sashabaranov/go-openai v1.38.1
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.3.1
|
||||
github.com/syndtr/goleveldb v1.0.0
|
||||
golang.org/x/image v0.15.0
|
||||
@@ -44,8 +42,9 @@ require (
|
||||
github.com/go-pay/util v0.0.2 // indirect
|
||||
github.com/go-pay/xlog v0.0.2 // indirect
|
||||
github.com/go-pay/xtime v0.0.2 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||
github.com/gorilla/css v1.0.0 // indirect
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
@@ -62,6 +61,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gaukas/godicttls v0.0.3 // indirect
|
||||
github.com/go-basic/ipv4 v1.0.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
@@ -71,7 +71,7 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af // indirect
|
||||
github.com/klauspost/compress v1.16.7 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
@@ -113,7 +113,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/fx v1.19.3
|
||||
go.uber.org/multierr v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.6.0 // indirect
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
gorm.io/gorm v1.25.1
|
||||
|
||||
87
api/go.sum
87
api/go.sum
@@ -1,5 +1,3 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I=
|
||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 h1:cKNFQmeCQFN0WNfjScKoVrGi7vXxTVbkCvCqSrOf+P4=
|
||||
@@ -8,7 +6,6 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||
@@ -16,13 +13,11 @@ github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -33,8 +28,6 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
|
||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
@@ -46,6 +39,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
|
||||
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
@@ -85,31 +80,13 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
|
||||
@@ -136,11 +113,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
|
||||
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
|
||||
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
@@ -151,7 +125,6 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
@@ -204,7 +177,6 @@ github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 h1:IFhPCcB0/H
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk=
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1 h1:UoQv7fBKtzAiD1qZPIvTy62Se48YLKxcCYP9nAwWMa0=
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYXOwye868w=
|
||||
@@ -220,8 +192,6 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8=
|
||||
github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||
@@ -234,7 +204,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -257,8 +226,8 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK
|
||||
github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
@@ -273,8 +242,8 @@ go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
|
||||
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec=
|
||||
go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
|
||||
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
|
||||
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -288,23 +257,15 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
@@ -315,15 +276,12 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -362,39 +320,16 @@ golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
@@ -407,10 +342,6 @@ gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYs
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
@@ -420,6 +351,4 @@ gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8o
|
||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
|
||||
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
logger2 "geekai/logger"
|
||||
@@ -20,10 +19,9 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -47,26 +45,6 @@ func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, cap
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ManagerHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("session", h.Session)
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("enable", h.Enable)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}
|
||||
}
|
||||
|
||||
// Login 登录
|
||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -81,6 +59,19 @@ func (h *ManagerHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.EnabledVerify {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var manager model.AdminUser
|
||||
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||
if res.Error != nil {
|
||||
@@ -144,15 +135,16 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
|
||||
|
||||
// Session 会话检测
|
||||
func (h *ManagerHandler) Session(c *gin.Context) {
|
||||
id := h.GetAdminId(c)
|
||||
if id == 0 {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
id := h.GetLoginUserId(c)
|
||||
key := fmt.Sprintf("admin/%d", id)
|
||||
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
var manager model.AdminUser
|
||||
err := h.DB.Where("id", id).First(&manager).Error
|
||||
if err != nil {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
res := h.DB.Where("id", id).First(&manager)
|
||||
if res.Error != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,14 +10,12 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -31,20 +29,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ApiKeyHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/apikey/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
@@ -87,18 +71,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
resp.SUCCESS(c, keyVo)
|
||||
}
|
||||
|
||||
// List 获取 API KEY 列表
|
||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||
status := h.GetBool(c, "status")
|
||||
t := c.Query("type")
|
||||
t := h.GetTrim(c, "type")
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if status {
|
||||
session = session.Where("enabled", true)
|
||||
}
|
||||
if t != "" {
|
||||
types := strings.Split(t, "|")
|
||||
session = session.Where("type IN ?", types)
|
||||
session = session.Where("type", t)
|
||||
}
|
||||
|
||||
var items []model.ApiKey
|
||||
|
||||
@@ -10,7 +10,6 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -31,29 +30,14 @@ func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/role/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("sort", h.Sort)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
// Save 创建或者更新某个角色
|
||||
func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||
var data vo.ChatApp
|
||||
var data vo.ChatRole
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
var role model.ChatApp
|
||||
var role model.ChatRole
|
||||
err := utils.CopyObject(data, &role)
|
||||
if err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -81,8 +65,8 @@ func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
var items []model.ChatApp
|
||||
var roles = make([]vo.ChatApp, 0)
|
||||
var items []model.ChatRole
|
||||
var roles = make([]vo.ChatRole, 0)
|
||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No data found")
|
||||
@@ -94,10 +78,10 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
typeIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
if v.ModelId > 0 {
|
||||
modelIds = append(modelIds, int(v.ModelId))
|
||||
modelIds = append(modelIds, v.ModelId)
|
||||
}
|
||||
if v.Tid > 0 {
|
||||
typeIds = append(typeIds, int(v.Tid))
|
||||
typeIds = append(typeIds, v.Tid)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,14 +107,14 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var role vo.ChatApp
|
||||
var role vo.ChatRole
|
||||
err := utils.CopyObject(v, &role)
|
||||
if err == nil {
|
||||
role.Id = v.Id
|
||||
role.CreatedAt = v.CreatedAt.Unix()
|
||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||
role.ModelName = modelNameMap[int(role.ModelId)]
|
||||
role.TypeName = typeNameMap[int(role.Tid)]
|
||||
role.ModelName = modelNameMap[role.ModelId]
|
||||
role.TypeName = typeNameMap[role.Tid]
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
@@ -151,7 +135,7 @@ func (h *ChatAppHandler) Sort(c *gin.Context) {
|
||||
}
|
||||
|
||||
for index, id := range data.Ids {
|
||||
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -173,7 +157,7 @@ func (h *ChatAppHandler) Set(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -188,8 +172,9 @@ func (h *ChatAppHandler) Remove(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
res := h.DB.Where("id", id).Delete(&model.ChatApp{})
|
||||
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "删除失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,14 +2,12 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -22,21 +20,6 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
||||
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppTypeHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/app/type/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
}
|
||||
}
|
||||
|
||||
// Save 创建或更新App类型
|
||||
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
|
||||
@@ -9,14 +9,12 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -29,31 +27,16 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/chat/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list", h.List)
|
||||
group.POST("message", h.Messages)
|
||||
group.GET("history", h.History)
|
||||
group.GET("remove", h.RemoveChat)
|
||||
group.GET("message/remove", h.RemoveMessage)
|
||||
}
|
||||
}
|
||||
|
||||
type chatItemVo struct {
|
||||
Username string `json:"username"`
|
||||
UserId uint `json:"user_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Title string `json:"title"`
|
||||
Role vo.ChatApp `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Token int `json:"token"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
MsgNum int `json:"msg_num"` // 消息数量
|
||||
Username string `json:"username"`
|
||||
UserId uint `json:"user_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Title string `json:"title"`
|
||||
Role vo.ChatRole `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Token int `json:"token"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
MsgNum int `json:"msg_num"` // 消息数量
|
||||
}
|
||||
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
@@ -103,7 +86,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
}
|
||||
var messages []model.ChatMessage
|
||||
var users []model.User
|
||||
var roles []model.ChatApp
|
||||
var roles []model.ChatRole
|
||||
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
||||
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||
@@ -111,7 +94,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
tokenMap := make(map[string]int)
|
||||
userMap := make(map[uint]string)
|
||||
msgMap := make(map[string]int)
|
||||
roleMap := make(map[uint]vo.ChatApp)
|
||||
roleMap := make(map[uint]vo.ChatRole)
|
||||
for _, msg := range messages {
|
||||
tokenMap[msg.ChatId] += msg.Tokens
|
||||
msgMap[msg.ChatId] += 1
|
||||
@@ -120,7 +103,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
userMap[user.Id] = user.Username
|
||||
}
|
||||
for _, r := range roles {
|
||||
var roleVo vo.ChatApp
|
||||
var roleVo vo.ChatRole
|
||||
err := utils.CopyObject(r, &roleVo)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -206,7 +189,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
for _, item := range items {
|
||||
list = append(list, chatMessageVo{
|
||||
Id: uint(item.Id),
|
||||
Id: item.Id,
|
||||
UserId: item.UserId,
|
||||
Username: userMap[item.UserId],
|
||||
Content: item.Content,
|
||||
@@ -225,28 +208,20 @@ func (h *ChatHandler) Messages(c *gin.Context) {
|
||||
func (h *ChatHandler) History(c *gin.Context) {
|
||||
chatId := c.Query("chat_id") // 会话 ID
|
||||
var items []model.ChatMessage
|
||||
var messages = make([]vo.ChatMessage, 0)
|
||||
var messages = make([]vo.HistoryMessage, 0)
|
||||
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No history message")
|
||||
return
|
||||
} else {
|
||||
for _, item := range items {
|
||||
var v vo.ChatMessage
|
||||
var v vo.HistoryMessage
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 解析内容
|
||||
var content vo.MsgContent
|
||||
err = utils.JsonDecode(item.Content, &content)
|
||||
if err != nil {
|
||||
content.Text = item.Content
|
||||
}
|
||||
v.Content = content
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
v.UpdatedAt = item.UpdatedAt.Unix()
|
||||
messages = append(messages, v)
|
||||
if err == nil {
|
||||
messages = append(messages, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,7 @@ package admin
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -30,41 +28,21 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatModelHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/model/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("batch-remove", h.BatchRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Enabled bool `json:"enabled"`
|
||||
SortNum int `json:"sort_num"`
|
||||
Open bool `json:"open"`
|
||||
Platform string `json:"platform"`
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Desc string `json:"desc"` //模型描述
|
||||
Tag string `json:"tag"` //模型标签
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Type string `json:"type"`
|
||||
Options map[string]string `json:"options"`
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Enabled bool `json:"enabled"`
|
||||
SortNum int `json:"sort_num"`
|
||||
Open bool `json:"open"`
|
||||
Platform string `json:"platform"`
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -80,16 +58,14 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
item.Name = data.Name
|
||||
item.Value = data.Value
|
||||
item.Enabled = data.Enabled
|
||||
item.SortNum = data.SortNum
|
||||
item.Open = data.Open
|
||||
item.Power = data.Power
|
||||
item.MaxTokens = data.MaxTokens
|
||||
item.MaxContext = data.MaxContext
|
||||
item.Desc = data.Desc
|
||||
item.Tag = data.Tag
|
||||
item.Temperature = data.Temperature
|
||||
item.KeyId = uint(data.KeyId)
|
||||
item.Type = data.Type
|
||||
item.Options = utils.JsonEncode(data.Options)
|
||||
item.KeyId = data.KeyId
|
||||
|
||||
var res *gorm.DB
|
||||
if data.Id > 0 {
|
||||
res = h.DB.Save(&item)
|
||||
@@ -118,16 +94,12 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
enable := h.GetBool(c, "enable")
|
||||
name := h.GetTrim(c, "name")
|
||||
modelType := h.GetTrim(c, "type")
|
||||
if enable {
|
||||
session = session.Where("enabled", enable)
|
||||
}
|
||||
if name != "" {
|
||||
session = session.Where("name LIKE ?", name+"%")
|
||||
}
|
||||
if modelType != "" {
|
||||
session = session.Where("type", modelType)
|
||||
}
|
||||
var items []model.ChatModel
|
||||
var cms = make([]vo.ChatModel, 0)
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
@@ -139,7 +111,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
// initialize key name
|
||||
keyIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
keyIds = append(keyIds, int(v.KeyId))
|
||||
keyIds = append(keyIds, v.KeyId)
|
||||
}
|
||||
var keys []model.ApiKey
|
||||
keyMap := make(map[uint]string)
|
||||
@@ -219,33 +191,3 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除模型
|
||||
func (h *ChatModelHandler) BatchRemove(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data.Ids) == 0 {
|
||||
resp.ERROR(c, "请选择要删除的模型")
|
||||
return
|
||||
}
|
||||
|
||||
// 执行批量删除
|
||||
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.ChatModel{}).Error
|
||||
if err != nil {
|
||||
logger.Error("批量删除模型失败:", err)
|
||||
resp.ERROR(c, "批量删除失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": fmt.Sprintf("成功删除 %d 个模型", len(data.Ids)),
|
||||
"deleted_count": len(data.Ids),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,220 +9,42 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
"geekai/service/sms"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConfigHandler struct {
|
||||
handler.BaseHandler
|
||||
licenseService *service.LicenseService
|
||||
sysConfig *types.SystemConfig
|
||||
alipayService *payment.AlipayService
|
||||
wxpayService *payment.WxPayService
|
||||
epayService *payment.EPayService
|
||||
smsManager *sms.SmsManager
|
||||
uploaderManager *oss.UploaderManager
|
||||
smtpService *service.SmtpService
|
||||
captchaService *service.CaptchaService
|
||||
wxLoginService *service.WxLoginService
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
}
|
||||
|
||||
func NewConfigHandler(
|
||||
app *core.AppServer,
|
||||
db *gorm.DB,
|
||||
licenseService *service.LicenseService,
|
||||
sysConfig *types.SystemConfig,
|
||||
alipayService *payment.AlipayService,
|
||||
wxpayService *payment.WxPayService,
|
||||
epayService *payment.EPayService,
|
||||
smsManager *sms.SmsManager,
|
||||
uploaderManager *oss.UploaderManager,
|
||||
smtpService *service.SmtpService,
|
||||
captchaService *service.CaptchaService,
|
||||
wxLoginService *service.WxLoginService,
|
||||
) *ConfigHandler {
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
licenseService: licenseService,
|
||||
sysConfig: sysConfig,
|
||||
alipayService: alipayService,
|
||||
wxpayService: wxpayService,
|
||||
epayService: epayService,
|
||||
smsManager: smsManager,
|
||||
uploaderManager: uploaderManager,
|
||||
smtpService: smtpService,
|
||||
captchaService: captchaService,
|
||||
wxLoginService: wxLoginService,
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
levelDB: levelDB,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ConfigHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/admin/config")
|
||||
|
||||
// 需要管理员登录的接口
|
||||
rg.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
rg.POST("update/base", h.UpdateBase)
|
||||
rg.POST("update/power", h.UpdatePower)
|
||||
rg.POST("update/notice", h.UpdateNotice)
|
||||
rg.POST("update/agreement", h.UpdateAgreement)
|
||||
rg.POST("update/privacy", h.UpdatePrivacy)
|
||||
rg.POST("update/mark_map", h.UpdateMarkMap)
|
||||
rg.POST("update/captcha", h.UpdateCaptcha)
|
||||
rg.POST("update/wx_login", h.UpdateWxLogin)
|
||||
rg.POST("update/payment", h.UpdatePayment)
|
||||
rg.POST("update/sms", h.UpdateSms)
|
||||
rg.POST("update/oss", h.UpdateOss)
|
||||
rg.POST("update/smtp", h.UpdateStmp)
|
||||
rg.GET("get", h.Get)
|
||||
rg.POST("license/active", h.Active)
|
||||
rg.GET("license/get", h.GetLicense)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateBase 更新基础配置
|
||||
func (h *ConfigHandler) UpdateBase(c *gin.Context) {
|
||||
var data types.BaseConfig
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 未授权的话不允许修改版权
|
||||
license := h.licenseService.GetLicense()
|
||||
if !license.IsActive && data.Copyright != h.sysConfig.Base.Copyright {
|
||||
resp.ERROR(c, "未授权系统不允许修改版权信息")
|
||||
return
|
||||
}
|
||||
|
||||
// 未授权的话不允许修改 Logo
|
||||
if !license.IsActive && data.Logo != h.sysConfig.Base.Logo {
|
||||
resp.ERROR(c, "未授权系统不允许修改 Logo")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeySystem, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.sysConfig.Base = data
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdatePower 更新系统配置
|
||||
func (h *ConfigHandler) UpdatePower(c *gin.Context) {
|
||||
func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
var data struct {
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
h.sysConfig.Base.InitPower = data.InitPower
|
||||
h.sysConfig.Base.DailyPower = data.DailyPower
|
||||
h.sysConfig.Base.InvitePower = data.InvitePower
|
||||
h.sysConfig.Base.MjPower = data.MjPower
|
||||
h.sysConfig.Base.MjActionPower = data.MjActionPower
|
||||
h.sysConfig.Base.SdPower = data.SdPower
|
||||
h.sysConfig.Base.SunoPower = data.SunoPower
|
||||
h.sysConfig.Base.LumaPower = data.LumaPower
|
||||
h.sysConfig.Base.KeLingPowers = data.KeLingPowers
|
||||
|
||||
err := h.Update(types.ConfigKeySystem, h.sysConfig.Base)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, h.sysConfig.Base)
|
||||
}
|
||||
|
||||
// UpdateNotice 更新公告配置
|
||||
func (h *ConfigHandler) UpdateNotice(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyNotice, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateAgreement 更新用户协议配置
|
||||
func (h *ConfigHandler) UpdateAgreement(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyAgreement, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdatePrivacy 更新隐私政策配置
|
||||
func (h *ConfigHandler) UpdatePrivacy(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyPrivacy, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateMarkMap 更新思维导图配置
|
||||
func (h *ConfigHandler) UpdateMarkMap(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
Key string `json:"key"`
|
||||
Config struct {
|
||||
types.SystemConfig
|
||||
Content string `json:"content,omitempty"`
|
||||
Updated bool `json:"updated,omitempty"`
|
||||
} `json:"config"`
|
||||
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
@@ -230,179 +52,57 @@ func (h *ConfigHandler) UpdateMarkMap(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyMarkMap, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
// ONLY authorized user can change the copyright
|
||||
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
|
||||
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateCaptcha 更新行为验证码配置
|
||||
func (h *ConfigHandler) UpdateCaptcha(c *gin.Context) {
|
||||
var data types.CaptchaConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyCaptcha, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.captchaService.UpdateConfig(data)
|
||||
resp.SUCCESS(c, data)
|
||||
|
||||
}
|
||||
|
||||
// UpdatePayment 更新支付配置
|
||||
func (h *ConfigHandler) UpdatePayment(c *gin.Context) {
|
||||
var data types.PaymentConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyPayment, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 如果启用状态发生改变,则需要更新支付服务配置
|
||||
if data.WxPay.Enabled {
|
||||
err = h.wxpayService.UpdateConfig(&data.WxPay)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if data.Epay.Enabled {
|
||||
h.epayService.UpdateConfig(&data.Epay)
|
||||
}
|
||||
if data.Alipay.Enabled {
|
||||
err = h.alipayService.UpdateConfig(&data.Alipay)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.sysConfig.Payment = data
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateSms 更新短信配置
|
||||
func (h *ConfigHandler) UpdateSms(c *gin.Context) {
|
||||
var data types.SMSConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeySms, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新服务配置
|
||||
h.smsManager.UpdateConfig(data)
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateOss 更新 Oss 配置
|
||||
func (h *ConfigHandler) UpdateOss(c *gin.Context) {
|
||||
var data types.OSSConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyOss, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新服务配置
|
||||
h.uploaderManager.UpdateConfig(data)
|
||||
h.sysConfig.OSS = data
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateStmp 更新 Stmp 配置
|
||||
func (h *ConfigHandler) UpdateStmp(c *gin.Context) {
|
||||
var data types.SmtpConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeySmtp, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新服务配置
|
||||
h.smtpService.UpdateConfig(&data)
|
||||
h.sysConfig.SMTP = data
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateWxLogin 更新微信登录配置
|
||||
func (h *ConfigHandler) UpdateWxLogin(c *gin.Context) {
|
||||
var data types.WxLoginConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
err := h.Update(types.ConfigKeyWxLogin, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if data.Enabled {
|
||||
h.wxLoginService.UpdateConfig(data)
|
||||
}
|
||||
|
||||
h.sysConfig.WxLogin = data
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// Update 更新系统配置
|
||||
func (h *ConfigHandler) Update(name string, value any) error {
|
||||
var config model.Config
|
||||
err := h.DB.Where("name", name).First(&config).Error
|
||||
if err != nil { // 不存在则创建
|
||||
config.Name = name
|
||||
config.Value = utils.JsonEncode(value)
|
||||
return h.DB.Create(&config).Error
|
||||
} else { // 存在则更新
|
||||
config.Value = utils.JsonEncode(value)
|
||||
return h.DB.Updates(&config).Error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Get 获取指定名称的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
name := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("name", name).First(&config)
|
||||
value := utils.JsonEncode(&data.Config)
|
||||
config := model.Config{Key: data.Key, Config: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]any
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
if config.Id > 0 {
|
||||
config.Config = value
|
||||
res := h.DB.Updates(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// update config cache for AppServer
|
||||
var cfg model.Config
|
||||
h.DB.Where("marker", data.Key).First(&cfg)
|
||||
var err error
|
||||
if data.Key == "system" {
|
||||
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.Infof("Update AppServer's config successfully: %v", config.Config)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, config)
|
||||
}
|
||||
|
||||
// Get 获取指定的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]interface{}
|
||||
err := utils.JsonDecode(config.Config, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -420,22 +120,19 @@ func (h *ConfigHandler) Active(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.licenseService.ActiveLicense(data.License)
|
||||
license := h.licenseService.GetLicense()
|
||||
info, err := host.Info()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.Update(types.ConfigKeyLicense, license); err != nil {
|
||||
|
||||
err = h.licenseService.ActiveLicense(data.License, info.HostID)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 更新系统配置
|
||||
h.sysConfig.License = *license
|
||||
|
||||
resp.SUCCESS(c, license.MachineId)
|
||||
|
||||
resp.SUCCESS(c, info.HostID)
|
||||
}
|
||||
|
||||
// GetLicense 获取 License 信息
|
||||
@@ -443,3 +140,70 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
license := h.licenseService.GetLicense()
|
||||
resp.SUCCESS(c, license)
|
||||
}
|
||||
|
||||
// FixData 修复数据
|
||||
func (h *ConfigHandler) FixData(c *gin.Context) {
|
||||
resp.ERROR(c, "当前升级版本没有数据需要修正!")
|
||||
return
|
||||
//var fixed bool
|
||||
//version := "data_fix_4.1.4"
|
||||
//err := h.levelDB.Get(version, &fixed)
|
||||
//if err == nil || fixed {
|
||||
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
||||
// return
|
||||
//}
|
||||
//tx := h.DB.Begin()
|
||||
//var users []model.User
|
||||
//err = tx.Find(&users).Error
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//for _, user := range users {
|
||||
// if user.Email != "" || user.Mobile != "" {
|
||||
// continue
|
||||
// }
|
||||
// if utils.IsValidEmail(user.Username) {
|
||||
// user.Email = user.Username
|
||||
// } else if utils.IsValidMobile(user.Username) {
|
||||
// user.Mobile = user.Username
|
||||
// }
|
||||
// err = tx.Save(&user).Error
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// tx.Rollback()
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//var orders []model.Order
|
||||
//err = h.DB.Find(&orders).Error
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//for _, order := range orders {
|
||||
// if order.PayWay == "支付宝" {
|
||||
// order.PayWay = "alipay"
|
||||
// order.PayType = "alipay"
|
||||
// } else if order.PayWay == "微信支付" {
|
||||
// order.PayWay = "wechat"
|
||||
// order.PayType = "wxpay"
|
||||
// } else if order.PayWay == "hupi" {
|
||||
// order.PayType = "wxpay"
|
||||
// }
|
||||
// err = tx.Save(&order).Error
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// tx.Rollback()
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
//tx.Commit()
|
||||
//err = h.levelDB.Put(version, true)
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -13,11 +13,10 @@ import (
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DashboardHandler struct {
|
||||
@@ -28,161 +27,46 @@ func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
||||
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *DashboardHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/dashboard/")
|
||||
group.GET("stats", h.Stats)
|
||||
}
|
||||
|
||||
// statsVo 增加 recentOrders、recentUsers 字段
|
||||
// 最近订单
|
||||
type OrderBrief struct {
|
||||
OrderNo string `json:"order_no"`
|
||||
Amount float64 `json:"amount"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// 最近用户
|
||||
type UserBrief struct {
|
||||
Nickname string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
LastActive time.Time `json:"last_active"`
|
||||
}
|
||||
|
||||
type statsVo struct {
|
||||
Users int64 `json:"users"`
|
||||
Chats int64 `json:"chats"`
|
||||
Tokens int `json:"tokens"`
|
||||
Income float64 `json:"income"`
|
||||
Chart map[string]map[string]float64 `json:"chart"`
|
||||
TodayUsers int64 `json:"todayUsers"`
|
||||
TodayChats int64 `json:"todayChats"`
|
||||
TodayTokens int `json:"todayTokens"`
|
||||
TodayIncome float64 `json:"todayIncome"`
|
||||
TodayOrders int64 `json:"todayOrders"`
|
||||
TodayImageJobs int64 `json:"todayImageJobs"`
|
||||
TodayVideoJobs int64 `json:"todayVideoJobs"`
|
||||
TodayMusicJobs int64 `json:"todayMusicJobs"`
|
||||
Orders int64 `json:"orders"`
|
||||
ImageJobs int64 `json:"imageJobs"`
|
||||
VideoJobs int64 `json:"videoJobs"`
|
||||
MusicJobs int64 `json:"musicJobs"`
|
||||
RecentOrders []OrderBrief `json:"recentOrders"`
|
||||
RecentUsers []UserBrief `json:"recentUsers"`
|
||||
Users int64 `json:"users"`
|
||||
Chats int64 `json:"chats"`
|
||||
Tokens int `json:"tokens"`
|
||||
Income float64 `json:"income"`
|
||||
Chart map[string]map[string]float64 `json:"chart"`
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||
stats := statsVo{}
|
||||
// new users statistic
|
||||
var userCount int64
|
||||
now := time.Now()
|
||||
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
|
||||
// 总用户数
|
||||
h.DB.Model(&model.User{}).Count(&stats.Users)
|
||||
|
||||
// 今日新增用户
|
||||
h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&stats.TodayUsers)
|
||||
|
||||
// 总对话数
|
||||
h.DB.Model(&model.ChatItem{}).Count(&stats.Chats)
|
||||
|
||||
// 今日新增对话
|
||||
h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&stats.TodayChats)
|
||||
|
||||
// 总算力消耗
|
||||
var powerLogs []model.PowerLog
|
||||
h.DB.Where("mark = ?", types.PowerSub).Find(&powerLogs)
|
||||
for _, item := range powerLogs {
|
||||
stats.Tokens += item.Amount
|
||||
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
||||
if res.Error == nil {
|
||||
stats.Users = userCount
|
||||
}
|
||||
|
||||
// 今日算力消耗
|
||||
var todayPowerLogs []model.PowerLog
|
||||
h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", zeroTime).Find(&todayPowerLogs)
|
||||
for _, item := range todayPowerLogs {
|
||||
stats.TodayTokens += item.Amount
|
||||
// new chats statistic
|
||||
var chatCount int64
|
||||
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
||||
if res.Error == nil {
|
||||
stats.Chats = chatCount
|
||||
}
|
||||
|
||||
// 总收入
|
||||
var allOrders []model.Order
|
||||
h.DB.Where("status = ?", types.OrderPaidSuccess).Find(&allOrders)
|
||||
for _, item := range allOrders {
|
||||
// tokens took stats
|
||||
var historyMessages []model.ChatMessage
|
||||
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
||||
for _, item := range historyMessages {
|
||||
stats.Tokens += item.Tokens
|
||||
}
|
||||
|
||||
// 订单收入
|
||||
var orders []model.Order
|
||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
||||
for _, item := range orders {
|
||||
stats.Income += item.Amount
|
||||
}
|
||||
|
||||
// 今日收入
|
||||
var todayOrders []model.Order
|
||||
h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&todayOrders)
|
||||
for _, item := range todayOrders {
|
||||
stats.TodayIncome += item.Amount
|
||||
}
|
||||
|
||||
// 订单总数
|
||||
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Count(&stats.Orders)
|
||||
|
||||
// 今日订单数
|
||||
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Count(&stats.TodayOrders)
|
||||
|
||||
// 图片生成任务统计
|
||||
var mjJobs, sdJobs, dallJobs, jimengImageJobs int64
|
||||
h.DB.Model(&model.MidJourneyJob{}).Count(&mjJobs)
|
||||
h.DB.Model(&model.SdJob{}).Count(&sdJobs)
|
||||
h.DB.Model(&model.DallJob{}).Count(&dallJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Count(&jimengImageJobs)
|
||||
stats.ImageJobs = mjJobs + sdJobs + dallJobs + jimengImageJobs
|
||||
|
||||
logger.Info("stats.ImageJobs", stats.ImageJobs)
|
||||
|
||||
// 今日图片生成任务统计
|
||||
var todayMjJobs, todaySdJobs, todayDallJobs, todayJimengImageJobs int64
|
||||
h.DB.Model(&model.MidJourneyJob{}).Where("created_at > ?", zeroTime).Count(&todayMjJobs)
|
||||
h.DB.Model(&model.SdJob{}).Where("created_at > ?", zeroTime).Count(&todaySdJobs)
|
||||
h.DB.Model(&model.DallJob{}).Where("created_at > ?", zeroTime).Count(&todayDallJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Where("created_at > ?", zeroTime).Count(&todayJimengImageJobs)
|
||||
stats.TodayImageJobs = todayMjJobs + todaySdJobs + todayDallJobs + todayJimengImageJobs
|
||||
|
||||
// 视频生成任务统计
|
||||
var videoJobs, jimengVideoJobs int64
|
||||
h.DB.Model(&model.VideoJob{}).Count(&videoJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Count(&jimengVideoJobs)
|
||||
stats.VideoJobs = videoJobs + jimengVideoJobs
|
||||
|
||||
// 今日视频生成任务统计
|
||||
var todayVideoJobs, todayJimengVideoJobs int64
|
||||
h.DB.Model(&model.VideoJob{}).Where("created_at > ?", zeroTime).Count(&todayVideoJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Where("created_at > ?", zeroTime).Count(&todayJimengVideoJobs)
|
||||
stats.TodayVideoJobs = todayVideoJobs + todayJimengVideoJobs
|
||||
|
||||
// 音乐生成任务统计
|
||||
h.DB.Model(&model.SunoJob{}).Count(&stats.MusicJobs)
|
||||
|
||||
// 今日音乐生成任务统计
|
||||
h.DB.Model(&model.SunoJob{}).Where("created_at > ?", zeroTime).Count(&stats.TodayMusicJobs)
|
||||
|
||||
// recentOrders: 最近10条已支付订单
|
||||
var orderList []model.Order
|
||||
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Order("created_at desc").Limit(10).Find(&orderList)
|
||||
for _, o := range orderList {
|
||||
stats.RecentOrders = append(stats.RecentOrders, OrderBrief{
|
||||
OrderNo: o.OrderNo,
|
||||
Amount: o.Amount,
|
||||
CreatedAt: o.CreatedAt,
|
||||
})
|
||||
}
|
||||
// recentUsers: 最近10个注册用户
|
||||
var userList []model.User
|
||||
h.DB.Model(&model.User{}).Order("created_at desc").Limit(10).Find(&userList)
|
||||
for _, u := range userList {
|
||||
lastActive := u.UpdatedAt
|
||||
if lastActive.IsZero() {
|
||||
lastActive = u.CreatedAt
|
||||
}
|
||||
stats.RecentUsers = append(stats.RecentUsers, UserBrief{
|
||||
Nickname: u.Nickname,
|
||||
Avatar: u.Avatar,
|
||||
LastActive: lastActive,
|
||||
})
|
||||
}
|
||||
|
||||
// 统计7天的订单的图表
|
||||
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
|
||||
var statsChart = make(map[string]map[string]float64)
|
||||
@@ -197,29 +81,23 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||
|
||||
// 统计用户7天增加的曲线
|
||||
var users []model.User
|
||||
err := h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users).Error
|
||||
if err == nil {
|
||||
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
|
||||
if res.Error == nil {
|
||||
for _, item := range users {
|
||||
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
|
||||
}
|
||||
}
|
||||
|
||||
// 统计7天算力消耗
|
||||
var chartPowerLogs []model.PowerLog
|
||||
err = h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", startDate).Find(&chartPowerLogs).Error
|
||||
if err == nil {
|
||||
for _, item := range chartPowerLogs {
|
||||
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Amount)
|
||||
}
|
||||
// 统计7天Token 消耗
|
||||
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
|
||||
for _, item := range historyMessages {
|
||||
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
||||
}
|
||||
|
||||
// 统计最近7天的订单
|
||||
var orders []model.Order
|
||||
err = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders).Error
|
||||
if err == nil {
|
||||
for _, item := range orders {
|
||||
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||
}
|
||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
||||
for _, item := range orders {
|
||||
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||
}
|
||||
|
||||
statsChart["users"] = userStatistic
|
||||
|
||||
@@ -9,7 +9,6 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -31,21 +30,6 @@ func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
||||
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *FunctionHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/function/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("token", h.GenToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FunctionHandler) Save(c *gin.Context) {
|
||||
var data vo.Function
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
@@ -135,6 +119,7 @@ func (h *FunctionHandler) GenToken(c *gin.Context) {
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
logger.Error("error with generate token", err)
|
||||
resp.ERROR(c)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -34,20 +32,6 @@ func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
||||
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ImageHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/image/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list/mj", h.MjList)
|
||||
group.POST("list/sd", h.SdList)
|
||||
group.POST("list/dall", h.DallList)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
type imageQuery struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Username string `json:"username"`
|
||||
@@ -205,10 +189,11 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = "mid-journey"
|
||||
power = job.Power
|
||||
userId = int(job.UserId)
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "sd":
|
||||
var job model.SdJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -220,10 +205,11 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = "stable-diffusion"
|
||||
power = job.Power
|
||||
userId = int(job.UserId)
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "dall":
|
||||
var job model.DallJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -239,13 +225,14 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if progress != 100 {
|
||||
err := h.userService.IncreasePower(uint(userId), power, model.PowerLog{
|
||||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: md,
|
||||
Remark: remark,
|
||||
|
||||
@@ -1,293 +0,0 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminJimengHandler 管理后台即梦AI处理器
|
||||
type AdminJimengHandler struct {
|
||||
handler.BaseHandler
|
||||
jimengClient *jimeng.Client
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengClient *jimeng.Client, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||
return &AdminJimengHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
jimengClient: jimengClient,
|
||||
userService: userService,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册即梦AI管理后台路由
|
||||
func (h *AdminJimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/admin/jimeng/")
|
||||
rg.GET("/jobs", h.Jobs)
|
||||
rg.GET("/jobs/:id", h.JobDetail)
|
||||
rg.POST("/jobs/remove", h.BatchRemove)
|
||||
rg.GET("/stats", h.Stats)
|
||||
rg.POST("/config/update", h.UpdateConfig)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *AdminJimengHandler) Jobs(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
taskType := h.GetTrim(c, "type")
|
||||
status := h.GetTrim(c, "status")
|
||||
|
||||
var tasks []model.JimengJob
|
||||
var total int64
|
||||
|
||||
session := h.DB.Model(&model.JimengJob{})
|
||||
|
||||
// 构建查询条件
|
||||
if userId > 0 {
|
||||
session = session.Where("user_id = ?", userId)
|
||||
}
|
||||
if taskType != "" {
|
||||
session = session.Where("type = ?", taskType)
|
||||
}
|
||||
if status != "" {
|
||||
session = session.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := session.Count(&total).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务数量失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = session.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&tasks).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"jobs": tasks,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// JobDetail 获取任务详情
|
||||
func (h *AdminJimengHandler) JobDetail(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var job model.JimengJob
|
||||
err = h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除任务
|
||||
func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
var req struct {
|
||||
JobIds []uint `json:"job_ids" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var deletedCount int64 = 0
|
||||
for _, jobId := range req.JobIds {
|
||||
var job model.JimengJob
|
||||
err := h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
continue // 跳过不存在的
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
if job.Status != model.JMTaskStatusSuccess && job.Power > 0 {
|
||||
remark := fmt.Sprintf("任务未成功,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: remark,
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = tx.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
tx.Commit()
|
||||
deletedCount++
|
||||
if job.ImgURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
}
|
||||
if job.VideoURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||
if err != nil {
|
||||
logger.Error("remove video failed: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": "批量删除成功",
|
||||
"deleted_count": deletedCount,
|
||||
})
|
||||
}
|
||||
|
||||
// Stats 获取统计信息
|
||||
func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
type StatResult struct {
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := h.DB.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 整理统计数据
|
||||
result := gin.H{
|
||||
"totalTasks": int64(0),
|
||||
"completedTasks": int64(0),
|
||||
"processingTasks": int64(0),
|
||||
"failedTasks": int64(0),
|
||||
"pendingTasks": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
|
||||
switch stat.Status {
|
||||
case model.JMTaskStatusInQueue:
|
||||
result["pendingTasks"] = stat.Count
|
||||
case model.JMTaskStatusSuccess:
|
||||
result["completedTasks"] = stat.Count
|
||||
case model.JMTaskStatusGenerating:
|
||||
result["processingTasks"] = stat.Count
|
||||
case model.JMTaskStatusFailed:
|
||||
result["failedTasks"] = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, result)
|
||||
}
|
||||
|
||||
// UpdateConfig 更新即梦AI配置
|
||||
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
var req types.JimengConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if req.AccessKey == "" {
|
||||
resp.ERROR(c, "AccessKey不能为空")
|
||||
return
|
||||
}
|
||||
if req.SecretKey == "" {
|
||||
resp.ERROR(c, "SecretKey不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证算力配置
|
||||
if req.Power.TextToImage <= 0 {
|
||||
resp.ERROR(c, "文生图算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageToImage <= 0 {
|
||||
resp.ERROR(c, "图生图算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageEdit <= 0 {
|
||||
resp.ERROR(c, "图片编辑算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageEffects <= 0 {
|
||||
resp.ERROR(c, "图片特效算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.TextToVideo <= 0 {
|
||||
resp.ERROR(c, "文生视频算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageToVideo <= 0 {
|
||||
resp.ERROR(c, "图生视频算力必须大于0")
|
||||
return
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
tx := h.DB.Begin()
|
||||
value := utils.JsonEncode(&req)
|
||||
var exist model.Config
|
||||
tx.Where("name", types.ConfigKeyJimeng).First(&exist)
|
||||
|
||||
if exist.Id > 0 {
|
||||
exist.Value = value
|
||||
err := tx.Updates(&exist).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
exist.Name = types.ConfigKeyJimeng
|
||||
exist.Value = value
|
||||
err := tx.Create(&exist).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "创建配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新服务中的客户端配置
|
||||
err := h.jimengClient.UpdateConfig(req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
h.App.SysConfig.Jimeng = req
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
||||
}
|
||||
@@ -10,7 +10,6 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -34,21 +32,7 @@ func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
||||
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MediaHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/media/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("suno", h.SunoList)
|
||||
group.POST("videos", h.Videos)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
type mediaQuery struct {
|
||||
Type string `json:"type"` // 任务类型 luma, keling
|
||||
Prompt string `json:"prompt"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt []string `json:"created_at"`
|
||||
@@ -100,15 +84,15 @@ func (h *MediaHandler) SunoList(c *gin.Context) {
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
// Videos 视频任务列表
|
||||
func (h *MediaHandler) Videos(c *gin.Context) {
|
||||
// LumaList Luma 视频任务列表
|
||||
func (h *MediaHandler) LumaList(c *gin.Context) {
|
||||
var data mediaQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{}).Where("type", data.Type)
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
@@ -164,12 +148,12 @@ func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = "suno"
|
||||
power = job.Power
|
||||
userId = int(job.UserId)
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("SUNO 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.AudioURL
|
||||
break
|
||||
case "luma":
|
||||
case "keling":
|
||||
var job model.VideoJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
@@ -180,20 +164,21 @@ func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = job.Type
|
||||
power = job.Power
|
||||
userId = int(job.UserId)
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("LUMA 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.VideoURL
|
||||
if fileURL == "" {
|
||||
fileURL = job.WaterURL
|
||||
}
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if progress != 100 {
|
||||
err := h.userService.IncreasePower(uint(userId), power, model.PowerLog{
|
||||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: md,
|
||||
Remark: remark,
|
||||
|
||||
@@ -27,16 +27,6 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MenuHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/menu/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service/moderation"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ModerationHandler struct {
|
||||
handler.BaseHandler
|
||||
sysConfig *types.SystemConfig
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewModerationHandler(app *core.AppServer, db *gorm.DB, sysConfig *types.SystemConfig, moderationManager *moderation.ServiceManager) *ModerationHandler {
|
||||
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, sysConfig: sysConfig, moderationManager: moderationManager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ModerationHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/moderation/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("batch-remove", h.BatchRemove)
|
||||
group.GET("source-list", h.GetSourceList)
|
||||
group.POST("config", h.UpdateModeration)
|
||||
group.POST("test", h.TestModeration)
|
||||
}
|
||||
}
|
||||
|
||||
// List 获取文本审核记录列表
|
||||
func (h *ModerationHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Username string `json:"username"`
|
||||
Source string `json:"source"`
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
|
||||
// 构建查询条件
|
||||
if data.Username != "" {
|
||||
// 通过用户名查找用户ID
|
||||
var user model.User
|
||||
if err := h.DB.Where("username LIKE ?", "%"+data.Username+"%").First(&user).Error; err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
|
||||
if data.Source != "" {
|
||||
session = session.Where("source", data.Source)
|
||||
}
|
||||
|
||||
if data.StartDate != "" && data.EndDate != "" {
|
||||
startTime := data.StartDate + " 00:00:00"
|
||||
endTime := data.EndDate + " 23:59:59"
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", startTime, endTime)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
var total int64
|
||||
session.Model(&model.Moderation{}).Count(&total)
|
||||
|
||||
// 分页
|
||||
page := data.Page
|
||||
pageSize := data.PageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
session = session.Offset(offset).Limit(pageSize)
|
||||
|
||||
// 查询数据
|
||||
var items []model.Moderation
|
||||
err := session.Order("id DESC").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userIds := make([]uint, 0)
|
||||
for _, item := range items {
|
||||
userIds = append(userIds, item.UserId)
|
||||
}
|
||||
|
||||
var users []model.User
|
||||
if len(userIds) > 0 {
|
||||
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||
}
|
||||
|
||||
userMap := make(map[uint]string)
|
||||
for _, user := range users {
|
||||
userMap[user.Id] = user.Username
|
||||
}
|
||||
|
||||
// 转换为响应数据
|
||||
list := make([]map[string]any, 0)
|
||||
for _, item := range items {
|
||||
var moderation types.ModerationResult
|
||||
err := utils.JsonDecode(item.Result, &moderation)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var result []string
|
||||
for value, label := range types.ModerationCategories {
|
||||
if moderation.Categories[value] {
|
||||
result = append(result, label)
|
||||
}
|
||||
}
|
||||
list = append(list, map[string]any{
|
||||
"id": item.Id,
|
||||
"user_id": item.UserId,
|
||||
"username": userMap[item.UserId],
|
||||
"source": item.Source,
|
||||
"input": item.Input,
|
||||
"output": item.Output,
|
||||
"result": result,
|
||||
"created_at": item.CreatedAt.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, map[string]any{
|
||||
"items": list,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ModerationHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Where("id", id).Delete(&model.Moderation{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除文本审核记录
|
||||
func (h *ModerationHandler) BatchRemove(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data.Ids) == 0 {
|
||||
resp.ERROR(c, "请选择要删除的记录")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.Moderation{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// 获取 source 列表
|
||||
func (h *ModerationHandler) GetSourceList(c *gin.Context) {
|
||||
sources := []gin.H{
|
||||
{
|
||||
"id": types.ModerationSourceChat,
|
||||
"name": "AI对话",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceMJ,
|
||||
"name": "Midjourney 绘图",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceDalle,
|
||||
"name": "Dalle 绘图",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceSD,
|
||||
"name": "StableDiffusion 绘图",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceSuno,
|
||||
"name": "Suno 音乐",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceVideo,
|
||||
"name": "视频生成",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceJiMeng,
|
||||
"name": "即梦AI",
|
||||
},
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, sources)
|
||||
}
|
||||
|
||||
// UpdateModeration 更新文本审查配置
|
||||
func (h *ModerationHandler) UpdateModeration(c *gin.Context) {
|
||||
var data types.ModerationConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var config model.Config
|
||||
err := h.DB.Where("name", types.ConfigKeyModeration).First(&config).Error
|
||||
if err != nil {
|
||||
config.Name = types.ConfigKeyModeration
|
||||
config.Value = utils.JsonEncode(data)
|
||||
err = h.DB.Create(&config).Error
|
||||
} else {
|
||||
config.Value = utils.JsonEncode(data)
|
||||
err = h.DB.Updates(&config).Error
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.moderationManager.UpdateConfig(data)
|
||||
h.sysConfig.Moderation = data
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// 测试结果类型,用于前端显示
|
||||
type ModerationTestResult struct {
|
||||
IsAbnormal bool `json:"isAbnormal"`
|
||||
Details []ModerationTestDetail `json:"details"`
|
||||
}
|
||||
|
||||
type ModerationTestDetail struct {
|
||||
Category string `json:"category"`
|
||||
Description string `json:"description"`
|
||||
Confidence string `json:"confidence"`
|
||||
IsCategory bool `json:"isCategory"`
|
||||
}
|
||||
|
||||
// TestModeration 测试文本审查服务
|
||||
func (h *ModerationHandler) TestModeration(c *gin.Context) {
|
||||
var data struct {
|
||||
Text string `json:"text"`
|
||||
Service string `json:"service"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Text == "" {
|
||||
resp.ERROR(c, "测试文本不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否启用了文本审查
|
||||
if !h.sysConfig.Moderation.Enable {
|
||||
resp.ERROR(c, "文本审查服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前激活的审核服务
|
||||
service := h.moderationManager.GetService()
|
||||
// 执行文本审核
|
||||
result, err := service.Moderate(data.Text)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "审核服务调用失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为前端需要的格式
|
||||
testResult := ModerationTestResult{
|
||||
IsAbnormal: result.Flagged,
|
||||
Details: make([]ModerationTestDetail, 0),
|
||||
}
|
||||
|
||||
// 构建详细信息
|
||||
for category, description := range types.ModerationCategories {
|
||||
score := result.CategoryScores[category]
|
||||
isCategory := result.Categories[category]
|
||||
|
||||
testResult.Details = append(testResult.Details, ModerationTestDetail{
|
||||
Category: category,
|
||||
Description: description,
|
||||
Confidence: fmt.Sprintf("%.2f", score),
|
||||
IsCategory: isCategory,
|
||||
})
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, testResult)
|
||||
}
|
||||
@@ -29,14 +29,6 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *OrderHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/order/")
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("clear", h.Clear)
|
||||
}
|
||||
|
||||
func (h *OrderHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
OrderNo string `json:"order_no"`
|
||||
@@ -76,16 +68,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
order.Id = item.Id
|
||||
order.CreatedAt = item.CreatedAt.Unix()
|
||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||
payChannel, ok := types.PayChannel[item.Channel]
|
||||
payMethod, ok := types.PayMethods[item.PayWay]
|
||||
if !ok {
|
||||
payChannel = item.Channel
|
||||
payMethod = item.PayWay
|
||||
}
|
||||
payWays, ok := types.PayWays[item.PayWay]
|
||||
payName, ok := types.PayNames[item.PayType]
|
||||
if !ok {
|
||||
payWays = item.PayWay
|
||||
payName = item.PayWay
|
||||
}
|
||||
order.ChannelName = payChannel
|
||||
order.PayName = payWays
|
||||
order.PayMethod = payMethod
|
||||
order.PayName = payName
|
||||
list = append(list, order)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
@@ -129,8 +121,8 @@ func (h *OrderHandler) Clear(c *gin.Context) {
|
||||
}
|
||||
deleteIds := make([]uint, 0)
|
||||
for _, order := range orders {
|
||||
// 只删除超时的未支付订单
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * time.Duration(h.App.SysConfig.Base.OrderPayTimeout))) {
|
||||
// 只删除 15 分钟内的未支付订单
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
|
||||
deleteIds = append(deleteIds, order.Id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,16 +28,9 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PowerLogHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
}
|
||||
|
||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Username string `json:"username"`
|
||||
UserId uint `json:"userid"`
|
||||
Type int `json:"type"`
|
||||
Model string `json:"model"`
|
||||
Date []string `json:"date"`
|
||||
@@ -56,12 +49,6 @@ func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
if data.Type > 0 {
|
||||
session = session.Where("type", data.Type)
|
||||
}
|
||||
if data.UserId > 0 {
|
||||
session = session.Where("user_id", data.UserId)
|
||||
}
|
||||
if data.Username != "" {
|
||||
session = session.Where("username", data.Username)
|
||||
}
|
||||
if len(data.Date) == 2 {
|
||||
start := data.Date[0] + " 00:00:00"
|
||||
end := data.Date[1] + " 00:00:00"
|
||||
|
||||
@@ -15,10 +15,9 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProductHandler struct {
|
||||
@@ -29,22 +28,14 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ProductHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/product/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
|
||||
func (h *ProductHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Price float64 `json:"price"`
|
||||
Discount float64 `json:"discount"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Days int `json:"days"`
|
||||
Power int `json:"power"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
@@ -54,10 +45,12 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
item := model.Product{
|
||||
Name: data.Name,
|
||||
Price: data.Price,
|
||||
Power: data.Power,
|
||||
Enabled: data.Enabled}
|
||||
Name: data.Name,
|
||||
Price: data.Price,
|
||||
Discount: data.Discount,
|
||||
Days: data.Days,
|
||||
Power: data.Power,
|
||||
Enabled: data.Enabled}
|
||||
item.Id = data.Id
|
||||
if item.Id > 0 {
|
||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||
|
||||
@@ -29,16 +29,6 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
|
||||
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *RedeemHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/redeem/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("create", h.Create)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("export", h.Export)
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) List(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
@@ -116,8 +106,8 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 设置响应头,告诉浏览器这是一个附件,需要下载
|
||||
c.Header("Prompt-Disposition", "attachment; filename=output.csv")
|
||||
c.Header("Prompt-Type", "text/csv")
|
||||
c.Header("Content-Disposition", "attachment; filename=output.csv")
|
||||
c.Header("Content-Type", "text/csv")
|
||||
|
||||
// 创建一个 CSV writer
|
||||
writer := csv.NewWriter(c.Writer)
|
||||
|
||||
@@ -9,15 +9,13 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/handler"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UploadHandler struct {
|
||||
@@ -29,39 +27,15 @@ func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderMan
|
||||
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *UploadHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/upload")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("", h.Upload)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
// 判断文件大小
|
||||
f, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 {
|
||||
resp.ERROR(c, "文件大小超过限制")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
userId := 0
|
||||
res := h.DB.Create(&model.File{
|
||||
UserId: uint(userId),
|
||||
UserId: userId,
|
||||
Name: file.Name,
|
||||
ObjKey: file.ObjKey,
|
||||
URL: file.URL,
|
||||
|
||||
@@ -10,7 +10,6 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
@@ -18,11 +17,10 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -36,29 +34,11 @@ func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.Li
|
||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *UserHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/user/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("loginLog", h.LoginLog)
|
||||
group.GET("genLoginLink", h.GenLoginLink)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}
|
||||
}
|
||||
|
||||
// List 用户列表
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
username := h.GetTrim(c, "username")
|
||||
mobile := h.GetTrim(c, "mobile")
|
||||
email := h.GetTrim(c, "email")
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
var items []model.User
|
||||
@@ -69,12 +49,6 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
if username != "" {
|
||||
session = session.Where("username LIKE ?", "%"+username+"%")
|
||||
}
|
||||
if mobile != "" {
|
||||
session = session.Where("mobile LIKE ?", "%"+mobile+"%")
|
||||
}
|
||||
if email != "" {
|
||||
session = session.Where("email LIKE ?", "%"+email+"%")
|
||||
}
|
||||
|
||||
session.Model(&model.User{}).Count(&total)
|
||||
res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
|
||||
@@ -194,7 +168,6 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
Power: data.Power,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode(data.ChatRoles),
|
||||
ChatConfig: "{}",
|
||||
ChatModels: utils.JsonEncode(data.ChatModels),
|
||||
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
||||
}
|
||||
@@ -338,36 +311,3 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, logs))
|
||||
}
|
||||
|
||||
// GenLoginLink 生成登录链接
|
||||
func (h *UserHandler) GenLoginLink(c *gin.Context) {
|
||||
id := c.Query("id")
|
||||
if id == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
if err := h.DB.Where("id = ?", id).First(&user).Error; err != nil {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
}
|
||||
// 保存到 redis
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, tokenString)
|
||||
}
|
||||
|
||||
@@ -15,9 +15,8 @@ import (
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -70,14 +69,6 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
|
||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||
}
|
||||
|
||||
func (h *BaseHandler) GetAdminId(c *gin.Context) uint {
|
||||
userId, ok := c.Get(types.AdminUserID)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||
}
|
||||
|
||||
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
|
||||
return h.GetLoginUserId(c) > 0
|
||||
}
|
||||
|
||||
@@ -8,45 +8,23 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 今日头条函数实现
|
||||
|
||||
type CaptchaHandler struct {
|
||||
App *core.AppServer
|
||||
service *service.CaptchaService
|
||||
}
|
||||
|
||||
func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler {
|
||||
return &CaptchaHandler{App: app, service: s}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *CaptchaHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/captcha/")
|
||||
|
||||
// 无需授权的接口
|
||||
group.GET("get", h.Get)
|
||||
group.POST("check", h.Check)
|
||||
group.GET("slide/get", h.SlideGet)
|
||||
group.POST("slide/check", h.SlideCheck)
|
||||
group.GET("config", h.GetConfig)
|
||||
}
|
||||
|
||||
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
|
||||
resp.SUCCESS(c, gin.H{"enabled": h.service.GetConfig().Enabled, "type": h.service.GetConfig().Type})
|
||||
func NewCaptchaHandler(s *service.CaptchaService) *CaptchaHandler {
|
||||
return &CaptchaHandler{service: s}
|
||||
}
|
||||
|
||||
func (h *CaptchaHandler) Get(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.Get()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
@@ -58,11 +36,6 @@ func (h *CaptchaHandler) Get(c *gin.Context) {
|
||||
|
||||
// Check verify the captcha data
|
||||
func (h *CaptchaHandler) Check(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Key string `json:"key"`
|
||||
Dots string `json:"dots"`
|
||||
@@ -82,11 +55,6 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
|
||||
|
||||
// SlideGet 获取滑动验证图片
|
||||
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.SlideGet()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
@@ -98,11 +66,6 @@ func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||
|
||||
// SlideCheck 滑动验证结果校验
|
||||
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Key string `json:"key"`
|
||||
X int `json:"x"`
|
||||
|
||||
@@ -19,12 +19,6 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
||||
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppTypeHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/app/type/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 获取App类型列表
|
||||
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||
var items []model.AppType
|
||||
|
||||
@@ -14,171 +14,59 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"io"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ChatEventStart = "start"
|
||||
ChatEventEnd = "end"
|
||||
ChatEventComplete = "complete"
|
||||
ChatEventError = "error"
|
||||
ChatEventMessageDelta = "message_delta"
|
||||
ChatEventTitle = "title"
|
||||
)
|
||||
|
||||
type ChatInput struct {
|
||||
UserId uint `json:"user_id"`
|
||||
RoleId uint `json:"role_id"`
|
||||
ModelId uint `json:"model_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Tools []uint `json:"tools"`
|
||||
Stream bool `json:"stream"`
|
||||
Files []vo.File `json:"files"`
|
||||
ChatModel model.ChatModel `json:"chat_model,omitempty"`
|
||||
ChatRole model.ChatApp `json:"chat_role,omitempty"`
|
||||
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID,用于重新生成答案的时候过滤上下文
|
||||
}
|
||||
|
||||
type ChatHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/chat/")
|
||||
|
||||
// 聊天接口不需要授权(已在authConfig中配置)
|
||||
group.Any("message", h.Chat)
|
||||
|
||||
// 其他接口需要用户授权
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.GET("detail", h.Detail)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("history", h.History)
|
||||
group.GET("clear", h.Clear)
|
||||
group.POST("tokens", h.Tokens)
|
||||
group.GET("stop", h.StopGenerate)
|
||||
group.POST("tts", h.TextToSpeech)
|
||||
}
|
||||
}
|
||||
|
||||
// Chat 处理聊天请求
|
||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
// 设置SSE响应头
|
||||
c.Header("Prompt-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
var input ChatInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
pushMessage(c, ChatEventError, types.InvalidArgs)
|
||||
c.Abort()
|
||||
return
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
||||
if !h.App.Debug {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Recover message from error: ", r)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
// 这里做个全局的异常处理,防止整个请求异常,导致 SSE 连接断开
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("chat handler error: %v", err)
|
||||
pushMessage(c, ChatEventError, err)
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
// 使用旧的聊天数据覆盖模型和角色ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", input.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
input.ModelId = chat.ModelId
|
||||
input.RoleId = chat.RoleId
|
||||
}
|
||||
|
||||
// 验证聊天角色
|
||||
var chatRole model.ChatApp
|
||||
err := h.DB.First(&chatRole, input.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
|
||||
return
|
||||
}
|
||||
input.ChatRole = chatRole
|
||||
|
||||
// 获取模型信息
|
||||
var chatModel model.ChatModel
|
||||
err = h.DB.Where("id", input.ModelId).First(&chatModel).Error
|
||||
if err != nil || !chatModel.Enabled {
|
||||
pushMessage(c, ChatEventError, "当前AI模型暂未启用,请更换模型后再发起对话!")
|
||||
return
|
||||
}
|
||||
input.ChatModel = chatModel
|
||||
|
||||
// 发送消息
|
||||
err = h.sendMessage(ctx, input, c)
|
||||
if err != nil {
|
||||
pushMessage(c, ChatEventError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pushMessage(c, ChatEventEnd, "对话完成")
|
||||
}
|
||||
|
||||
func pushMessage(c *gin.Context, msgType string, content interface{}) {
|
||||
c.SSEvent("message", map[string]interface{}{
|
||||
"type": msgType,
|
||||
"body": content,
|
||||
})
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.Context) error {
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, input.UserId)
|
||||
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
|
||||
if res.Error != nil {
|
||||
return errors.New("未授权用户,您正在进行非法操作!")
|
||||
}
|
||||
@@ -189,12 +77,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
return errors.New("User 对象转换失败," + err.Error())
|
||||
}
|
||||
|
||||
if !userVo.Status {
|
||||
if userVo.Status == false {
|
||||
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
}
|
||||
|
||||
if userVo.Power < input.ChatModel.Power {
|
||||
return fmt.Errorf("您的算力不足,请购买算力。")
|
||||
if userVo.Power < session.Model.Power {
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power)
|
||||
}
|
||||
|
||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||
@@ -202,29 +90,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
}
|
||||
|
||||
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||
promptTokens, _ := utils.CalcTokens(input.Prompt, input.ChatModel.Value)
|
||||
if promptTokens > input.ChatModel.MaxContext {
|
||||
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||
if promptTokens > session.Model.MaxContext {
|
||||
|
||||
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: input.ChatModel.Value,
|
||||
Stream: input.Stream,
|
||||
Temperature: input.ChatModel.Temperature,
|
||||
Model: session.Model.Value,
|
||||
}
|
||||
// 兼容 OpenAI 模型
|
||||
if strings.HasPrefix(input.ChatModel.Value, "o1-") ||
|
||||
strings.HasPrefix(input.ChatModel.Value, "o3-") ||
|
||||
strings.HasPrefix(input.ChatModel.Value, "gpt") {
|
||||
req.MaxCompletionTokens = input.ChatModel.MaxTokens
|
||||
// 兼容 GPT-O1 模型
|
||||
if strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
utils.SendChunkMsg(ws, "AI 正在思考...\n")
|
||||
req.Stream = false
|
||||
session.Start = time.Now().Unix()
|
||||
} else {
|
||||
req.MaxTokens = input.ChatModel.MaxTokens
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.Stream = session.Stream
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 && !strings.HasPrefix(input.ChatModel.Value, "o1-") {
|
||||
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
var items []model.Function
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", input.Tools).Find(&items)
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
|
||||
if res.Error == nil {
|
||||
var tools = make([]types.Tool, 0)
|
||||
for _, v := range items {
|
||||
@@ -255,34 +143,25 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
}
|
||||
|
||||
// 加载聊天上下文
|
||||
chatCtx := make([]any, 0)
|
||||
messages := make([]any, 0)
|
||||
if h.App.SysConfig.Base.EnableContext {
|
||||
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
|
||||
if h.App.SysConfig.Base.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
|
||||
if input.LastMsgId > 0 { // 重新生成逻辑
|
||||
var lastMessage model.ChatMessage
|
||||
err = dbSession.Where("id <= ?", input.LastMsgId).Where("type", types.PromptMsg).First(&lastMessage).Error
|
||||
if err != nil {
|
||||
input.LastMsgId = 0
|
||||
} else {
|
||||
input.LastMsgId = lastMessage.Id
|
||||
}
|
||||
dbSession = dbSession.Where("id < ?", input.LastMsgId)
|
||||
// 删除对应的聊天记录
|
||||
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
|
||||
}
|
||||
err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
||||
if err == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
if msg.Type == types.ReplyMsg {
|
||||
ms.Role = "assistant"
|
||||
chatCtx := make([]interface{}, 0)
|
||||
messages := make([]interface{}, 0)
|
||||
if h.App.SysConfig.EnableContext {
|
||||
if h.ChatContexts.Has(session.ChatId) {
|
||||
messages = h.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
_ = utils.JsonDecode(role.Context, &messages)
|
||||
if h.App.SysConfig.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||
if res.Error == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
if msg.Type == types.ReplyMsg {
|
||||
ms.Role = "assistant"
|
||||
}
|
||||
chatCtx = append(chatCtx, ms)
|
||||
}
|
||||
chatCtx = append(chatCtx, ms)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -297,124 +176,90 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
v := messages[i]
|
||||
tks, _ = utils.CalcTokens(utils.JsonEncode(v), req.Model)
|
||||
// 上下文 token 超出了模型的最大上下文长度
|
||||
if tokens+tks >= input.ChatModel.MaxContext {
|
||||
if tokens+tks >= session.Model.MaxContext {
|
||||
break
|
||||
}
|
||||
|
||||
// 上下文的深度超出了模型的最大上下文深度
|
||||
if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep {
|
||||
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
|
||||
break
|
||||
}
|
||||
|
||||
tokens += tks
|
||||
chatCtx = append(chatCtx, v)
|
||||
}
|
||||
}
|
||||
reqMgs := make([]any, 0)
|
||||
|
||||
// 添加引导提示词,防止模型生成违规内容
|
||||
if h.App.SysConfig.Moderation.EnableGuide {
|
||||
reqMgs = append(reqMgs, map[string]any{
|
||||
"role": "system",
|
||||
"content": h.App.SysConfig.Moderation.GuidePrompt,
|
||||
})
|
||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||
}
|
||||
reqMgs := make([]interface{}, 0)
|
||||
|
||||
for i := len(chatCtx) - 1; i >= 0; i-- {
|
||||
reqMgs = append(reqMgs, chatCtx[i])
|
||||
}
|
||||
|
||||
fileContents := make([]string, 0) // 文件内容
|
||||
var finalPrompt = input.Prompt
|
||||
imgList := make([]any, 0)
|
||||
for _, file := range input.Files {
|
||||
logger.Debugf("detected file: %+v", file.URL)
|
||||
// 处理图片
|
||||
if isImageURL(file.URL) {
|
||||
imgList = append(imgList, gin.H{
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": file.URL,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// 处理文件,提取文件内容
|
||||
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
||||
fullPrompt := prompt
|
||||
text := prompt
|
||||
// extract files in prompt
|
||||
files := utils.ExtractFileURLs(prompt)
|
||||
logger.Debugf("detected FILES: %+v", files)
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
|
||||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
|
||||
strings.HasSuffix(session.Model.Value, "claude-3")) {
|
||||
contents := make([]string, 0)
|
||||
var file model.File
|
||||
for _, v := range files {
|
||||
h.DB.Where("url = ?", v).First(&file)
|
||||
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
continue
|
||||
} else {
|
||||
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
logger.Debugf("fileContents: %s", fileContents)
|
||||
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fileContents) > 0 {
|
||||
finalPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(fileContents, "\n"), input.Prompt)
|
||||
tokens, _ := utils.CalcTokens(finalPrompt, req.Model)
|
||||
if tokens > input.ChatModel.MaxContext {
|
||||
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
|
||||
if tokens > session.Model.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
} else {
|
||||
finalPrompt = input.Prompt
|
||||
}
|
||||
logger.Debug("最终Prompt:", fullPrompt)
|
||||
|
||||
if len(imgList) > 0 {
|
||||
imgList = append(imgList, map[string]any{
|
||||
// extract images from prompt
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": v,
|
||||
},
|
||||
})
|
||||
}
|
||||
data = append(data, gin.H{
|
||||
"type": "text",
|
||||
"text": input.Prompt,
|
||||
})
|
||||
req.Messages = append(reqMgs, map[string]any{
|
||||
"role": "user",
|
||||
"content": imgList,
|
||||
"text": strings.TrimSpace(text),
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]any{
|
||||
"role": "user",
|
||||
"content": finalPrompt,
|
||||
})
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
|
||||
return h.sendOpenAiMessage(req, userVo, ctx, input, c)
|
||||
}
|
||||
logger.Debugf("%+v", req.Messages)
|
||||
|
||||
// 判断一个 URL 是否图片链接
|
||||
func isImageURL(url string) bool {
|
||||
// 检查是否是有效的URL
|
||||
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(path.Ext(url))
|
||||
validImageExts := map[string]bool{
|
||||
".jpg": true,
|
||||
".jpeg": true,
|
||||
".png": true,
|
||||
".gif": true,
|
||||
".bmp": true,
|
||||
".webp": true,
|
||||
".svg": true,
|
||||
".ico": true,
|
||||
}
|
||||
|
||||
if !validImageExts[ext] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 发送HEAD请求检查Content-Type
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
resp, err := client.Head(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
return strings.HasPrefix(contentType, "image/")
|
||||
return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws)
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
@@ -430,17 +275,17 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
|
||||
if data.Text == "" && data.ChatId != "" {
|
||||
var item model.ChatMessage
|
||||
userId, _ := c.Get(types.LoginUserID)
|
||||
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, item.Tokens)
|
||||
return
|
||||
}
|
||||
//if data.Text == "" && data.ChatId != "" {
|
||||
// var item model.ChatMessage
|
||||
// userId, _ := c.Get(types.LoginUserID)
|
||||
// res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
|
||||
// if res.Error != nil {
|
||||
// resp.ERROR(c, res.Error.Error())
|
||||
// return
|
||||
// }
|
||||
// resp.SUCCESS(c, item.Tokens)
|
||||
// return
|
||||
//}
|
||||
|
||||
tokens, err := utils.CalcTokens(data.Text, data.Model)
|
||||
if err != nil {
|
||||
@@ -483,14 +328,15 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
|
||||
// 发送请求到 OpenAI 服务器
|
||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if input.ChatModel.KeyId > 0 {
|
||||
h.DB.Where("id", input.ChatModel.KeyId).Where("enabled", true).Find(apiKey)
|
||||
} else { // use the last unused key
|
||||
if session.Model.KeyId > 0 {
|
||||
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if apiKey.Id == 0 {
|
||||
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||
}
|
||||
|
||||
if apiKey.Id == 0 {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
@@ -501,14 +347,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input
|
||||
return nil, err
|
||||
}
|
||||
logger.Debugf("对话请求消息体:%+v", req)
|
||||
var apiURL string
|
||||
p, _ := url.Parse(apiKey.ApiURL)
|
||||
// 如果设置的是 BASE_URL 没有路径,则添加 /v1/chat/completions
|
||||
if p.Path == "" {
|
||||
apiURL = fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
} else {
|
||||
apiURL = apiKey.ApiURL
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
@@ -540,16 +380,16 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input
|
||||
}
|
||||
|
||||
// 扣减用户算力
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens int, replyTokens int) {
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
|
||||
power := 1
|
||||
if input.ChatModel.Power > 0 {
|
||||
power = input.ChatModel.Power
|
||||
if session.Model.Power > 0 {
|
||||
power = session.Model.Power
|
||||
}
|
||||
|
||||
err := h.userService.DecreasePower(userVo.Id, power, model.PowerLog{
|
||||
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: input.ChatModel.Value,
|
||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", input.ChatModel.Name, promptTokens, replyTokens),
|
||||
Model: session.Model.Value,
|
||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
@@ -557,43 +397,22 @@ func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens
|
||||
}
|
||||
|
||||
func (h *ChatHandler) saveChatHistory(
|
||||
c *gin.Context,
|
||||
req types.ApiRequest,
|
||||
usage Usage,
|
||||
message types.Message,
|
||||
input ChatInput,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
userVo vo.User,
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(usage.Content)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
logger.Debugf("moderationResult: %+v", moderationResult)
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: userVo.Id,
|
||||
Source: types.ModerationSourceChat,
|
||||
Input: usage.Prompt,
|
||||
Output: usage.Content,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
pushMessage(c, ChatEventError, "很抱歉,内容触发敏感词预警,AI 无法回答!!!")
|
||||
// 更新用户算力
|
||||
if input.ChatModel.Power > 0 {
|
||||
h.subUserPower(userVo, input, 0, 0)
|
||||
}
|
||||
return
|
||||
}
|
||||
// 更新上下文消息
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx := req.Messages // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
var promptTokens, replyTokens, totalTokens int
|
||||
@@ -604,15 +423,12 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: input.ChatId,
|
||||
RoleId: input.RoleId,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: utils.JsonEncode(vo.MsgContent{
|
||||
Text: usage.Prompt,
|
||||
Files: input.Files,
|
||||
}),
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(usage.Prompt),
|
||||
Tokens: promptTokens,
|
||||
TotalTokens: promptTokens,
|
||||
UseContext: true,
|
||||
@@ -635,15 +451,12 @@ func (h *ChatHandler) saveChatHistory(
|
||||
totalTokens = replyTokens + getTotalTokens(req)
|
||||
}
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: input.ChatId,
|
||||
RoleId: input.RoleId,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: input.ChatRole.Icon,
|
||||
Content: utils.JsonEncode(vo.MsgContent{
|
||||
Text: message.Content,
|
||||
Files: input.Files,
|
||||
}),
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: usage.Content,
|
||||
Tokens: replyTokens,
|
||||
TotalTokens: totalTokens,
|
||||
UseContext: true,
|
||||
@@ -656,34 +469,18 @@ func (h *ChatHandler) saveChatHistory(
|
||||
logger.Error("failed to save reply history message: ", err)
|
||||
}
|
||||
|
||||
// 发送完整聊天记录给前端
|
||||
var messageVo vo.ChatMessage
|
||||
err = utils.CopyObject(historyReplyMsg, &messageVo)
|
||||
if err == nil {
|
||||
// 解析内容
|
||||
var content vo.MsgContent
|
||||
err = utils.JsonDecode(historyReplyMsg.Content, &content)
|
||||
if err != nil {
|
||||
content.Text = historyReplyMsg.Content
|
||||
}
|
||||
messageVo.Content = content
|
||||
messageVo.CreatedAt = historyReplyMsg.CreatedAt.Unix()
|
||||
messageVo.UpdatedAt = historyReplyMsg.UpdatedAt.Unix()
|
||||
pushMessage(c, ChatEventComplete, messageVo)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
if input.ChatModel.Power > 0 {
|
||||
h.subUserPower(userVo, input, promptTokens, replyTokens)
|
||||
if session.Model.Power > 0 {
|
||||
h.subUserPower(userVo, session, promptTokens, replyTokens)
|
||||
}
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
err = h.DB.Where("chat_id = ?", input.ChatId).First(&chatItem).Error
|
||||
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
|
||||
if err != nil {
|
||||
chatItem.ChatId = input.ChatId
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = userVo.Id
|
||||
chatItem.RoleId = input.RoleId
|
||||
chatItem.ModelId = input.ModelId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(usage.Prompt) > 30 {
|
||||
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
|
||||
} else {
|
||||
@@ -697,102 +494,28 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
}
|
||||
|
||||
// TextToSpeech 文本生成语音
|
||||
func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
var data struct {
|
||||
ModelId int `json:"model_id"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
// 将AI回复消息中生成的图片链接下载到本地
|
||||
func (h *ChatHandler) extractImgUrl(text string) string {
|
||||
pattern := `!\[([^\]]*)]\(([^)]+)\)`
|
||||
re := regexp.MustCompile(pattern)
|
||||
matches := re.FindAllStringSubmatch(text, -1)
|
||||
|
||||
textHash := utils.Sha256(fmt.Sprintf("%d/%s", data.ModelId, data.Text))
|
||||
audioFile := fmt.Sprintf("%s/audio", h.App.Config.StaticDir)
|
||||
if _, err := os.Stat(audioFile); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 下载图片并替换链接地址
|
||||
for _, match := range matches {
|
||||
imageURL := match[2]
|
||||
logger.Debug(imageURL)
|
||||
// 对于相同地址的图片,已经被替换了,就不再重复下载了
|
||||
if !strings.Contains(text, imageURL) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(audioFile, 0755); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
audioFile = fmt.Sprintf("%s/%s.mp3", audioFile, textHash)
|
||||
if _, err := os.Stat(audioFile); err == nil {
|
||||
// 设置响应头
|
||||
c.Header("Prompt-Type", "audio/mpeg")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
|
||||
c.File(audioFile)
|
||||
return
|
||||
}
|
||||
newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询模型
|
||||
var chatModel model.ChatModel
|
||||
err := h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "找不到语音模型")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用 DeepSeek 的 API 接口
|
||||
var apiKey model.ApiKey
|
||||
if chatModel.KeyId > 0 {
|
||||
h.DB.Where("id", chatModel.KeyId).First(&apiKey)
|
||||
}
|
||||
if apiKey.Id == 0 {
|
||||
h.DB.Where("type", "tts").Where("enabled", true).First(&apiKey)
|
||||
}
|
||||
if apiKey.Id == 0 {
|
||||
resp.ERROR(c, "no TTS API key, please import key")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("chatModel: %+v, apiKey: %+v", chatModel, apiKey)
|
||||
|
||||
// 调用 openai tts api
|
||||
config := openai.DefaultConfig(apiKey.Value)
|
||||
config.BaseURL = apiKey.ApiURL + "/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
voice := openai.VoiceAlloy
|
||||
var options map[string]string
|
||||
err = utils.JsonDecode(chatModel.Options, &options)
|
||||
if err == nil {
|
||||
voice = openai.SpeechVoice(options["voice"])
|
||||
}
|
||||
req := openai.CreateSpeechRequest{
|
||||
Model: openai.SpeechModel(chatModel.Value),
|
||||
Input: data.Text,
|
||||
Voice: voice,
|
||||
}
|
||||
|
||||
audioData, err := client.CreateSpeech(context.Background(), req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先将音频数据读取到内存
|
||||
audioBytes, err := io.ReadAll(audioData)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存到音频文件
|
||||
err = os.WriteFile(audioFile, audioBytes, 0644)
|
||||
if err != nil {
|
||||
logger.Error("failed to save audio file: ", err)
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Prompt-Type", "audio/mpeg")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
|
||||
|
||||
// 直接写入完整的音频数据到响应
|
||||
_, err = c.Writer.Write(audioBytes)
|
||||
if err != nil {
|
||||
logger.Error("写入音频数据到响应失败:", err)
|
||||
text = strings.ReplaceAll(text, imageURL, newImgURL)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
// List 获取会话列表
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
logger.Info(h.GetLoginUserId(c))
|
||||
if !h.IsLogin(c) {
|
||||
resp.SUCCESS(c)
|
||||
return
|
||||
@@ -29,7 +28,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var items = make([]vo.ChatItem, 0)
|
||||
var chats []model.ChatItem
|
||||
h.DB.Debug().Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
if len(chats) == 0 {
|
||||
resp.SUCCESS(c, items)
|
||||
return
|
||||
@@ -42,9 +41,9 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
modelValues = append(modelValues, chat.Model)
|
||||
}
|
||||
|
||||
var roles []model.ChatApp
|
||||
var roles []model.ChatRole
|
||||
var models []model.ChatModel
|
||||
roleMap := make(map[uint]model.ChatApp)
|
||||
roleMap := make(map[uint]model.ChatRole)
|
||||
modelMap := make(map[string]model.ChatModel)
|
||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||
h.DB.Where("value IN ?", modelValues).Find(&models)
|
||||
@@ -105,6 +104,8 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
var chatIds = make([]string, 0)
|
||||
for _, chat := range chats {
|
||||
chatIds = append(chatIds, chat.ChatId)
|
||||
// 清空会话上下文
|
||||
h.ChatContexts.Delete(chat.ChatId)
|
||||
}
|
||||
err = h.DB.Transaction(func(tx *gorm.DB) error {
|
||||
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
||||
@@ -132,28 +133,20 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
func (h *ChatHandler) History(c *gin.Context) {
|
||||
chatId := c.Query("chat_id") // 会话 ID
|
||||
var items []model.ChatMessage
|
||||
var messages = make([]vo.ChatMessage, 0)
|
||||
var messages = make([]vo.HistoryMessage, 0)
|
||||
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No history message")
|
||||
return
|
||||
} else {
|
||||
for _, item := range items {
|
||||
var v vo.ChatMessage
|
||||
var v vo.HistoryMessage
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 解析内容
|
||||
var content vo.MsgContent
|
||||
err = utils.JsonDecode(item.Content, &content)
|
||||
if err != nil {
|
||||
content.Text = item.Content
|
||||
}
|
||||
v.Content = content
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
v.UpdatedAt = item.UpdatedAt.Unix()
|
||||
messages = append(messages, v)
|
||||
if err == nil {
|
||||
messages = append(messages, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,6 +179,10 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
|
||||
|
||||
// 清空会话上下文
|
||||
h.ChatContexts.Delete(chatId)
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
@@ -205,7 +202,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 填充角色名称
|
||||
var role model.ChatApp
|
||||
var role model.ChatRole
|
||||
res = h.DB.Where("id", chatItem.RoleId).First(&role)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Role not found")
|
||||
|
||||
@@ -26,12 +26,6 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatModelHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/model/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 模型列表
|
||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
var items []model.ChatModel
|
||||
@@ -40,12 +34,10 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
t := c.Query("type")
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
} else {
|
||||
session = session.Where("type", "chat")
|
||||
}
|
||||
|
||||
session = session.Where("open", true)
|
||||
if h.IsLogin(c) && t == "chat" {
|
||||
if h.IsLogin(c) {
|
||||
user, _ := h.GetLoginUser(c)
|
||||
var models []int
|
||||
err := utils.JsonDecode(user.ChatModels, &models)
|
||||
|
||||
@@ -9,7 +9,6 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
@@ -20,31 +19,18 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatAppHandler struct {
|
||||
type ChatRoleHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||
return &ChatAppHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/app/")
|
||||
group.GET("list", h.List)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list/user", h.ListByUser)
|
||||
group.POST("update", h.UpdateApp)
|
||||
}
|
||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// List 获取用户聊天应用列表
|
||||
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
tid := h.GetInt(c, "tid", 0)
|
||||
var roles []model.ChatApp
|
||||
var roles []model.ChatRole
|
||||
session := h.DB.Where("enable", true)
|
||||
if tid > 0 {
|
||||
session = session.Where("tid", tid)
|
||||
@@ -55,9 +41,9 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var roleVos = make([]vo.ChatApp, 0)
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
for _, r := range roles {
|
||||
var v vo.ChatApp
|
||||
var v vo.ChatRole
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
v.Id = r.Id
|
||||
@@ -68,22 +54,20 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
// ListByUser 获取用户添加的角色列表
|
||||
func (h *ChatAppHandler) ListByUser(c *gin.Context) {
|
||||
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
var roles []model.ChatApp
|
||||
var roles []model.ChatRole
|
||||
session := h.DB.Where("enable", true)
|
||||
// 如果用户没登录,则获取所有角色
|
||||
if userId > 0 {
|
||||
var user model.User
|
||||
h.DB.First(&user, userId)
|
||||
var roleKeys []string
|
||||
if user.ChatRoles != "" {
|
||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
}
|
||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
}
|
||||
// 保证用户至少有一个角色可用
|
||||
if len(roleKeys) > 0 {
|
||||
@@ -100,9 +84,9 @@ func (h *ChatAppHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var roleVos = make([]vo.ChatApp, 0)
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
for _, r := range roles {
|
||||
var v vo.ChatApp
|
||||
var v vo.ChatRole
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
v.Id = r.Id
|
||||
@@ -112,8 +96,8 @@ func (h *ChatAppHandler) ListByUser(c *gin.Context) {
|
||||
resp.SUCCESS(c, roleVos)
|
||||
}
|
||||
|
||||
// UpdateApp 更新用户聊天应用
|
||||
func (h *ChatAppHandler) UpdateApp(c *gin.Context) {
|
||||
// UpdateRole 更新用户聊天角色
|
||||
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
@@ -27,27 +27,18 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
|
||||
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ConfigHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/config/")
|
||||
|
||||
// 无需授权的接口
|
||||
group.GET("get", h.Get)
|
||||
group.GET("license", h.License)
|
||||
}
|
||||
|
||||
// Get 获取指定的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("name", key).First(&config)
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]any
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
var value map[string]interface{}
|
||||
err := utils.JsonDecode(config.Config, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -8,37 +8,33 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DallJobHandler struct {
|
||||
BaseHandler
|
||||
dallService *dalle.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
redis *redis.Client
|
||||
dallService *dalle.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *DallJobHandler {
|
||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
|
||||
return &DallJobHandler{
|
||||
dallService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
dallService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -46,90 +42,49 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *DallJobHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/dall/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("models", h.GetModels)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return false
|
||||
}
|
||||
if user.Power < h.App.SysConfig.DallPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
if !h.preCheck(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var data types.DallTask
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceDalle,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,提示词未通过文本审核,请重新输入!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
|
||||
resp.ERROR(c, "模型不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户剩余算力
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
if user.Power < chatModel.Power {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
task := types.DallTask{
|
||||
ClientId: data.ClientId,
|
||||
UserId: uint(userId),
|
||||
ModelId: chatModel.Id,
|
||||
ModelName: chatModel.Value,
|
||||
Image: data.Image,
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
Power: chatModel.Power,
|
||||
Power: h.App.SysConfig.DallPower,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
}
|
||||
job := model.DallJob{
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Power: chatModel.Power,
|
||||
Power: task.Power,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
}
|
||||
res := h.DB.Create(&job)
|
||||
@@ -140,17 +95,6 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
|
||||
task.Id = job.Id
|
||||
h.dallService.PushTask(task)
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(user.Id, chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with decrease power: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -266,25 +210,3 @@ func (h *DallJobHandler) Publish(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *DallJobHandler) GetModels(c *gin.Context) {
|
||||
var models []model.ChatModel
|
||||
err := h.DB.Where("type", "img").Where("enabled", true).Find(&models).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var modelVos []vo.ChatModel
|
||||
for _, v := range models {
|
||||
var modelVo vo.ChatModel
|
||||
err := utils.CopyObject(v, &modelVo)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
modelVo.Id = v.Id
|
||||
modelVos = append(modelVos, modelVo)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, modelVos)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
@@ -30,9 +29,9 @@ import (
|
||||
|
||||
type FunctionHandler struct {
|
||||
BaseHandler
|
||||
config types.ApiConfig
|
||||
uploadManager *oss.UploaderManager
|
||||
dallService *dalle.Service
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewFunctionHandler(
|
||||
@@ -40,30 +39,18 @@ func NewFunctionHandler(
|
||||
db *gorm.DB,
|
||||
config *types.AppConfig,
|
||||
manager *oss.UploaderManager,
|
||||
dallService *dalle.Service,
|
||||
userService *service.UserService) *FunctionHandler {
|
||||
dallService *dalle.Service) *FunctionHandler {
|
||||
return &FunctionHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: server,
|
||||
DB: db,
|
||||
},
|
||||
config: config.ApiConfig,
|
||||
uploadManager: manager,
|
||||
dallService: dallService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *FunctionHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/function/")
|
||||
group.GET("list", h.List)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.POST("weibo", h.WeiBo)
|
||||
group.POST("zaobao", h.ZaoBao)
|
||||
group.POST("dalle3", h.Dall3)
|
||||
}
|
||||
|
||||
type resVo struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@@ -115,10 +102,16 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL)
|
||||
if h.config.Token == "" {
|
||||
resp.ERROR(c, "无效的 API Token")
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
|
||||
var res resVo
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer geekai-plus").
|
||||
SetHeader("AppId", h.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||
@@ -148,17 +141,19 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL)
|
||||
var res resVo
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer geekai-plus").
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||
if h.config.Token == "" {
|
||||
resp.ERROR(c, "无效的 API Token")
|
||||
return
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, fmt.Sprintf("%v", r.Err))
|
||||
|
||||
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
|
||||
var res resVo
|
||||
r, err := req.C().R().
|
||||
SetHeader("AppId", h.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -172,7 +167,7 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
||||
for _, v := range res.Data.Items {
|
||||
builder = append(builder, v.Title)
|
||||
}
|
||||
builder = append(builder, res.Data.Title)
|
||||
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
|
||||
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||||
}
|
||||
|
||||
@@ -189,70 +184,48 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
res := h.DB.Where("type = ?", "img").Where("enabled", true).First(&chatModel)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "没有找到可用的AI绘图模型!")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("绘画参数:%+v", params)
|
||||
var user model.User
|
||||
res = h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "当前用户不存在!")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < chatModel.Power {
|
||||
resp.ERROR(c, "创建绘图任务失败,算力不足")
|
||||
if user.Power < h.App.SysConfig.DallPower {
|
||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// create dall task
|
||||
prompt := utils.InterfaceToString(params["prompt"])
|
||||
task := types.DallTask{
|
||||
UserId: user.Id,
|
||||
Prompt: prompt,
|
||||
ModelId: chatModel.Id,
|
||||
ModelName: chatModel.Value,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
Style: "vivid",
|
||||
Power: chatModel.Power,
|
||||
}
|
||||
job := model.DallJob{
|
||||
UserId: user.Id,
|
||||
Prompt: prompt,
|
||||
Power: chatModel.Power,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
UserId: user.Id,
|
||||
Prompt: prompt,
|
||||
Power: h.App.SysConfig.DallPower,
|
||||
}
|
||||
err := h.DB.Create(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "创建绘图任务失败:"+err.Error())
|
||||
res = h.DB.Create(&job)
|
||||
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
task.Id = job.Id
|
||||
content, err := h.dallService.Image(task, true)
|
||||
content, err := h.dallService.Image(types.DallTask{
|
||||
Id: job.Id,
|
||||
UserId: user.Id,
|
||||
Prompt: job.Prompt,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
Style: "vivid",
|
||||
Power: job.Power,
|
||||
}, true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务执行失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: task.ModelName,
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,18 +8,14 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// InviteHandler 用户邀请
|
||||
@@ -31,23 +27,6 @@ func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
||||
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *InviteHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/invite/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("hits", h.Hits)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("code", h.Code)
|
||||
group.GET("list", h.List)
|
||||
group.GET("stats", h.Stats)
|
||||
group.GET("rules", h.Rules)
|
||||
}
|
||||
}
|
||||
|
||||
// Code 获取当前用户邀请码
|
||||
func (h *InviteHandler) Code(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
@@ -86,34 +65,21 @@ func (h *InviteHandler) List(c *gin.Context) {
|
||||
var total int64
|
||||
session.Model(&model.InviteLog{}).Count(&total)
|
||||
var items []model.InviteLog
|
||||
offset := (page - 1) * pageSize
|
||||
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
userIds := make([]uint, 0)
|
||||
for _, item := range items {
|
||||
userIds = append(userIds, item.UserId)
|
||||
}
|
||||
userMap := make(map[uint]model.User)
|
||||
var users []model.User
|
||||
h.DB.Model(&model.User{}).Where("id IN (?)", userIds).Find(&users)
|
||||
for _, user := range users {
|
||||
userMap[user.Id] = user
|
||||
}
|
||||
|
||||
var list = make([]vo.InviteLog, 0)
|
||||
for _, item := range items {
|
||||
var v vo.InviteLog
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err != nil {
|
||||
continue
|
||||
offset := (page - 1) * pageSize
|
||||
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var v vo.InviteLog
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err == nil {
|
||||
v.Id = item.Id
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
list = append(list, v)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
v.Avatar = userMap[item.UserId].Avatar
|
||||
list = append(list, v)
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||
}
|
||||
@@ -124,89 +90,3 @@ func (h *InviteHandler) Hits(c *gin.Context) {
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Stats 获取邀请统计
|
||||
func (h *InviteHandler) Stats(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
// 获取邀请码
|
||||
var inviteCode model.InviteCode
|
||||
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "邀请码不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 统计累计邀请数
|
||||
var totalInvite int64
|
||||
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ?", userId).Count(&totalInvite)
|
||||
|
||||
// 统计今日邀请数
|
||||
today := time.Now().Format("2006-01-02")
|
||||
var todayInvite int64
|
||||
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ? AND DATE(created_at) = ?", userId, today).Count(&todayInvite)
|
||||
|
||||
// 获取系统配置中的邀请奖励
|
||||
var config model.Config
|
||||
var invitePower int = 200 // 默认值
|
||||
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
|
||||
var configMap map[string]any
|
||||
if utils.JsonDecode(config.Value, &configMap) == nil {
|
||||
if power, ok := configMap["invite_power"].(float64); ok {
|
||||
invitePower = int(power)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算获得奖励总数
|
||||
rewardTotal := int(totalInvite) * invitePower
|
||||
|
||||
// 构建邀请链接
|
||||
inviteLink := fmt.Sprintf("%s/register?invite=%s", h.App.Config.StaticUrl, inviteCode.Code)
|
||||
|
||||
stats := vo.InviteStats{
|
||||
InviteCount: int(totalInvite),
|
||||
RewardTotal: rewardTotal,
|
||||
TodayInvite: int(todayInvite),
|
||||
InviteCode: inviteCode.Code,
|
||||
InviteLink: inviteLink,
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, stats)
|
||||
}
|
||||
|
||||
// Rules 获取奖励规则
|
||||
func (h *InviteHandler) Rules(c *gin.Context) {
|
||||
// 获取系统配置中的邀请奖励
|
||||
var config model.Config
|
||||
var invitePower int = 200 // 默认值
|
||||
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
|
||||
var configMap map[string]interface{}
|
||||
if utils.JsonDecode(config.Value, &configMap) == nil {
|
||||
if power, ok := configMap["invite_power"].(float64); ok {
|
||||
invitePower = int(power)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rules := []vo.RewardRule{
|
||||
{
|
||||
Id: 1,
|
||||
Title: "好友注册",
|
||||
Desc: "好友通过邀请链接成功注册",
|
||||
Icon: "icon-user-fill",
|
||||
Color: "#1989fa",
|
||||
Reward: invitePower,
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Title: "好友首次充值",
|
||||
Desc: "好友首次充值任意金额",
|
||||
Icon: "icon-money",
|
||||
Color: "#07c160",
|
||||
Reward: invitePower * 2, // 假设首次充值奖励是注册奖励的2倍
|
||||
},
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, rules)
|
||||
}
|
||||
|
||||
@@ -1,469 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/moderation"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JimengHandler 即梦AI处理器
|
||||
type JimengHandler struct {
|
||||
BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
// NewJimengHandler 创建即梦AI处理器
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService, moderationManager *moderation.ServiceManager) *JimengHandler {
|
||||
return &JimengHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由,新增统一任务接口
|
||||
func (h *JimengHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/jimeng/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("task", h.CreateTask)
|
||||
group.GET("power-config", h.GetPowerConfig)
|
||||
group.POST("jobs", h.Jobs)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("retry", h.Retry)
|
||||
}
|
||||
}
|
||||
|
||||
// JimengTaskRequest 统一任务请求结构体
|
||||
// 支持所有生图和生成视频类型
|
||||
type JimengTaskRequest struct {
|
||||
TaskType string `json:"task_type" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
ImageInput string `json:"image_input"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
TemplateId string `json:"template_id"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
// CreateTask 统一任务创建接口
|
||||
func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
var req JimengTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceJiMeng,
|
||||
Input: req.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||
resp.ERROR(c, "提示词不能为空")
|
||||
return
|
||||
}
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
var powerCost int
|
||||
var taskType model.JMTaskType
|
||||
var params map[string]any
|
||||
var reqKey string
|
||||
var modelName string
|
||||
|
||||
switch req.TaskType {
|
||||
case "text_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToImage)
|
||||
taskType = model.JMTaskTypeTextToImage
|
||||
reqKey = jimeng.ReqKeyTextToImage
|
||||
modelName = "即梦文生图"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
case "image_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToImage)
|
||||
taskType = model.JMTaskTypeImageToImage
|
||||
reqKey = jimeng.ReqKeyImageToImagePortrait
|
||||
modelName = "即梦图生图"
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
}
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
case "image_edit":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEdit)
|
||||
taskType = model.JMTaskTypeImageEdit
|
||||
reqKey = jimeng.ReqKeyImageEdit
|
||||
modelName = "即梦图像编辑"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
params["image_urls"] = []string{req.ImageInput}
|
||||
case "image_effects":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||
taskType = model.JMTaskTypeImageEffects
|
||||
reqKey = jimeng.ReqKeyImageEffects
|
||||
modelName = "即梦图像特效"
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input1": req.ImageInput,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
case "text_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToVideo)
|
||||
taskType = model.JMTaskTypeTextToVideo
|
||||
reqKey = jimeng.ReqKeyTextToVideo
|
||||
modelName = "即梦文生视频"
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
case "image_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToVideo)
|
||||
taskType = model.JMTaskTypeImageToVideo
|
||||
reqKey = jimeng.ReqKeyImageToVideo
|
||||
modelName = "即梦图生视频"
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
default:
|
||||
resp.ERROR(c, "不支持的任务类型")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: taskType,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: reqKey,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
var req struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Filter string `json:"filter"`
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var jobs []model.JimengJob
|
||||
var total int64
|
||||
query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
||||
|
||||
switch req.Filter {
|
||||
case "image":
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToImage,
|
||||
model.JMTaskTypeImageToImage,
|
||||
model.JMTaskTypeImageEdit,
|
||||
model.JMTaskTypeImageEffects,
|
||||
})
|
||||
case "video":
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToVideo,
|
||||
model.JMTaskTypeImageToVideo,
|
||||
})
|
||||
}
|
||||
|
||||
if len(req.Ids) > 0 {
|
||||
query = query.Where("id IN (?)", req.Ids)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := query.Order("updated_at DESC").Offset(offset).Limit(req.PageSize).Find(&jobs).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 填充 VO
|
||||
var jobVos []vo.JimengJob
|
||||
for _, job := range jobs {
|
||||
var jobVo vo.JimengJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
jobVo.CreatedAt = job.CreatedAt.Unix()
|
||||
jobVos = append(jobVos, jobVo)
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, req.Page, req.PageSize, jobVos))
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取任务,判断状态
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
if job.UserId != user.Id {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
|
||||
// 正在运行中的任务不能删除
|
||||
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
|
||||
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
|
||||
return
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Where("id = ? AND user_id = ?", jobId, user.Id).Delete(&model.JimengJob{}).Error; err != nil {
|
||||
logger.Errorf("delete jimeng job failed: %v", err)
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 失败任务删除后退回算力
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "退回算力失败")
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// Retry 重试任务
|
||||
func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否存在且属于当前用户
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if job.UserId != userId {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败的任务才能重试
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能重试")
|
||||
return
|
||||
}
|
||||
|
||||
// 重置任务状态
|
||||
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JMTaskStatusInQueue, ""); err != nil {
|
||||
logger.Errorf("reset job status failed: %v", err)
|
||||
resp.ERROR(c, "重置任务状态失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 重新推送到队列
|
||||
if err := h.jimengService.PushTaskToQueue(uint(jobId)); err != nil {
|
||||
logger.Errorf("push retry task to queue failed: %v", err)
|
||||
resp.ERROR(c, "推送重试任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
|
||||
}
|
||||
|
||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
config := h.App.SysConfig.Jimeng
|
||||
|
||||
switch taskType {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
return config.Power.TextToImage
|
||||
case model.JMTaskTypeImageToImage:
|
||||
return config.Power.ImageToImage
|
||||
case model.JMTaskTypeImageEdit:
|
||||
return config.Power.ImageEdit
|
||||
case model.JMTaskTypeImageEffects:
|
||||
return config.Power.ImageEffects
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
return config.Power.TextToVideo
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
return config.Power.ImageToVideo
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
}
|
||||
|
||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||
config := h.App.SysConfig.Jimeng
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"text_to_image": config.Power.TextToImage,
|
||||
"image_to_image": config.Power.ImageToImage,
|
||||
"image_edit": config.Power.ImageEdit,
|
||||
"image_effects": config.Power.ImageEffects,
|
||||
"text_to_video": config.Power.TextToVideo,
|
||||
"image_to_video": config.Power.ImageToVideo,
|
||||
})
|
||||
}
|
||||
@@ -10,13 +10,11 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -36,17 +34,6 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MarkMapHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/markMap/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("gen", h.Generate)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate 生成思维导图
|
||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -108,7 +95,7 @@ func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
|
||||
// 扣减算力
|
||||
if chatModel.Power > 0 {
|
||||
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -26,12 +25,6 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MenuHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/menu/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 数据列表
|
||||
func (h *MenuHandler) List(c *gin.Context) {
|
||||
index := h.GetBool(c, "index")
|
||||
@@ -40,7 +33,7 @@ func (h *MenuHandler) List(c *gin.Context) {
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
session = session.Where("enabled", true)
|
||||
if index {
|
||||
session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs)
|
||||
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
|
||||
}
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
if res.Error == nil {
|
||||
|
||||
@@ -10,11 +10,9 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
@@ -29,20 +27,18 @@ import (
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
mjService *mj.Service
|
||||
snowflake *service.Snowflake
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
mjService *mj.Service
|
||||
snowflake *service.Snowflake
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *MidJourneyHandler {
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
|
||||
return &MidJourneyHandler{
|
||||
snowflake: snowflake,
|
||||
mjService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
snowflake: snowflake,
|
||||
mjService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -50,25 +46,6 @@ func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.S
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MidJourneyHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/mj/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("image", h.Image)
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -76,7 +53,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.Base.MjPower {
|
||||
if user.Power < h.App.SysConfig.MjPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
}
|
||||
@@ -89,6 +66,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
TaskType string `json:"task_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegPrompt string `json:"neg_prompt"`
|
||||
Rate string `json:"rate"`
|
||||
@@ -113,29 +91,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceMJ,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var params = ""
|
||||
if data.Rate != "" && !strings.Contains(params, "--ar") {
|
||||
params += " --ar " + data.Rate
|
||||
@@ -198,6 +153,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
task := types.MjTask{
|
||||
ClientId: data.ClientId,
|
||||
TaskId: taskId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
Prompt: data.Prompt,
|
||||
@@ -205,17 +161,17 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.Base.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: data.TaskType,
|
||||
UserId: uint(userId),
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
||||
Power: h.App.SysConfig.Base.MjPower,
|
||||
Power: h.App.SysConfig.MjPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
opt := "绘图"
|
||||
@@ -251,6 +207,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
type reqVo struct {
|
||||
Index int `json:"index"`
|
||||
ClientId string `json:"client_id"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
@@ -272,21 +229,22 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
taskId, _ := h.snowflake.Next(true)
|
||||
task := types.MjTask{
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskUpscale,
|
||||
UserId: userId,
|
||||
ChannelId: data.ChannelId,
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.Base.MjMode,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskUpscale.String(),
|
||||
UserId: uint(userId),
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
Power: h.App.SysConfig.Base.MjActionPower,
|
||||
Power: h.App.SysConfig.MjActionPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||
@@ -328,21 +286,22 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
taskId, _ := h.snowflake.Next(true)
|
||||
task := types.MjTask{
|
||||
Type: types.TaskVariation,
|
||||
ClientId: data.ClientId,
|
||||
UserId: userId,
|
||||
Index: data.Index,
|
||||
ChannelId: data.ChannelId,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.Base.MjMode,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskVariation.String(),
|
||||
ChannelId: data.ChannelId,
|
||||
UserId: uint(userId),
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
Power: h.App.SysConfig.Base.MjActionPower,
|
||||
Power: h.App.SysConfig.MjActionPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||
@@ -468,7 +427,7 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action).Error
|
||||
err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -9,19 +9,17 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type NetHandler struct {
|
||||
@@ -33,22 +31,6 @@ func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManage
|
||||
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *NetHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/upload")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("", h.Upload)
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
|
||||
// 公开接口,不需要授权
|
||||
h.App.Engine.GET("/api/download", h.Download)
|
||||
}
|
||||
|
||||
func (h *NetHandler) Upload(c *gin.Context) {
|
||||
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||
if err != nil {
|
||||
@@ -64,7 +46,7 @@ func (h *NetHandler) Upload(c *gin.Context) {
|
||||
|
||||
userId := h.GetLoginUserId(c)
|
||||
res := h.DB.Create(&model.File{
|
||||
UserId: uint(userId),
|
||||
UserId: int(userId),
|
||||
Name: file.Name,
|
||||
ObjKey: file.ObjKey,
|
||||
URL: file.URL,
|
||||
@@ -161,15 +143,7 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 使用http.Get下载文件
|
||||
req, err := http.NewRequest("GET", fileUrl, nil)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 模拟浏览器 UA
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
client := &http.Client{}
|
||||
r, err := client.Do(req)
|
||||
r, err := http.Get(fileUrl)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -182,5 +156,6 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
// 将下载的文件内容写入响应
|
||||
_, _ = io.Copy(c.Writer, r.Body)
|
||||
}
|
||||
|
||||
@@ -17,12 +17,10 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
req2 "github.com/imroc/req/v3"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
req2 "github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type Usage struct {
|
||||
@@ -56,16 +54,18 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
input ChatInput,
|
||||
c *gin.Context) error {
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, input, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Since(start))
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
return fmt.Errorf("用户取消了请求:%s", input.Prompt)
|
||||
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
}
|
||||
@@ -88,8 +88,6 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
var function model.Function
|
||||
var toolCall = false
|
||||
var arguments = make([]string, 0)
|
||||
var reasoning = false
|
||||
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -104,14 +102,12 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
if responseBody.Choices[0].Delta.Content == nil &&
|
||||
responseBody.Choices[0].Delta.ToolCalls == nil &&
|
||||
responseBody.Choices[0].Delta.ReasoningContent == "" {
|
||||
if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||
pushMessage(c, "text", "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -139,7 +135,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if res.Error == nil {
|
||||
toolCall = true
|
||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
pushMessage(c, "text", callMsg)
|
||||
utils.SendChunkMsg(ws, callMsg)
|
||||
contents = append(contents, callMsg)
|
||||
}
|
||||
continue
|
||||
@@ -153,80 +149,58 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
// output stopped
|
||||
if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else { // 正常输出结果
|
||||
// 兼容思考过程
|
||||
if responseBody.Choices[0].Delta.ReasoningContent != "" {
|
||||
reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
|
||||
if !reasoning {
|
||||
reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
|
||||
reasoning = true
|
||||
}
|
||||
|
||||
pushMessage(c, "text", reasoningContent)
|
||||
contents = append(contents, reasoningContent)
|
||||
} else if responseBody.Choices[0].Delta.Content != "" {
|
||||
finalContent := responseBody.Choices[0].Delta.Content
|
||||
if reasoning {
|
||||
finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content)
|
||||
reasoning = false
|
||||
}
|
||||
contents = append(contents, utils.InterfaceToString(finalContent))
|
||||
pushMessage(c, "text", finalContent)
|
||||
}
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
utils.SendChunkMsg(ws, responseBody.Choices[0].Delta.Content)
|
||||
}
|
||||
} // end for
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", input.Prompt)
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else {
|
||||
logger.Error("信息读取出错:", err)
|
||||
}
|
||||
}
|
||||
|
||||
if toolCall { // 调用函数完成任务
|
||||
params := make(map[string]any)
|
||||
params := make(map[string]interface{})
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
params["user_id"] = userVo.Id
|
||||
var apiRes types.BizVo
|
||||
r, err := req2.C().R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", function.Token).
|
||||
SetBody(params).Post(function.Action)
|
||||
SetBody(params).
|
||||
SetSuccessResult(&apiRes).Post(function.Action)
|
||||
errMsg := ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else {
|
||||
all, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(all, &apiRes)
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else if apiRes.Code != types.Success {
|
||||
errMsg = apiRes.Message
|
||||
}
|
||||
} else if r.IsErrorState() {
|
||||
errMsg = r.Status
|
||||
}
|
||||
|
||||
if errMsg != "" {
|
||||
errMsg = "调用函数工具出错:" + errMsg
|
||||
if errMsg != "" || apiRes.Code != types.Success {
|
||||
errMsg = "调用函数工具出错:" + apiRes.Message + errMsg
|
||||
contents = append(contents, errMsg)
|
||||
} else {
|
||||
errMsg = utils.InterfaceToString(apiRes.Data)
|
||||
contents = append(contents, errMsg)
|
||||
}
|
||||
pushMessage(c, "text", errMsg)
|
||||
utils.SendChunkMsg(ws, errMsg)
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
usage := Usage{
|
||||
Prompt: input.Prompt,
|
||||
Prompt: prompt,
|
||||
Content: strings.Join(contents, ""),
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
message.Content = usage.Content
|
||||
h.saveChatHistory(c, req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
|
||||
h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
}
|
||||
} else { // 非流式输出
|
||||
var respVo OpenAIResVo
|
||||
@@ -239,10 +213,13 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
return fmt.Errorf("解析响应失败:%v", body)
|
||||
}
|
||||
content := respVo.Choices[0].Message.Content
|
||||
pushMessage(c, "text", content)
|
||||
respVo.Usage.Prompt = input.Prompt
|
||||
if strings.HasPrefix(req.Model, "o1-") {
|
||||
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||
}
|
||||
utils.SendChunkMsg(ws, content)
|
||||
respVo.Usage.Prompt = prompt
|
||||
respVo.Usage.Content = content
|
||||
h.saveChatHistory(c, req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
|
||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -9,12 +9,12 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -28,18 +28,6 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *OrderHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/order/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.GET("query", h.Query)
|
||||
}
|
||||
}
|
||||
|
||||
// List 订单列表
|
||||
func (h *OrderHandler) List(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
@@ -60,21 +48,20 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
order.Id = item.Id
|
||||
order.CreatedAt = item.CreatedAt.Unix()
|
||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||
payChannel, ok := types.PayChannel[item.Channel]
|
||||
payMethod, ok := types.PayMethods[item.PayWay]
|
||||
if !ok {
|
||||
payChannel = item.PayWay
|
||||
payMethod = item.PayWay
|
||||
}
|
||||
payWays, ok := types.PayWays[item.PayWay]
|
||||
payName, ok := types.PayNames[item.PayType]
|
||||
if !ok {
|
||||
payWays = item.PayWay
|
||||
payName = item.PayWay
|
||||
}
|
||||
order.ChannelName = payChannel
|
||||
order.PayName = payWays
|
||||
order.PayMethod = payMethod
|
||||
order.PayName = payName
|
||||
list = append(list, order)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||
@@ -95,8 +82,17 @@ func (h *OrderHandler) Query(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var item model.Order
|
||||
h.DB.Where("order_no = ?", orderNo).First(&item)
|
||||
counter := 0
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
var item model.Order
|
||||
h.DB.Where("order_no = ?", orderNo).First(&item)
|
||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
||||
order.Status = item.Status
|
||||
break
|
||||
}
|
||||
counter++
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/payment"
|
||||
@@ -34,148 +33,52 @@ type PayWay struct {
|
||||
// PaymentHandler 支付服务回调 handler
|
||||
type PaymentHandler struct {
|
||||
BaseHandler
|
||||
alipayService *payment.AlipayService
|
||||
epayService *payment.EPayService
|
||||
wxpayService *payment.WxPayService
|
||||
snowflake *service.Snowflake
|
||||
userService *service.UserService
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
config *types.PaymentConfig
|
||||
alipayService *payment.AlipayService
|
||||
huPiPayService *payment.HuPiPayService
|
||||
geekPayService *payment.GeekPayService
|
||||
wechatPayService *payment.WechatPayService
|
||||
snowflake *service.Snowflake
|
||||
userService *service.UserService
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
signKey string // 用来签名的随机秘钥
|
||||
}
|
||||
|
||||
func NewPaymentHandler(
|
||||
server *core.AppServer,
|
||||
alipayService *payment.AlipayService,
|
||||
geekPayService *payment.EPayService,
|
||||
wxpayService *payment.WxPayService,
|
||||
huPiPayService *payment.HuPiPayService,
|
||||
geekPayService *payment.GeekPayService,
|
||||
wechatPayService *payment.WechatPayService,
|
||||
db *gorm.DB,
|
||||
userService *service.UserService,
|
||||
snowflake *service.Snowflake,
|
||||
fs embed.FS,
|
||||
sysConfig *types.SystemConfig) *PaymentHandler {
|
||||
fs embed.FS) *PaymentHandler {
|
||||
return &PaymentHandler{
|
||||
alipayService: alipayService,
|
||||
epayService: geekPayService,
|
||||
wxpayService: wxpayService,
|
||||
snowflake: snowflake,
|
||||
userService: userService,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
alipayService: alipayService,
|
||||
huPiPayService: huPiPayService,
|
||||
geekPayService: geekPayService,
|
||||
wechatPayService: wechatPayService,
|
||||
snowflake: snowflake,
|
||||
userService: userService,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
BaseHandler: BaseHandler{
|
||||
App: server,
|
||||
DB: db,
|
||||
},
|
||||
config: &sysConfig.Payment,
|
||||
signKey: utils.RandString(32),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PaymentHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/payment/")
|
||||
|
||||
// 支付回调接口(公开)
|
||||
rg.POST("notify/alipay", h.AlipayNotify)
|
||||
rg.GET("notify/epay", h.EPayNotify)
|
||||
rg.POST("notify/wxpay", h.WxpayNotify)
|
||||
|
||||
// 需要用户登录的接口
|
||||
rg.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
rg.POST("create", h.CreateOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) StartSyncOrders() {
|
||||
go func() {
|
||||
for {
|
||||
err := h.SyncOrders()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncOrders 同步订单状态
|
||||
func (h *PaymentHandler) SyncOrders() error {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("同步订单状态发生异常: %v", err)
|
||||
}
|
||||
}()
|
||||
var orders []model.Order
|
||||
err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, order := range orders {
|
||||
time.Sleep(time.Second * 1)
|
||||
//超时15分钟的订单,直接标记为已关闭
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * 5)) {
|
||||
h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true)
|
||||
logger.Errorf("订单超时:%v", order)
|
||||
continue
|
||||
}
|
||||
// 查询订单状态
|
||||
var res payment.OrderInfo
|
||||
switch order.Channel {
|
||||
case payment.PayChannelEpay:
|
||||
res, err = h.epayService.Query(order.OrderNo)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query order info: %v", err)
|
||||
continue
|
||||
}
|
||||
// 微信支付
|
||||
case payment.PayChannelWX:
|
||||
res, err = h.wxpayService.Query(order.OrderNo)
|
||||
logger.Debugf("微信支付订单状态:%+v", res)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query order info: %v", err)
|
||||
continue
|
||||
}
|
||||
case payment.PayChannelAL:
|
||||
res, err = h.alipayService.Query(order.OrderNo)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query order info: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 订单已关闭
|
||||
if res.Closed() {
|
||||
h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{
|
||||
"checked": true,
|
||||
"status": types.OrderPaidFailed,
|
||||
})
|
||||
logger.Errorf("订单已关闭:%v", order)
|
||||
continue
|
||||
}
|
||||
|
||||
// 订单未支付,不处理,继续轮询
|
||||
if !res.Success() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 订单支付成功
|
||||
err = h.paySuccess(res)
|
||||
if err != nil {
|
||||
logger.Errorf("error with deal order: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信
|
||||
Pid int `json:"pid,omitempty"`
|
||||
Device string `json:"device,omitempty"`
|
||||
Domain string `json:"domain,omitempty"` // 支付回调域名
|
||||
Channel string `json:"channel,omitempty"`
|
||||
PayWay string `json:"pay_way"`
|
||||
PayType string `json:"pay_type"`
|
||||
ProductId int `json:"product_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Device string `json:"device"`
|
||||
Host string `json:"host"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -183,7 +86,7 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
}
|
||||
|
||||
var product model.Product
|
||||
err := h.DB.Where("id", data.Pid).First(&product).Error
|
||||
err := h.DB.Where("id", data.ProductId).First(&product).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Product not found")
|
||||
return
|
||||
@@ -194,118 +97,136 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
err = h.DB.Where("id", userId).First(&user).Error
|
||||
err = h.DB.Where("id", data.UserId).First(&user).Error
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
amount := product.Price
|
||||
var payURL, notifyURL string
|
||||
amount := product.Discount
|
||||
var payURL, returnURL, notifyURL string
|
||||
switch data.PayWay {
|
||||
case "wxpay":
|
||||
logger.Debugf("微信支付,%+v", data)
|
||||
data.Channel = payment.PayChannelWX
|
||||
// 优先使用微信官方支付
|
||||
if h.config.WxPay.Enabled {
|
||||
data.Channel = "wxpay"
|
||||
if h.config.WxPay.Domain != "" {
|
||||
data.Domain = h.config.WxPay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/wxpay", data.Domain)
|
||||
payURL, err = h.wxpayService.Pay(payment.PayRequest{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: fmt.Sprintf("%d", int(amount*100)),
|
||||
Subject: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
PayWay: payment.PayWayWX,
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
} else if h.config.Epay.Enabled { // 聚合支付
|
||||
logger.Debugf("聚合支付%+v", data)
|
||||
data.Channel = payment.PayChannelEpay
|
||||
if h.config.Epay.Domain != "" {
|
||||
data.Domain = h.config.Epay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
|
||||
params := payment.PayRequest{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
PayWay: payment.PayWayWX,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
|
||||
r, err := h.epayService.Pay(params)
|
||||
logger.Debugf("请求支付结果,%+v", r)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
} else {
|
||||
payURL = r
|
||||
}
|
||||
} else {
|
||||
resp.ERROR(c, "系统没有配置可用的支付渠道!")
|
||||
return
|
||||
}
|
||||
case "alipay":
|
||||
if h.config.Alipay.Enabled {
|
||||
logger.Debugf("支付宝,%+v", data)
|
||||
data.Channel = payment.PayChannelAL
|
||||
if h.config.Alipay.Domain != "" { // 用于本地调试支付
|
||||
data.Domain = h.config.Alipay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain)
|
||||
money := fmt.Sprintf("%.2f", amount)
|
||||
payURL, err = h.alipayService.Pay(payment.PayRequest{
|
||||
Device: data.Device,
|
||||
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
|
||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
|
||||
}
|
||||
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
|
||||
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
money := fmt.Sprintf("%.2f", amount)
|
||||
if data.Device == "wechat" {
|
||||
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: money,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||
return
|
||||
}
|
||||
} else if h.config.Epay.Enabled { // 聚合支付
|
||||
logger.Debugf("聚合支付,%+v", data)
|
||||
data.Channel = payment.PayChannelEpay
|
||||
if h.config.Epay.Domain != "" {
|
||||
data.Domain = h.config.Epay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
|
||||
params := payment.PayRequest{
|
||||
} else {
|
||||
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
PayWay: data.PayWay,
|
||||
TotalFee: money,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
r, err := h.epayService.Pay(params)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
} else {
|
||||
payURL = r
|
||||
}
|
||||
} else {
|
||||
resp.ERROR(c, "系统没有配置可用的支付渠道!")
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||
return
|
||||
}
|
||||
break
|
||||
case "wechat":
|
||||
if h.App.Config.WechatPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
|
||||
}
|
||||
if data.Device == "wechat" {
|
||||
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: int(amount * 100),
|
||||
Subject: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ClientIP: c.ClientIP(),
|
||||
})
|
||||
} else {
|
||||
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: int(amount * 100),
|
||||
Subject: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
break
|
||||
case "hupi":
|
||||
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
|
||||
}
|
||||
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
|
||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
|
||||
Version: "1.1",
|
||||
TradeOrderId: orderNo,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
Title: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ReturnURL: returnURL,
|
||||
WapName: "GeekAI助手",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
payURL = r.URL
|
||||
break
|
||||
case "geek":
|
||||
if h.App.Config.GeekPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
|
||||
}
|
||||
if h.App.Config.GeekPayConfig.ReturnURL != "" {
|
||||
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
|
||||
}
|
||||
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
|
||||
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
params := payment.GeekPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Method: "web",
|
||||
Name: product.Name,
|
||||
Money: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
Type: data.PayType,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
|
||||
res, err := h.geekPayService.Pay(params)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
payURL = res.PayURL
|
||||
default:
|
||||
resp.ERROR(c, "不支持的支付渠道")
|
||||
return
|
||||
@@ -313,40 +234,43 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
|
||||
// 创建订单
|
||||
remark := types.OrderRemark{
|
||||
Power: product.Power,
|
||||
Name: product.Name,
|
||||
Price: product.Price,
|
||||
Days: product.Days,
|
||||
Power: product.Power,
|
||||
Name: product.Name,
|
||||
Price: product.Price,
|
||||
Discount: product.Discount,
|
||||
}
|
||||
order := model.Order{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
OrderNo: orderNo,
|
||||
Subject: product.Name,
|
||||
Amount: amount,
|
||||
Status: types.OrderNotPaid,
|
||||
PayWay: data.PayWay,
|
||||
Channel: data.Channel,
|
||||
Remark: utils.JsonEncode(remark),
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
ProductId: product.Id,
|
||||
OrderNo: orderNo,
|
||||
Subject: product.Name,
|
||||
Amount: amount,
|
||||
Status: types.OrderNotPaid,
|
||||
PayWay: data.PayWay,
|
||||
PayType: data.PayType,
|
||||
Remark: utils.JsonEncode(remark),
|
||||
}
|
||||
err = h.DB.Create(&order).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with create order: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo})
|
||||
resp.SUCCESS(c, payURL)
|
||||
}
|
||||
|
||||
// 支付成功处理
|
||||
func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
// 异步通知回调公共逻辑
|
||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
var order model.Order
|
||||
err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error
|
||||
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with fetch order: %v", err)
|
||||
}
|
||||
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
// 已支付订单,直接返回
|
||||
if order.Status == types.OrderPaidSuccess {
|
||||
return nil
|
||||
@@ -365,22 +289,20 @@ func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
|
||||
}
|
||||
|
||||
// 增加用户算力
|
||||
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
|
||||
Type: types.PowerRecharge,
|
||||
Model: order.Subject,
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||
CreatedAt: time.Now(),
|
||||
err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
|
||||
Type: types.PowerRecharge,
|
||||
Model: order.PayWay,
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新订单状态
|
||||
order.PayTime = utils.Str2stamp(info.PayTime)
|
||||
order.PayTime = time.Now().Unix()
|
||||
order.Status = types.OrderPaidSuccess
|
||||
order.TradeNo = info.TradeId
|
||||
order.Checked = true
|
||||
err = h.DB.Debug().Updates(&order).Error
|
||||
order.TradeNo = tradeNo
|
||||
err = h.DB.Updates(&order).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with update order info: %v", err)
|
||||
}
|
||||
@@ -395,6 +317,54 @@ func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPayWays 获取支付方式
|
||||
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
||||
payWays := make([]gin.H, 0)
|
||||
if h.App.Config.AlipayConfig.Enabled {
|
||||
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
|
||||
}
|
||||
if h.App.Config.HuPiPayConfig.Enabled {
|
||||
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
|
||||
}
|
||||
if h.App.Config.GeekPayConfig.Enabled {
|
||||
for _, v := range h.App.Config.GeekPayConfig.Methods {
|
||||
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
|
||||
}
|
||||
}
|
||||
if h.App.Config.WechatPayConfig.Enabled {
|
||||
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
|
||||
}
|
||||
resp.SUCCESS(c, payWays)
|
||||
}
|
||||
|
||||
// HuPiPayNotify 虎皮椒支付异步回调
|
||||
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
orderNo := c.Request.Form.Get("trade_order_id")
|
||||
tradeNo := c.Request.Form.Get("open_order_id")
|
||||
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
|
||||
|
||||
if err = h.huPiPayService.Check(orderNo); err != nil {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.notify(orderNo, tradeNo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// AlipayNotify 支付宝支付回调
|
||||
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
@@ -403,15 +373,16 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no"))
|
||||
logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo)
|
||||
if !orderInfo.Success() {
|
||||
logger.Errorf("订单校验失败:%v", err)
|
||||
result := h.alipayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", result.Message)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.paySuccess(orderInfo)
|
||||
tradeNo := c.Request.Form.Get("trade_no")
|
||||
err = h.notify(result.OutTradeNo, tradeNo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
@@ -421,35 +392,28 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// EPayNotify 易支付支付异步回调
|
||||
func (h *PaymentHandler) EPayNotify(c *gin.Context) {
|
||||
// GeekPayNotify 支付异步回调
|
||||
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
||||
var params = make(map[string]string)
|
||||
for k := range c.Request.URL.Query() {
|
||||
params[k] = c.Query(k)
|
||||
}
|
||||
|
||||
logger.Infof("收到易支付订单支付回调:%+v", params)
|
||||
// 检查支付状态, 如果未支付,则返回成功
|
||||
logger.Infof("收到GeekPay订单支付回调:%+v", params)
|
||||
// 检查支付状态
|
||||
if params["trade_status"] != "TRADE_SUCCESS" {
|
||||
c.String(http.StatusOK, "success")
|
||||
return
|
||||
}
|
||||
|
||||
sign := h.epayService.Sign(params)
|
||||
sign := h.geekPayService.Sign(params)
|
||||
if sign != c.Query("sign") {
|
||||
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
// 查询订单状态
|
||||
order, err := h.epayService.Query(params["out_trade_no"])
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.paySuccess(order)
|
||||
err := h.notify(params["out_trade_no"], params["trade_no"])
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
@@ -459,23 +423,26 @@ func (h *PaymentHandler) EPayNotify(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// WxpayNotify 微信商户支付异步回调
|
||||
func (h *PaymentHandler) WxpayNotify(c *gin.Context) {
|
||||
// WechatPayNotify 微信商户支付异步回调
|
||||
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
orderInfo, err := h.wxpayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到微信商号订单支付回调:%+v", orderInfo)
|
||||
if err != nil {
|
||||
logger.Errorf("订单校验失败:%v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"})
|
||||
result := h.wechatPayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到微信商号订单支付回调:%+v", result)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": "FAIL",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.paySuccess(orderInfo)
|
||||
err = h.notify(result.OutTradeNo, result.TradeId)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
|
||||
@@ -9,13 +9,11 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -29,18 +27,6 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PowerLogHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/powerLog/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list", h.List)
|
||||
group.GET("stats", h.Stats)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Model string `json:"model"`
|
||||
@@ -86,45 +72,3 @@ func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||
}
|
||||
|
||||
// Stats 获取用户算力统计
|
||||
func (h *PowerLogHandler) Stats(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
if userId == 0 {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息(包含余额)
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总消费(所有支出记录)
|
||||
var totalConsume int64
|
||||
h.DB.Model(&model.PowerLog{}).
|
||||
Where("user_id", userId).
|
||||
Where("mark", types.PowerSub).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&totalConsume)
|
||||
|
||||
// 计算今日消费
|
||||
today := time.Now().Format("2006-01-02")
|
||||
var todayConsume int64
|
||||
h.DB.Model(&model.PowerLog{}).
|
||||
Where("user_id", userId).
|
||||
Where("mark", types.PowerSub).
|
||||
Where("DATE(created_at) = ?", today).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&todayConsume)
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total": totalConsume,
|
||||
"today": todayConsume,
|
||||
"balance": user.Power,
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, stats)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -26,12 +25,6 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ProductHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/product/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 模型列表
|
||||
func (h *ProductHandler) List(c *gin.Context) {
|
||||
var items []model.Product
|
||||
|
||||
@@ -10,17 +10,15 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/suno"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 提示词生成 handler
|
||||
@@ -28,6 +26,8 @@ import (
|
||||
|
||||
type PromptHandler struct {
|
||||
BaseHandler
|
||||
sunoService *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
@@ -41,20 +41,6 @@ func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PromptHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/prompt/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis)).Use(middleware.RateLimitEvery(h.App.Redis, 30*time.Second))
|
||||
{
|
||||
group.POST("lyric", h.Lyric)
|
||||
group.POST("image", h.Image)
|
||||
group.POST("video", h.Video)
|
||||
group.POST("meta", h.MetaPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Lyric 生成歌词
|
||||
func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -64,7 +50,7 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -82,7 +68,7 @@ func (h *PromptHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -100,7 +86,7 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -135,12 +121,3 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||
}
|
||||
|
||||
func (h *PromptHandler) getPromptModel() string {
|
||||
if h.App.SysConfig.Base.AssistantModelId > 0 {
|
||||
var chatModel model.ChatModel
|
||||
h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel)
|
||||
return chatModel.Value
|
||||
}
|
||||
return "gpt-4o"
|
||||
}
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
@@ -33,23 +23,10 @@ import (
|
||||
|
||||
type RealtimeHandler struct {
|
||||
BaseHandler
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *service.UserService) *RealtimeHandler {
|
||||
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *RealtimeHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/realtime/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.Any("", h.Connection)
|
||||
group.POST("voice", h.VoiceChat)
|
||||
}
|
||||
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
|
||||
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
|
||||
}
|
||||
|
||||
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
@@ -149,86 +126,3 @@ func sendError(ws *websocket.Conn, message string) {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI 实时语音对话,一次性对话
|
||||
func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
||||
var apiKey model.ApiKey
|
||||
err := h.DB.Session(&gorm.Session{}).Where("type", "realtime").Where("enabled", true).First(&apiKey).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("error with fetch OpenAI API KEY:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否还有算力
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("error with fetch user:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.Base.AdvanceVoicePower {
|
||||
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
|
||||
return
|
||||
}
|
||||
|
||||
var response utils.OpenAIResponse
|
||||
client := req.C()
|
||||
if len(apiKey.ProxyURL) > 5 {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, "advanced-voice")
|
||||
r, err := client.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "advanced-voice",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: []interface{}{types.Message{
|
||||
Role: "user",
|
||||
Content: "实时语音通话",
|
||||
}},
|
||||
}).Post(apiURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败:%v", r.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("解析API数据失败:%v, %s", err, string(body)))
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "advanced-voice",
|
||||
Remark: "实时语音通话",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("Response: %v", response.Choices[0].Message.Content)
|
||||
|
||||
// 提取链接
|
||||
re := regexp.MustCompile(`\[(.*?)\]\((.*?)\)`)
|
||||
links := re.FindAllStringSubmatch(response.Choices[0].Message.Content, -1)
|
||||
var url = ""
|
||||
if len(links) > 0 {
|
||||
url = links[0][2]
|
||||
}
|
||||
resp.SUCCESS(c, url)
|
||||
}
|
||||
|
||||
@@ -10,16 +10,14 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RedeemHandler struct {
|
||||
@@ -32,17 +30,6 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
|
||||
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *RedeemHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/redeem/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("verify", h.Verify)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) Verify(c *gin.Context) {
|
||||
var data struct {
|
||||
Code string `json:"code"`
|
||||
@@ -74,7 +61,7 @@ func (h *RedeemHandler) Verify(c *gin.Context) {
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
err := h.userService.IncreasePower(userId, item.Power, model.PowerLog{
|
||||
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
|
||||
Type: types.PowerRedeem,
|
||||
Model: "兑换码",
|
||||
Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]),
|
||||
|
||||
@@ -10,10 +10,8 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
@@ -30,13 +28,12 @@ import (
|
||||
|
||||
type SdJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
sdService *sd.Service
|
||||
uploader *oss.UploaderManager
|
||||
snowflake *service.Snowflake
|
||||
leveldb *store.LevelDB
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
redis *redis.Client
|
||||
sdService *sd.Service
|
||||
uploader *oss.UploaderManager
|
||||
snowflake *service.Snowflake
|
||||
leveldb *store.LevelDB
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewSdJobHandler(app *core.AppServer,
|
||||
@@ -45,15 +42,13 @@ func NewSdJobHandler(app *core.AppServer,
|
||||
manager *oss.UploaderManager,
|
||||
snowflake *service.Snowflake,
|
||||
userService *service.UserService,
|
||||
levelDB *store.LevelDB,
|
||||
moderationManager *moderation.ServiceManager) *SdJobHandler {
|
||||
levelDB *store.LevelDB) *SdJobHandler {
|
||||
return &SdJobHandler{
|
||||
sdService: service,
|
||||
uploader: manager,
|
||||
snowflake: snowflake,
|
||||
leveldb: levelDB,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
sdService: service,
|
||||
uploader: manager,
|
||||
snowflake: snowflake,
|
||||
leveldb: levelDB,
|
||||
userService: userService,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -61,23 +56,6 @@ func NewSdJobHandler(app *core.AppServer,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *SdJobHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/sd/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -85,7 +63,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.Base.SdPower {
|
||||
if user.Power < h.App.SysConfig.SdPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
}
|
||||
@@ -106,29 +84,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceSD,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if data.Width <= 0 {
|
||||
data.Width = 512
|
||||
}
|
||||
@@ -147,7 +102,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
if data.Sampler == "" {
|
||||
data.Sampler = "Euler a"
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
taskId, err := h.snowflake.Next(true)
|
||||
@@ -157,7 +111,8 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
task := types.SdTask{
|
||||
Type: types.TaskImage,
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskImage,
|
||||
Params: types.SdTaskParams{
|
||||
TaskId: taskId,
|
||||
Prompt: data.Prompt,
|
||||
@@ -176,18 +131,18 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
HdSteps: data.HdSteps,
|
||||
},
|
||||
UserId: userId,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
UserId: uint(userId),
|
||||
UserId: userId,
|
||||
Type: types.TaskImage.String(),
|
||||
TaskId: taskId,
|
||||
Params: utils.JsonEncode(task.Params),
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Prompt: data.Prompt,
|
||||
Progress: 0,
|
||||
Power: h.App.SysConfig.Base.SdPower,
|
||||
Power: h.App.SysConfig.SdPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
res := h.DB.Create(&job)
|
||||
@@ -318,7 +273,7 @@ func (h *SdJobHandler) Publish(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
|
||||
err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action).Error
|
||||
err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -24,31 +24,24 @@ const CodeStorePrefix = "/verify/codes/"
|
||||
|
||||
type SmsHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
sms *sms.SmsManager
|
||||
smtp *service.SmtpService
|
||||
captchaService *service.CaptchaService
|
||||
redis *redis.Client
|
||||
sms *sms.ServiceManager
|
||||
smtp *service.SmtpService
|
||||
captcha *service.CaptchaService
|
||||
}
|
||||
|
||||
func NewSmsHandler(
|
||||
app *core.AppServer,
|
||||
client *redis.Client,
|
||||
sms *sms.SmsManager,
|
||||
sms *sms.ServiceManager,
|
||||
smtp *service.SmtpService,
|
||||
captcha *service.CaptchaService) *SmsHandler {
|
||||
return &SmsHandler{
|
||||
redis: client,
|
||||
sms: sms,
|
||||
captchaService: captcha,
|
||||
smtp: smtp,
|
||||
BaseHandler: BaseHandler{App: app}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *SmsHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/sms/")
|
||||
// 无需授权的接口
|
||||
group.POST("code", h.SendCode)
|
||||
redis: client,
|
||||
sms: sms,
|
||||
captcha: captcha,
|
||||
smtp: smtp,
|
||||
BaseHandler: BaseHandler{App: app}}
|
||||
}
|
||||
|
||||
// SendCode 发送验证码
|
||||
@@ -63,12 +56,12 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if h.captchaService.GetConfig().Enabled {
|
||||
if h.App.SysConfig.EnabledVerify {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captchaService.SlideCheck(data)
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captchaService.Check(data)
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
@@ -79,14 +72,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
code := utils.RandomNumber(6)
|
||||
var err error
|
||||
if strings.Contains(data.Receiver, "@") { // email
|
||||
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") {
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
|
||||
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||
return
|
||||
}
|
||||
// 检查邮箱后缀是否在白名单
|
||||
if len(h.App.SysConfig.Base.EmailWhiteList) > 0 {
|
||||
if len(h.App.SysConfig.EmailWhiteList) > 0 {
|
||||
inWhiteList := false
|
||||
for _, suffix := range h.App.SysConfig.Base.EmailWhiteList {
|
||||
for _, suffix := range h.App.SysConfig.EmailWhiteList {
|
||||
if strings.HasSuffix(data.Receiver, suffix) {
|
||||
inWhiteList = true
|
||||
break
|
||||
@@ -99,7 +92,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
}
|
||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||
} else {
|
||||
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") {
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
||||
resp.ERROR(c, "系统已禁用手机号注册!")
|
||||
return
|
||||
}
|
||||
@@ -118,5 +111,9 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
if h.App.Debug {
|
||||
resp.SUCCESS(c, code)
|
||||
} else {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,65 +10,42 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/suno"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SunoHandler struct {
|
||||
BaseHandler
|
||||
sunoService *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
sunoService *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *SunoHandler {
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
|
||||
return &SunoHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
sunoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *SunoHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/suno/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("play", h.Play)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("create", h.Create)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("detail", h.Detail)
|
||||
sunoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Instrumental bool `json:"instrumental"`
|
||||
Lyrics string `json:"lyrics"`
|
||||
@@ -87,36 +64,13 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceSuno,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.Base.SunoPower {
|
||||
if user.Power < h.App.SysConfig.SunoPower {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
@@ -136,6 +90,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
task := types.SunoTask{
|
||||
ClientId: data.ClientId,
|
||||
UserId: int(h.GetLoginUserId(c)),
|
||||
Type: data.Type,
|
||||
Title: data.Title,
|
||||
@@ -143,7 +98,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
RefSongId: data.RefSongId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Prompt: data.Prompt,
|
||||
Lyrics: data.Lyrics,
|
||||
Tags: data.Tags,
|
||||
Model: data.Model,
|
||||
Instrumental: data.Instrumental,
|
||||
@@ -153,7 +107,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
// 插入数据库
|
||||
job := model.SunoJob{
|
||||
UserId: uint(task.UserId),
|
||||
UserId: task.UserId,
|
||||
Prompt: data.Prompt,
|
||||
Instrumental: data.Instrumental,
|
||||
ModelName: data.Model,
|
||||
@@ -164,7 +118,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
RefSongId: data.RefSongId,
|
||||
RefTaskId: data.RefTaskId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Power: h.App.SysConfig.Base.SunoPower,
|
||||
Power: h.App.SysConfig.SunoPower,
|
||||
SongId: utils.RandString(32),
|
||||
}
|
||||
if data.Lyrics != "" {
|
||||
@@ -261,8 +215,8 @@ func (h *SunoHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败或者已完成的任务可以删除
|
||||
if !(job.Progress == service.FailTaskProgress || job.Progress == 100) {
|
||||
// 只有失败,或者超时的任务才能删除
|
||||
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
|
||||
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,36 +1,21 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/service"
|
||||
"geekai/service/payment"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type TestHandler struct {
|
||||
App *core.AppServer
|
||||
db *gorm.DB
|
||||
snowflake *service.Snowflake
|
||||
js *payment.EPayService
|
||||
js *payment.GeekPayService
|
||||
}
|
||||
|
||||
func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler {
|
||||
return &TestHandler{App: app, db: db, snowflake: snowflake, js: js}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *TestHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/test/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.Any("sse", h.PostTest, h.SseTest)
|
||||
}
|
||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
|
||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||
}
|
||||
|
||||
func (h *TestHandler) SseTest(c *gin.Context) {
|
||||
|
||||
@@ -8,17 +8,15 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/imroc/req/v3"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -34,12 +32,9 @@ type UserHandler struct {
|
||||
BaseHandler
|
||||
searcher *xdb.Searcher
|
||||
redis *redis.Client
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
captchaService *service.CaptchaService
|
||||
captcha *service.CaptchaService
|
||||
userService *service.UserService
|
||||
wxLoginService *service.WxLoginService
|
||||
ipSearcher *xdb.Searcher
|
||||
}
|
||||
|
||||
func NewUserHandler(
|
||||
@@ -47,48 +42,16 @@ func NewUserHandler(
|
||||
db *gorm.DB,
|
||||
searcher *xdb.Searcher,
|
||||
client *redis.Client,
|
||||
levelDB *store.LevelDB,
|
||||
captcha *service.CaptchaService,
|
||||
userService *service.UserService,
|
||||
wxLoginService *service.WxLoginService,
|
||||
ipSearcher *xdb.Searcher,
|
||||
licenseService *service.LicenseService) *UserHandler {
|
||||
return &UserHandler{
|
||||
BaseHandler: BaseHandler{DB: db, App: app},
|
||||
searcher: searcher,
|
||||
redis: client,
|
||||
levelDB: levelDB,
|
||||
captchaService: captcha,
|
||||
captcha: captcha,
|
||||
licenseService: licenseService,
|
||||
userService: userService,
|
||||
wxLoginService: wxLoginService,
|
||||
ipSearcher: ipSearcher,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *UserHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/user/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.POST("register", h.Register)
|
||||
group.POST("login", h.Login)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
group.GET("login/qrcode", h.GetWxLoginQRCode)
|
||||
group.POST("login/callback", h.WxLoginCallback)
|
||||
group.GET("login/status", h.GetWxLoginState)
|
||||
group.GET("logout", h.Logout)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("session", h.Session)
|
||||
group.GET("profile", h.Profile)
|
||||
group.POST("profile/update", h.ProfileUpdate)
|
||||
group.POST("password", h.UpdatePass)
|
||||
group.POST("bind/mobile", h.BindMobile)
|
||||
group.POST("bind/email", h.BindEmail)
|
||||
group.GET("signin", h.SignIn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,13 +75,12 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 人机验证
|
||||
if h.captchaService.GetConfig().Enabled {
|
||||
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captchaService.SlideCheck(data)
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captchaService.Check(data)
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
@@ -158,8 +120,28 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证邀请码
|
||||
inviteCode := model.InviteCode{}
|
||||
if data.InviteCode != "" {
|
||||
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "无效的邀请码")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
salt := utils.RandString(8)
|
||||
user := model.User{
|
||||
Username: data.Username,
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
}
|
||||
|
||||
// check if the username is existing
|
||||
user := model.User{Username: data.Username, Password: data.Password}
|
||||
var item model.User
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Mobile != "" {
|
||||
@@ -179,19 +161,73 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.createNewUser(user, data.InviteCode)
|
||||
if err != nil {
|
||||
// 被邀请人也获得赠送算力
|
||||
if data.InviteCode != "" {
|
||||
user.Power += h.App.SysConfig.InvitePower
|
||||
}
|
||||
if h.licenseService.GetLicense().Configs.DeCopy {
|
||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
} else {
|
||||
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.doLogin(&user, c.ClientIP())
|
||||
// 记录邀请关系
|
||||
if data.InviteCode != "" {
|
||||
// 增加邀请数量
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||
if h.App.SysConfig.InvitePower > 0 {
|
||||
err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
|
||||
Type: types.PowerInvite,
|
||||
Model: "",
|
||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 添加邀请记录
|
||||
err := tx.Create(&model.InviteLog{
|
||||
InviterId: inviteCode.UserId,
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
InviteCode: inviteCode.Code,
|
||||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
|
||||
// 自动登录创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
|
||||
// 保存到 redis
|
||||
key = fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
@@ -207,12 +243,15 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if h.captchaService.GetConfig().Enabled {
|
||||
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
|
||||
needVerify, err := h.redis.Get(c, verifyKey).Bool()
|
||||
|
||||
if h.App.SysConfig.EnabledVerify && needVerify {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captchaService.SlideCheck(data)
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captchaService.Check(data)
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
@@ -223,28 +262,54 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
var user model.User
|
||||
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||
if res.Error != nil {
|
||||
h.redis.Set(c, verifyKey, true, 0)
|
||||
resp.ERROR(c, "用户名不存在")
|
||||
return
|
||||
}
|
||||
|
||||
password := utils.GenPassword(data.Password, user.Salt)
|
||||
if password != user.Password {
|
||||
h.redis.Set(c, verifyKey, true, 0)
|
||||
resp.ERROR(c, "用户名或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
if !user.Status {
|
||||
if user.Status == false {
|
||||
resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.doLogin(&user, c.ClientIP())
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = c.ClientIP()
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
h.DB.Model(&user).Updates(user)
|
||||
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: c.ClientIP(),
|
||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
||||
})
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
|
||||
// 保存到 redis
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
// 移除登录行为验证码
|
||||
h.redis.Del(c, verifyKey)
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
// Logout 注 销
|
||||
@@ -256,165 +321,134 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// GetWxLoginQRCode 获取微信登录二维码URL
|
||||
func (h *UserHandler) GetWxLoginQRCode(c *gin.Context) {
|
||||
if !h.wxLoginService.GetConfig().Enabled {
|
||||
resp.ERROR(c, "微信登录功能未启用")
|
||||
return
|
||||
}
|
||||
|
||||
if h.wxLoginService.GetConfig().ApiKey == "" {
|
||||
resp.ERROR(c, "微信登录服务令牌未配置")
|
||||
return
|
||||
}
|
||||
|
||||
state := utils.RandString(32)
|
||||
qrCodeURL, err := h.wxLoginService.GetLoginQrCodeUrl(state)
|
||||
// CLogin 第三方登录请求二维码
|
||||
func (h *UserHandler) CLogin(c *gin.Context) {
|
||||
returnURL := h.GetTrim(c, "return_url")
|
||||
var res types.BizVo
|
||||
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
|
||||
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
|
||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"url": qrCodeURL,
|
||||
"state": state,
|
||||
})
|
||||
if res.Code != types.Success {
|
||||
resp.ERROR(c, "error with http response: "+res.Message)
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, res.Data)
|
||||
}
|
||||
|
||||
// 查询微信登录状态
|
||||
func (h *UserHandler) GetWxLoginState(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if state == "" {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
// CLoginCallback 第三方登录回调
|
||||
func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
||||
loginType := c.Query("login_type")
|
||||
code := c.Query("code")
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := c.Query("action")
|
||||
|
||||
status, err := h.wxLoginService.GetLoginStatus(state)
|
||||
var res types.BizVo
|
||||
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
|
||||
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
|
||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if status.Status != service.LoginStatusSuccess {
|
||||
resp.SUCCESS(c, status)
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
||||
return
|
||||
}
|
||||
|
||||
// 登录成功
|
||||
if res.Code != types.Success {
|
||||
resp.ERROR(c, "error with http response: "+res.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// login successfully
|
||||
data := res.Data.(map[string]interface{})
|
||||
var user model.User
|
||||
h.DB.Where("openid = ?", status.OpenID).First(&user)
|
||||
if user.Id == 0 {
|
||||
// 创建新用户
|
||||
user, err = h.createNewUser(model.User{OpenId: status.OpenID}, "")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
if action == "bind" && userId > 0 {
|
||||
err = h.DB.Where("openid", data["openid"]).First(&user).Error
|
||||
if err == nil {
|
||||
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, err := h.doLogin(&user, c.ClientIP())
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
err = h.DB.Where("id", userId).First(&user).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "绑定用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新用户信息失败,"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"token": ""})
|
||||
return
|
||||
}
|
||||
|
||||
status.Status = service.LoginStatusExpired
|
||||
h.wxLoginService.SetLoginStatus(state, *status)
|
||||
|
||||
status.Status = service.LoginStatusSuccess
|
||||
status.Token = token
|
||||
resp.SUCCESS(c, status)
|
||||
}
|
||||
|
||||
// createNewUser 创建新用户
|
||||
func (h *UserHandler) createNewUser(user model.User, inviteCode string) (model.User, error) {
|
||||
if user.OpenId != "" {
|
||||
user.Platform = "wechat"
|
||||
user.Nickname = fmt.Sprintf("微信用户@%d", utils.RandomNumber(6))
|
||||
user.Username = fmt.Sprintf("wx@%d", utils.RandomNumber(8))
|
||||
user.Password = "geekai123"
|
||||
} else {
|
||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
if user.Username == "" || user.Password == "" {
|
||||
return user, fmt.Errorf("用户名或密码不能为空")
|
||||
}
|
||||
}
|
||||
|
||||
salt := utils.RandString(8)
|
||||
user.Salt = salt
|
||||
user.Password = utils.GenPassword(user.Password, salt)
|
||||
user.Avatar = "/images/avatar/user.png"
|
||||
user.Status = true
|
||||
user.ChatRoles = utils.JsonEncode([]string{"gpt"})
|
||||
user.ChatConfig = "{}"
|
||||
user.ChatModels = "{}"
|
||||
user.Power = h.App.SysConfig.Base.InitPower
|
||||
|
||||
// 创建用户
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// 记录邀请关系
|
||||
if inviteCode != "" {
|
||||
inviteCode := model.InviteCode{}
|
||||
err := h.DB.Where("code = ?", inviteCode).First(&inviteCode).Error
|
||||
if err != nil {
|
||||
return user, fmt.Errorf("无效的邀请码")
|
||||
session := gin.H{}
|
||||
tx := h.DB.Where("openid", data["openid"]).First(&user)
|
||||
if tx.Error != nil {
|
||||
// create new user
|
||||
var totalUser int64
|
||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||
return
|
||||
}
|
||||
|
||||
// 增加邀请数量
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", inviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||
if h.App.SysConfig.Base.InvitePower > 0 {
|
||||
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{
|
||||
Type: types.PowerInvite,
|
||||
Model: "Invite",
|
||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return user, err
|
||||
}
|
||||
|
||||
// 添加邀请记录
|
||||
err = tx.Create(&model.InviteLog{
|
||||
InviterId: inviteCode.UserId,
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
InviteCode: inviteCode.Code,
|
||||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return user, err
|
||||
}
|
||||
salt := utils.RandString(8)
|
||||
password := fmt.Sprintf("%d", utils.RandomNumber(8))
|
||||
user = model.User{
|
||||
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
|
||||
Password: utils.GenPassword(password, salt),
|
||||
Avatar: fmt.Sprintf("%s", data["avatar"]),
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
OpenId: fmt.Sprintf("%s", data["openid"]),
|
||||
Nickname: fmt.Sprintf("%s", data["nickname"]),
|
||||
}
|
||||
|
||||
tx = h.DB.Create(&user)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, "保存数据失败")
|
||||
logger.Error(tx.Error)
|
||||
return
|
||||
}
|
||||
session["username"] = user.Username
|
||||
session["password"] = password
|
||||
} else { // login directly
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = c.ClientIP()
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
h.DB.Model(&user).Updates(user)
|
||||
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: c.ClientIP(),
|
||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
||||
})
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// doLogin 执行登录操作
|
||||
func (h *UserHandler) doLogin(user *model.User, ip string) (string, error) {
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = ip
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
err := h.DB.Model(user).Updates(user).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to update user: %v", err)
|
||||
}
|
||||
|
||||
// 记录登录日志
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: ip,
|
||||
LoginAddress: utils.Ip2Region(h.ipSearcher, ip),
|
||||
})
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
@@ -422,42 +456,17 @@ func (h *UserHandler) doLogin(user *model.User, ip string) (string, error) {
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %v", err)
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存到 redis
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(context.Background(), sessionKey, tokenString, 0).Result(); err != nil {
|
||||
return "", fmt.Errorf("error with save token: %v", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// WxLoginCallback 微信登录回调处理
|
||||
func (h *UserHandler) WxLoginCallback(c *gin.Context) {
|
||||
var data struct {
|
||||
OpenID string `json:"openid"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
key := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if data.OpenID == "" || data.State == "" {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置登录状态
|
||||
status := service.LoginStatus{
|
||||
Status: service.LoginStatusSuccess,
|
||||
OpenID: data.OpenID,
|
||||
}
|
||||
h.wxLoginService.SetLoginStatus(data.State, status)
|
||||
|
||||
resp.SUCCESS(c, status)
|
||||
session["token"] = tokenString
|
||||
resp.SUCCESS(c, session)
|
||||
}
|
||||
|
||||
// Session 获取/验证会话
|
||||
@@ -703,30 +712,3 @@ func (h *UserHandler) BindEmail(c *gin.Context) {
|
||||
_ = h.redis.Del(c, key) // 删除短信验证码
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// SignIn 每日签到
|
||||
func (h *UserHandler) SignIn(c *gin.Context) {
|
||||
// 获取当前日期
|
||||
date := time.Now().Format("2006-01-02")
|
||||
|
||||
// 检查是否已经签到
|
||||
userId := h.GetLoginUserId(c)
|
||||
key := fmt.Sprintf("signin/%d/%s", userId, date)
|
||||
var signIn bool
|
||||
err := h.levelDB.Get(key, &signIn)
|
||||
if err == nil && signIn {
|
||||
resp.ERROR(c, "今日已签到,请明日再来!")
|
||||
return
|
||||
}
|
||||
|
||||
// 签到
|
||||
h.levelDB.Put(key, true)
|
||||
if h.App.SysConfig.Base.DailyPower > 0 {
|
||||
h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{
|
||||
Type: types.PowerSignIn,
|
||||
Model: "SignIn",
|
||||
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower),
|
||||
})
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -10,61 +10,42 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/video"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VideoHandler struct {
|
||||
BaseHandler
|
||||
videoService *video.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
videoService *video.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *VideoHandler {
|
||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
|
||||
return &VideoHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
videoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *VideoHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/video/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("luma/create", h.LumaCreate)
|
||||
group.POST("keling/create", h.KeLingCreate)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
videoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImg string `json:"first_frame_img,omitempty"`
|
||||
EndFrameImg string `json:"end_frame_img,omitempty"`
|
||||
@@ -75,34 +56,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
// 检查 Prompt 长度
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceVideo,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -110,31 +63,37 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.Base.LumaPower {
|
||||
if user.Power < h.App.SysConfig.LumaPower {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
}
|
||||
|
||||
userId := int(h.GetLoginUserId(c))
|
||||
params := types.LumaVideoParams{
|
||||
params := types.VideoParams{
|
||||
PromptOptimize: data.ExpandPrompt,
|
||||
Loop: data.Loop,
|
||||
StartImgURL: data.FirstFrameImg,
|
||||
EndImgURL: data.EndFrameImg,
|
||||
}
|
||||
task := types.VideoTask{
|
||||
ClientId: data.ClientId,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
UserId: uint(userId),
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Power: h.App.SysConfig.Base.LumaPower,
|
||||
Power: h.App.SysConfig.LumaPower,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
}
|
||||
tx := h.DB.Create(&job)
|
||||
@@ -160,117 +119,20 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
Channel string `json:"channel"`
|
||||
TaskType string `json:"task_type"` // 任务类型: text2video/image2video
|
||||
Model string `json:"model"` // 模型: kling-v1-5,kling-v1-6
|
||||
Prompt string `json:"prompt"` // 视频描述
|
||||
NegPrompt string `json:"negative_prompt"` // 负面提示词
|
||||
CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1)
|
||||
Mode string `json:"mode"` // 生成模式: std/pro
|
||||
AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1
|
||||
Duration string `json:"duration"` // 视频时长: 5/10
|
||||
CameraControl types.CameraControl `json:"camera_control"` // 摄像机控制
|
||||
Image string `json:"image"` // 参考图片URL(image2video)
|
||||
ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video)
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 计算当前任务所需算力
|
||||
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
|
||||
power := h.App.SysConfig.Base.KeLingPowers[key]
|
||||
if power == 0 {
|
||||
resp.ERROR(c, "当前模型暂不支持")
|
||||
return
|
||||
}
|
||||
if user.Power < power {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
}
|
||||
|
||||
userId := int(h.GetLoginUserId(c))
|
||||
params := types.KeLingVideoParams{
|
||||
TaskType: data.TaskType,
|
||||
Model: data.Model,
|
||||
Prompt: data.Prompt,
|
||||
NegPrompt: data.NegPrompt,
|
||||
CfgScale: data.CfgScale,
|
||||
Mode: data.Mode,
|
||||
AspectRatio: data.AspectRatio,
|
||||
Duration: data.Duration,
|
||||
CameraControl: data.CameraControl,
|
||||
Image: data.Image,
|
||||
ImageTail: data.ImageTail,
|
||||
}
|
||||
task := types.VideoTask{
|
||||
UserId: userId,
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
Channel: data.Channel,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
UserId: uint(userId),
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Power: power,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
}
|
||||
tx := h.DB.Create(&job)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, tx.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
task.Id = job.Id
|
||||
h.videoService.PushTask(task)
|
||||
|
||||
// update user's power
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "keling",
|
||||
Remark: fmt.Sprintf("keling 文生视频,任务ID:%d", job.Id),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *VideoHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
t := c.Query("type")
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
all := h.GetBool(c, "all")
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
}
|
||||
if all {
|
||||
session = session.Where("publish", 0).Where("progress", 100)
|
||||
} else {
|
||||
session = session.Where("user_id", userId)
|
||||
session = session.Where("user_id", h.GetLoginUserId(c))
|
||||
}
|
||||
// 统计总数
|
||||
var total int64
|
||||
@@ -299,33 +161,6 @@ func (h *VideoHandler) List(c *gin.Context) {
|
||||
if item.VideoURL == "" {
|
||||
item.VideoURL = v.WaterURL
|
||||
}
|
||||
// 解析任务详情
|
||||
if item.Type == types.VideoKeLing {
|
||||
task := types.VideoTask{}
|
||||
err = utils.JsonDecode(v.TaskInfo, &task)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var params types.KeLingVideoParams
|
||||
err = utils.JsonDecode(utils.JsonEncode(task.Params), ¶ms)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
item.RawData = map[string]interface{}{
|
||||
"task_type": params.TaskType,
|
||||
"model": params.Model,
|
||||
"cfg_scale": params.CfgScale,
|
||||
"mode": params.Mode,
|
||||
"aspect_ratio": params.AspectRatio,
|
||||
"duration": params.Duration,
|
||||
"model_name": fmt.Sprintf("%s_%s_%s", params.Model, params.Mode, params.Duration),
|
||||
}
|
||||
|
||||
// 如果视频URL不为空,则设置为生成成功
|
||||
if item.VideoURL != "" {
|
||||
item.Progress = 100
|
||||
}
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
@@ -357,8 +192,6 @@ func (h *VideoHandler) Remove(c *gin.Context) {
|
||||
// 删除文件
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *VideoHandler) Publish(c *gin.Context) {
|
||||
|
||||
150
api/handler/ws_handler.go
Normal file
150
api/handler/ws_handler.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Websocket 连接处理 handler
|
||||
|
||||
type WebsocketHandler struct {
|
||||
BaseHandler
|
||||
wsService *service.WebsocketService
|
||||
chatHandler *ChatHandler
|
||||
}
|
||||
|
||||
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
|
||||
return &WebsocketHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
chatHandler: chatHandler,
|
||||
wsService: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
ws, err := (&websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: strings.Split(clientProtocols, ","),
|
||||
}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
clientId := c.Query("client_id")
|
||||
client := types.NewWsClient(ws, clientId)
|
||||
userId := h.GetLoginUserId(c)
|
||||
if userId == 0 {
|
||||
_ = client.Send([]byte("Invalid user_id"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
_ = client.Send([]byte("Invalid user_id"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
h.wsService.Clients.Put(clientId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||
client.Close()
|
||||
h.wsService.Clients.Delete(clientId)
|
||||
break
|
||||
}
|
||||
|
||||
var message types.InputMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("Receive a message:%+v", message)
|
||||
if message.Type == types.MsgTypePing {
|
||||
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||||
continue
|
||||
}
|
||||
|
||||
// 当前只处理聊天消息,其他消息全部丢弃
|
||||
var chatMessage types.ChatMessage
|
||||
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
|
||||
if err != nil || message.Channel != types.ChChat {
|
||||
logger.Warnf("invalid message body:%+v", message.Body)
|
||||
continue
|
||||
}
|
||||
var chatRole model.ChatRole
|
||||
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
|
||||
continue
|
||||
}
|
||||
// if the role bind a model_id, use role's bind model_id
|
||||
if chatRole.ModelId > 0 {
|
||||
chatMessage.RoleId = chatRole.ModelId
|
||||
}
|
||||
// get model info
|
||||
var chatModel model.ChatModel
|
||||
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
|
||||
if err != nil || chatModel.Enabled == false {
|
||||
utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
|
||||
continue
|
||||
}
|
||||
|
||||
session := &types.ChatSession{
|
||||
ClientIP: c.ClientIP(),
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
// use old chat data override the chat model and role ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
chatModel.Id = chat.ModelId
|
||||
chatMessage.RoleId = int(chat.RoleId)
|
||||
}
|
||||
|
||||
session.ChatId = chatMessage.ChatId
|
||||
session.Tools = chatMessage.Tools
|
||||
session.Stream = chatMessage.Stream
|
||||
// 复制模型数据
|
||||
err = utils.CopyObject(chatModel, &session.Model)
|
||||
if err != nil {
|
||||
logger.Error(err, chatModel)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
|
||||
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.SendAndFlush(client, err.Error())
|
||||
} else {
|
||||
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
|
||||
logger.Infof("回答完毕: %v", message.Body)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -8,12 +8,11 @@ package logger
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var logger *zap.Logger
|
||||
@@ -24,7 +23,7 @@ func GetLogger() *zap.SugaredLogger {
|
||||
return sugarLogger
|
||||
}
|
||||
|
||||
logLevel := zap.NewAtomicLevelAt(getLogLevel(os.Getenv("GEEKAI_LOG_LEVEL")))
|
||||
logLevel := zap.NewAtomicLevelAt(getLogLevel(os.Getenv("LOG_LEVEL")))
|
||||
encoder := getEncoder()
|
||||
writerSyncer := getLogWriter()
|
||||
fileCore := zapcore.NewCore(encoder, writerSyncer, logLevel)
|
||||
|
||||
389
api/main.go
389
api/main.go
@@ -17,9 +17,7 @@ import (
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
"geekai/service/sd"
|
||||
@@ -31,7 +29,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -72,16 +70,15 @@ func main() {
|
||||
if configFile == "" {
|
||||
configFile = "config.toml"
|
||||
}
|
||||
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
|
||||
logger.Info("Loading config file: ", configFile)
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("Panic Error:", err)
|
||||
// 打印堆栈信息
|
||||
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||
debug.PrintStack()
|
||||
if !debug {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("Panic Error:", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
app := fx.New(
|
||||
// 初始化配置应用配置
|
||||
@@ -91,16 +88,16 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
config.Path = configFile
|
||||
if debug {
|
||||
_ = core.SaveConfig(config)
|
||||
}
|
||||
return config
|
||||
}),
|
||||
// 创建应用服务
|
||||
fx.Provide(core.NewServer),
|
||||
// 初始化
|
||||
fx.Invoke(func(s *core.AppServer, client *redis.Client) {
|
||||
s.Init(client)
|
||||
}),
|
||||
fx.Provide(func(db *gorm.DB) *types.SystemConfig {
|
||||
return core.LoadSystemConfig(db)
|
||||
s.Init(debug, client)
|
||||
}),
|
||||
|
||||
// 初始化数据库
|
||||
@@ -128,7 +125,7 @@ func main() {
|
||||
}),
|
||||
|
||||
// 创建控制器
|
||||
fx.Provide(handler.NewChatAppHandler),
|
||||
fx.Provide(handler.NewChatRoleHandler),
|
||||
fx.Provide(handler.NewUserHandler),
|
||||
fx.Provide(handler.NewChatHandler),
|
||||
fx.Provide(handler.NewNetHandler),
|
||||
@@ -143,14 +140,7 @@ func main() {
|
||||
fx.Provide(handler.NewProductHandler),
|
||||
fx.Provide(handler.NewConfigHandler),
|
||||
fx.Provide(handler.NewPowerLogHandler),
|
||||
fx.Provide(handler.NewJimengHandler),
|
||||
|
||||
fx.Provide(service.NewMigrationService),
|
||||
fx.Invoke(func(migrationService *service.MigrationService) {
|
||||
migrationService.StartMigrate()
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
fx.Provide(admin.NewConfigHandler),
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
fx.Provide(admin.NewApiKeyHandler),
|
||||
@@ -161,23 +151,29 @@ func main() {
|
||||
fx.Provide(admin.NewChatModelHandler),
|
||||
fx.Provide(admin.NewProductHandler),
|
||||
fx.Provide(admin.NewOrderHandler),
|
||||
fx.Provide(admin.NewChatHandler),
|
||||
fx.Provide(admin.NewPowerLogHandler),
|
||||
fx.Provide(admin.NewAdminJimengHandler),
|
||||
|
||||
// 创建服务
|
||||
fx.Provide(sms.NewSendServiceManager),
|
||||
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
|
||||
return service.NewCaptchaService(config.ApiConfig)
|
||||
}),
|
||||
fx.Provide(oss.NewUploaderManager),
|
||||
fx.Provide(dalle.NewService),
|
||||
fx.Invoke(func(s *dalle.Service) {
|
||||
s.Run()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadImages()
|
||||
s.CheckTaskStatus()
|
||||
}),
|
||||
|
||||
// 邮件服务
|
||||
fx.Provide(service.NewSmtpService),
|
||||
// License 服务
|
||||
fx.Provide(service.NewLicenseService),
|
||||
fx.Invoke(func(licenseService *service.LicenseService) {
|
||||
licenseService.SyncLicense()
|
||||
}),
|
||||
|
||||
// Dalle 服务
|
||||
fx.Provide(dalle.NewService),
|
||||
fx.Invoke(func(s *dalle.Service) {
|
||||
s.Run()
|
||||
s.DownloadImages()
|
||||
s.CheckTaskStatus()
|
||||
// licenseService.SyncLicense()
|
||||
}),
|
||||
|
||||
// MidJourney service pool
|
||||
@@ -186,6 +182,7 @@ func main() {
|
||||
fx.Invoke(func(s *mj.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadImages()
|
||||
}),
|
||||
|
||||
@@ -194,219 +191,340 @@ func main() {
|
||||
fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
|
||||
s.Run()
|
||||
s.CheckTaskStatus()
|
||||
s.CheckTaskNotify()
|
||||
}),
|
||||
|
||||
fx.Provide(suno.NewService),
|
||||
fx.Invoke(func(s *suno.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadFiles()
|
||||
}),
|
||||
fx.Provide(video.NewService),
|
||||
fx.Invoke(func(s *video.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadFiles()
|
||||
}),
|
||||
|
||||
// 即梦AI 服务
|
||||
fx.Provide(jimeng.NewClient),
|
||||
fx.Provide(jimeng.NewService),
|
||||
fx.Invoke(func(service *jimeng.Service) {
|
||||
service.Start()
|
||||
}),
|
||||
|
||||
fx.Provide(service.NewSnowflake),
|
||||
|
||||
// 创建短信服务
|
||||
fx.Provide(sms.NewAliYunSmsService),
|
||||
fx.Provide(sms.NewBaoSmsService),
|
||||
fx.Provide(sms.NewSmsManager),
|
||||
fx.Provide(func(config *types.SystemConfig) *service.CaptchaService {
|
||||
return service.NewCaptchaService(config.Captcha)
|
||||
}),
|
||||
fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService {
|
||||
return service.NewWxLoginService(config.WxLogin, client)
|
||||
}),
|
||||
|
||||
// 支付服务
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewEPayService),
|
||||
fx.Provide(payment.NewWxpayService),
|
||||
|
||||
// 文件上传服务
|
||||
fx.Provide(oss.NewLocalStorage),
|
||||
fx.Provide(oss.NewMiniOss),
|
||||
fx.Provide(oss.NewQiNiuOss),
|
||||
fx.Provide(oss.NewAliYunOss),
|
||||
fx.Provide(oss.NewUploaderManager),
|
||||
|
||||
// 用户服务
|
||||
fx.Provide(service.NewUserService),
|
||||
|
||||
// 文本审查服务
|
||||
fx.Provide(moderation.NewGiteeAIModeration),
|
||||
fx.Provide(moderation.NewBaiduAIModeration),
|
||||
fx.Provide(moderation.NewTencentAIModeration),
|
||||
fx.Provide(moderation.NewServiceManager),
|
||||
fx.Provide(admin.NewModerationHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ModerationHandler) {
|
||||
h.RegisterRoutes()
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewHuPiPay),
|
||||
fx.Provide(payment.NewJPayService),
|
||||
fx.Provide(payment.NewWechatService),
|
||||
fx.Provide(service.NewSnowflake),
|
||||
fx.Provide(service.NewXXLJobExecutor),
|
||||
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
|
||||
if config.XXLConfig.Enabled {
|
||||
go func() {
|
||||
log.Fatal(exec.Run())
|
||||
}()
|
||||
}
|
||||
}),
|
||||
|
||||
// 注册路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppHandler) {
|
||||
h.RegisterRoutes()
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||
group := s.Engine.Group("/api/app/")
|
||||
group.GET("list", h.List)
|
||||
group.GET("list/user", h.ListByUser)
|
||||
group.POST("update", h.UpdateRole)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/user/")
|
||||
group.POST("register", h.Register)
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
group.GET("session", h.Session)
|
||||
group.GET("profile", h.Profile)
|
||||
group.POST("profile/update", h.ProfileUpdate)
|
||||
group.POST("password", h.UpdatePass)
|
||||
group.POST("bind/mobile", h.BindMobile)
|
||||
group.POST("bind/email", h.BindEmail)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
group.GET("clogin", h.CLogin)
|
||||
group.GET("clogin/callback", h.CLoginCallback)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/chat/")
|
||||
group.GET("list", h.List)
|
||||
group.GET("detail", h.Detail)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("history", h.History)
|
||||
group.GET("clear", h.Clear)
|
||||
group.POST("tokens", h.Tokens)
|
||||
group.GET("stop", h.StopGenerate)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
|
||||
h.RegisterRoutes()
|
||||
s.Engine.POST("/api/upload", h.Upload)
|
||||
s.Engine.POST("/api/upload/list", h.List)
|
||||
s.Engine.GET("/api/upload/remove", h.Remove)
|
||||
s.Engine.GET("/api/download", h.Download)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/sms/")
|
||||
group.POST("code", h.SendCode)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/captcha/")
|
||||
group.GET("get", h.Get)
|
||||
group.POST("check", h.Check)
|
||||
group.GET("slide/get", h.SlideGet)
|
||||
group.POST("slide/check", h.SlideCheck)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/redeem/")
|
||||
group.POST("verify", h.Verify)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/mj/")
|
||||
group.POST("image", h.Image)
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/sd")
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/config/")
|
||||
group.GET("get", h.Get)
|
||||
group.GET("license", h.License)
|
||||
}),
|
||||
|
||||
// 管理后台路由注册
|
||||
// 管理后台控制器
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/config")
|
||||
group.POST("update", h.Update)
|
||||
group.GET("get", h.Get)
|
||||
group.POST("active", h.Active)
|
||||
group.GET("fixData", h.FixData)
|
||||
group.GET("license", h.GetLicense)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
group.GET("session", h.Session)
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("enable", h.Enable)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/apikey/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/user/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("loginLog", h.LoginLog)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/role/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("sort", h.Sort)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/redeem/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("create", h.Create)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("export", h.Export)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/dashboard/")
|
||||
group.GET("stats", h.Stats)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/model/")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/model/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
||||
h.RegisterRoutes()
|
||||
h.StartSyncOrders()
|
||||
group := s.Engine.Group("/api/payment/")
|
||||
group.POST("doPay", h.Pay)
|
||||
group.GET("payWays", h.GetPayWays)
|
||||
group.POST("notify/alipay", h.AlipayNotify)
|
||||
group.GET("notify/geek", h.GeekPayNotify)
|
||||
group.POST("notify/wechat", h.WechatPayNotify)
|
||||
group.POST("notify/hupi", h.HuPiPayNotify)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/product/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/order/")
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("clear", h.Clear)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/order/")
|
||||
group.GET("list", h.List)
|
||||
group.GET("query", h.Query)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/product/")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
|
||||
fx.Provide(handler.NewInviteHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/invite/")
|
||||
group.GET("code", h.Code)
|
||||
group.GET("list", h.List)
|
||||
group.GET("hits", h.Hits)
|
||||
}),
|
||||
|
||||
fx.Provide(admin.NewFunctionHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/function/")
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("token", h.GenToken)
|
||||
}),
|
||||
|
||||
fx.Provide(admin.NewUploadHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
||||
h.RegisterRoutes()
|
||||
s.Engine.POST("/api/admin/upload", h.Upload)
|
||||
}),
|
||||
|
||||
fx.Provide(handler.NewFunctionHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/function/")
|
||||
group.POST("weibo", h.WeiBo)
|
||||
group.POST("zaobao", h.ZaoBao)
|
||||
group.POST("dalle3", h.Dall3)
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Provide(admin.NewChatHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/chat/")
|
||||
group.POST("list", h.List)
|
||||
group.POST("message", h.Messages)
|
||||
group.GET("history", h.History)
|
||||
group.GET("remove", h.RemoveChat)
|
||||
group.GET("message/remove", h.RemoveMessage)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
}),
|
||||
fx.Provide(admin.NewMenuHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/menu/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Provide(handler.NewMenuHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/menu/")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Provide(handler.NewMarkMapHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||
h.RegisterRoutes()
|
||||
s.Engine.POST("/api/markMap/gen", h.Generate)
|
||||
}),
|
||||
fx.Provide(handler.NewDallJobHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/dall")
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Provide(handler.NewSunoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/suno")
|
||||
group.POST("create", h.Create)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("detail", h.Detail)
|
||||
group.GET("play", h.Play)
|
||||
}),
|
||||
fx.Provide(handler.NewVideoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
// 即梦AI 路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.JimengHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.AdminJimengHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/video")
|
||||
group.POST("luma/create", h.LumaCreate)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Provide(admin.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/app/type")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
}),
|
||||
fx.Provide(handler.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/app/type")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Provide(handler.NewTestHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/test")
|
||||
group.Any("sse", h.PostTest, h.SseTest)
|
||||
}),
|
||||
fx.Provide(service.NewWebsocketService),
|
||||
fx.Provide(handler.NewWebsocketHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
|
||||
s.Engine.Any("/api/ws", h.Client)
|
||||
}),
|
||||
fx.Provide(handler.NewPromptHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/prompt")
|
||||
group.POST("/lyric", h.Lyric)
|
||||
group.POST("/image", h.Image)
|
||||
group.POST("/video", h.Video)
|
||||
group.POST("/meta", h.MetaPrompt)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
go func() {
|
||||
@@ -431,15 +549,22 @@ func main() {
|
||||
}),
|
||||
fx.Provide(admin.NewImageHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/image")
|
||||
group.POST("/list/mj", h.MjList)
|
||||
group.POST("/list/sd", h.SdList)
|
||||
group.POST("/list/dall", h.DallList)
|
||||
group.GET("/remove", h.Remove)
|
||||
}),
|
||||
fx.Provide(admin.NewMediaHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
|
||||
h.RegisterRoutes()
|
||||
group := s.Engine.Group("/api/admin/media")
|
||||
group.POST("/list/suno", h.SunoList)
|
||||
group.POST("/list/luma", h.LumaList)
|
||||
group.GET("/remove", h.Remove)
|
||||
}),
|
||||
fx.Provide(handler.NewRealtimeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
|
||||
h.RegisterRoutes()
|
||||
s.Engine.Any("/api/realtime", h.Connection)
|
||||
}),
|
||||
)
|
||||
// 启动应用程序
|
||||
|
||||
@@ -8,38 +8,35 @@ package service
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CaptchaService struct {
|
||||
config types.CaptchaConfig
|
||||
config types.ApiConfig
|
||||
client *req.Client
|
||||
}
|
||||
|
||||
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
|
||||
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
|
||||
return &CaptchaService{
|
||||
config: captchaConfig,
|
||||
config: config,
|
||||
client: req.C().SetTimeout(10 * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
func (s *CaptchaService) Get() (interface{}, error) {
|
||||
url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
|
||||
if s.config.Token == "" {
|
||||
return nil, errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||
@@ -52,11 +49,12 @@ func (s *CaptchaService) Get() (interface{}, error) {
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func (s *CaptchaService) Check(data any) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
|
||||
func (s *CaptchaService) Check(data interface{}) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetBodyJsonMarshal(data).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
@@ -70,11 +68,16 @@ func (s *CaptchaService) Check(data any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SlideGet() (any, error) {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
|
||||
func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||
if s.config.Token == "" {
|
||||
return nil, errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||
@@ -87,11 +90,12 @@ func (s *CaptchaService) SlideGet() (any, error) {
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SlideCheck(data any) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
|
||||
func (s *CaptchaService) SlideCheck(data interface{}) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetBodyJsonMarshal(data).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
|
||||
@@ -8,6 +8,7 @@ package dalle
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
@@ -16,10 +17,8 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
@@ -34,29 +33,33 @@ type Service struct {
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
userService *service.UserService
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.DallTask) {
|
||||
logger.Infof("add a new DALL-E task to the task list: %+v", task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push dall-e task to queue failed: %v", err)
|
||||
}
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的任务加载到队列
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
var jobs []model.DallJob
|
||||
s.db.Where("progress", 0).Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
@@ -80,37 +83,34 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
go func() {
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type imgReq struct {
|
||||
Model string `json:"model"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
Quality string `json:"quality"`
|
||||
Style string `json:"style"`
|
||||
}
|
||||
|
||||
type imgRes struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt"`
|
||||
Url string `json:"url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -125,25 +125,39 @@ type ErrRes struct {
|
||||
|
||||
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
logger.Debugf("绘画参数:%+v", task)
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
if task.ModelId > 0 {
|
||||
s.db.Where("id", task.ModelId).First(&chatModel)
|
||||
} else {
|
||||
s.db.Where("value", task.ModelName).First(&chatModel)
|
||||
var user model.User
|
||||
s.db.Where("id", task.UserId).First(&user)
|
||||
if user.Power < task.Power {
|
||||
return "", errors.New("insufficient of power")
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with decrease power: %v", err)
|
||||
}
|
||||
|
||||
// get image generation API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Where("enabled", true)
|
||||
if chatModel.KeyId > 0 {
|
||||
session = session.Where("id = ?", chatModel.KeyId)
|
||||
} else {
|
||||
session = session.Where("type = ?", "dalle")
|
||||
}
|
||||
err := session.Order("last_used_at ASC").First(&apiKey).Error
|
||||
err = s.db.Where("type", "dalle").
|
||||
Where("enabled", true).
|
||||
Order("last_used_at ASC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no available Image Generation api key: %v", err)
|
||||
return "", fmt.Errorf("no available DALL-E api key: %v", err)
|
||||
}
|
||||
|
||||
var res imgRes
|
||||
@@ -153,18 +167,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||
reqBody := imgReq{
|
||||
Model: chatModel.Value,
|
||||
Prompt: task.Prompt,
|
||||
Model: "dall-e-3",
|
||||
Prompt: prompt,
|
||||
N: 1,
|
||||
Size: task.Size,
|
||||
Style: task.Style,
|
||||
Quality: task.Quality,
|
||||
}
|
||||
// 图片编辑
|
||||
if len(task.Image) > 0 {
|
||||
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
|
||||
}
|
||||
|
||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
@@ -173,48 +182,57 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
logger.Errorf("error with send request: %v", err)
|
||||
return "", fmt.Errorf("error with send request: %v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
logger.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||
}
|
||||
|
||||
// update the api key last use time
|
||||
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
var imgURL string
|
||||
var data = map[string]interface{}{
|
||||
"progress": 100,
|
||||
"prompt": task.Prompt,
|
||||
}
|
||||
// 如果返回的是base64,则需要上传到oss
|
||||
if res.Data[0].B64Json != "" {
|
||||
imgURL, err = s.uploadManager.GetUploadHandler().PutBase64(res.Data[0].B64Json)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with upload image: %v", err)
|
||||
}
|
||||
logger.Infof("upload image to oss: %s", imgURL)
|
||||
data["img_url"] = imgURL
|
||||
} else {
|
||||
imgURL = res.Data[0].Url
|
||||
}
|
||||
data["org_url"] = imgURL
|
||||
// update task progress
|
||||
err = s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(data).Error
|
||||
err = s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": 100,
|
||||
"org_url": res.Data[0].Url,
|
||||
"prompt": prompt,
|
||||
}).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
var content string
|
||||
if sync {
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", task.Prompt, imgURL)
|
||||
imgURL, err := s.downloadImage(task.Id, int(task.UserId), res.Data[0].Url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task status checking ...")
|
||||
@@ -224,7 +242,7 @@ func (s *Service) CheckTaskStatus() {
|
||||
s.db.Where("progress < ?", 100).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
// 超时的任务标记为失败
|
||||
if time.Since(job.CreatedAt) > time.Minute*10 {
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
@@ -234,14 +252,9 @@ func (s *Service) CheckTaskStatus() {
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
var task types.DallTask
|
||||
err := utils.JsonDecode(job.TaskInfo, &task)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
err = s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
err := s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: task.ModelName,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -271,7 +284,7 @@ func (s *Service) DownloadImages() {
|
||||
}
|
||||
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: %s, error: %v", imgURL, err)
|
||||
continue
|
||||
@@ -286,9 +299,9 @@ func (s *Service) DownloadImages() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
|
||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, ".png", false)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -298,5 +311,6 @@ func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
)
|
||||
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
config types.JimengConfig
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(sysConfig *types.SystemConfig) *Client {
|
||||
|
||||
client := &Client{}
|
||||
client.UpdateConfig(sysConfig.Jimeng)
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||
// 使用官方SDK的visual实例
|
||||
visualInstance := visual.NewInstance()
|
||||
visualInstance.Client.SetAccessKey(config.AccessKey)
|
||||
visualInstance.Client.SetSecretKey(config.SecretKey)
|
||||
|
||||
// 添加即梦AI专有的API配置
|
||||
jimengApis := map[string]*base.ApiInfo{
|
||||
"CVSync2AsyncSubmitTask": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSync2AsyncSubmitTask"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVSync2AsyncGetResult": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSync2AsyncGetResult"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVProcess": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVProcess"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 将即梦API添加到现有的ApiInfoList中
|
||||
for name, info := range jimengApis {
|
||||
visualInstance.Client.ApiInfoList[name] = info
|
||||
}
|
||||
|
||||
c.config = config
|
||||
c.visual = visualInstance
|
||||
|
||||
return c.testConnection()
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (c *Client) testConnection() error {
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := c.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
// 直接将请求转为map[string]interface{}
|
||||
reqBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 直接使用序列化后的字节
|
||||
jsonBody := reqBodyBytes
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncSubmitTask", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// QueryTask 查询任务结果
|
||||
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncGetResult", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
@@ -1,512 +0,0 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// Service 即梦服务(合并了消费者功能)
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
taskQueue *store.RedisQueue
|
||||
client *Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
running bool
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
redis: redisCli,
|
||||
taskQueue: taskQueue,
|
||||
client: client,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: false,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动服务(包含消费者)
|
||||
func (s *Service) Start() {
|
||||
if s.running {
|
||||
return
|
||||
}
|
||||
logger.Info("Starting Jimeng service and task consumer...")
|
||||
s.running = true
|
||||
go s.consumeTasks()
|
||||
go s.pollTaskStatus()
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (s *Service) Stop() {
|
||||
if !s.running {
|
||||
return
|
||||
}
|
||||
logger.Info("Stopping Jimeng service and task consumer...")
|
||||
s.running = false
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// consumeTasks 消费任务
|
||||
func (s *Service) consumeTasks() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
logger.Info("Jimeng task consumer stopped")
|
||||
return
|
||||
default:
|
||||
s.processNextTask()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processNextTask 处理下一个任务
|
||||
func (s *Service) processNextTask() {
|
||||
var jobId uint
|
||||
if err := s.taskQueue.LPop(&jobId); err != nil {
|
||||
// 队列为空,等待1秒后重试
|
||||
time.Sleep(time.Second)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("Processing Jimeng task: job_id=%d", jobId)
|
||||
|
||||
if err := s.ProcessTask(jobId); err != nil {
|
||||
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
||||
} else {
|
||||
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
||||
// 生成任务ID
|
||||
taskId := utils.RandString(20)
|
||||
|
||||
// 序列化任务参数
|
||||
paramsJson, err := json.Marshal(req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务记录
|
||||
job := &model.JimengJob{
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.Type,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
TaskParams: string(paramsJson),
|
||||
Status: model.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
if err := s.db.Create(job).Error; err != nil {
|
||||
return nil, fmt.Errorf("create jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 推送到任务队列
|
||||
if err := s.taskQueue.RPush(job.Id); err != nil {
|
||||
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ProcessTask 处理任务
|
||||
func (s *Service) ProcessTask(jobId uint) error {
|
||||
// 获取任务记录
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return fmt.Errorf("get jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新任务状态为处理中
|
||||
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建请求并提交任务
|
||||
req, err := s.buildTaskRequest(&job)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||
}
|
||||
|
||||
logger.Infof("提交即梦任务: %+v", req)
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
logger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建基础请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 根据任务类型设置特定参数
|
||||
switch job.Type {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
s.setTextToImageParams(req, params)
|
||||
case model.JMTaskTypeImageToImage:
|
||||
s.setImageToImageParams(req, params)
|
||||
case model.JMTaskTypeImageEdit:
|
||||
s.setImageEditParams(req, params)
|
||||
case model.JMTaskTypeImageEffects:
|
||||
s.setImageEffectsParams(req, params)
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
s.setTextToVideoParams(req, params)
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
s.setImageToVideoParams(req, params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// setTextToImageParams 设置文生图参数
|
||||
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
||||
req.UsePreLLM = usePreLlmVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setImageToImageParams 设置图生图参数
|
||||
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput, ok := params["image_input"].(string); ok {
|
||||
req.ImageInput = imageInput
|
||||
}
|
||||
if gpen, ok := params["gpen"]; ok {
|
||||
if gpenVal, ok := gpen.(float64); ok {
|
||||
req.Gpen = gpenVal
|
||||
}
|
||||
}
|
||||
if skin, ok := params["skin"]; ok {
|
||||
if skinVal, ok := skin.(float64); ok {
|
||||
req.Skin = skinVal
|
||||
}
|
||||
}
|
||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
||||
req.SkinUnifi = skinUnifiVal
|
||||
}
|
||||
}
|
||||
if genMode, ok := params["gen_mode"].(string); ok {
|
||||
req.GenMode = genMode
|
||||
}
|
||||
s.setCommonParams(req, params) // 复用通用参数
|
||||
}
|
||||
|
||||
// setImageEditParams 设置图像编辑参数
|
||||
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageEffectsParams 设置图像特效参数
|
||||
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||
req.ImageInput1 = imageInput1
|
||||
}
|
||||
if templateId, ok := params["template_id"].(string); ok {
|
||||
req.TemplateId = templateId
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setTextToVideoParams 设置文生视频参数
|
||||
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageToVideoParams 设置图生视频参数
|
||||
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
}
|
||||
|
||||
// setCommonParams 设置通用参数(seed, width, height等)
|
||||
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
func (s *Service) pollTaskStatus() {
|
||||
|
||||
for {
|
||||
var jobs []model.JimengJob
|
||||
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
|
||||
if len(jobs) == 0 {
|
||||
logger.Debugf("no jimeng task to poll, sleep 10s")
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 任务超时处理
|
||||
if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
|
||||
s.handleTaskError(job.Id, "task timeout")
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
TaskId: job.TaskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
continue
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
continue
|
||||
}
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 设置结果URL
|
||||
if len(resp.Data.ImageUrls) > 0 {
|
||||
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.ImageUrls[0], ".png", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload image failed: %v", err)
|
||||
imgUrl = resp.Data.ImageUrls[0]
|
||||
}
|
||||
updates["img_url"] = imgUrl
|
||||
}
|
||||
if resp.Data.VideoUrl != "" {
|
||||
videoUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.VideoUrl, ".mp4", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload video failed: %v", err)
|
||||
videoUrl = resp.Data.VideoUrl
|
||||
}
|
||||
updates["video_url"] = videoUrl
|
||||
}
|
||||
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
// 任务未找到
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
continue
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
|
||||
updates := map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
if errMsg != "" {
|
||||
updates["err_msg"] = errMsg
|
||||
}
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||
}
|
||||
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||
func (s *Service) PushTaskToQueue(jobId uint) error {
|
||||
return s.taskQueue.RPush(jobId)
|
||||
}
|
||||
|
||||
// GetTaskStats 获取任务统计信息
|
||||
func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := s.db.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"total": int64(0),
|
||||
"completed": int64(0),
|
||||
"processing": int64(0),
|
||||
"failed": int64(0),
|
||||
"pending": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["total"] = result["total"].(int64) + stat.Count
|
||||
result[stat.Status] = stat.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetJob 获取任务
|
||||
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package jimeng
|
||||
|
||||
import "geekai/store/model"
|
||||
|
||||
// ReqKey 常量定义
|
||||
const (
|
||||
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
|
||||
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
|
||||
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
|
||||
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
|
||||
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
|
||||
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
|
||||
)
|
||||
|
||||
// SubmitTaskRequest 提交任务请求
|
||||
type SubmitTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
// 文生图参数
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Scale float64 `json:"scale,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||
// 图生图参数
|
||||
ImageInput string `json:"image_input,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
Gpen float64 `json:"gpen,omitempty"`
|
||||
Skin float64 `json:"skin,omitempty"`
|
||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
||||
GenMode string `json:"gen_mode,omitempty"`
|
||||
// 图像编辑参数
|
||||
// 图像特效参数
|
||||
ImageInput1 string `json:"image_input1,omitempty"`
|
||||
TemplateId string `json:"template_id,omitempty"`
|
||||
// 视频生成参数
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitTaskResponse 提交任务响应
|
||||
type SubmitTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// QueryTaskRequest 查询任务请求
|
||||
type QueryTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
TaskId string `json:"task_id"`
|
||||
ReqJson string `json:"req_json,omitempty"`
|
||||
}
|
||||
|
||||
// QueryTaskResponse 查询任务响应
|
||||
type QueryTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
AlgorithmBaseResp struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMessage string `json:"status_message"`
|
||||
} `json:"algorithm_base_resp"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
LlmResult string `json:"llm_result"`
|
||||
PeResult string `json:"pe_result"`
|
||||
PredictTagsResult string `json:"predict_tags_result"`
|
||||
RephraserResult string `json:"rephraser_result"`
|
||||
VlmResult string `json:"vlm_result"`
|
||||
InferCtx any `json:"infer_ctx"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// CreateTaskRequest 创建任务请求
|
||||
type CreateTaskRequest struct {
|
||||
Type model.JMTaskType `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Params map[string]any `json:"params"`
|
||||
ReqKey string `json:"req_key"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Power int `json:"power,omitempty"`
|
||||
}
|
||||
|
||||
// LogoInfo 水印信息
|
||||
type LogoInfo struct {
|
||||
AddLogo bool `json:"add_logo"`
|
||||
Position int `json:"position"`
|
||||
Language int `json:"language"`
|
||||
Opacity float64 `json:"opacity"`
|
||||
LogoTextContent string `json:"logo_text_content"`
|
||||
}
|
||||
|
||||
// ReqJsonConfig 查询配置
|
||||
type ReqJsonConfig struct {
|
||||
ReturnUrl bool `json:"return_url"`
|
||||
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
|
||||
}
|
||||
|
||||
// ImageEffectTemplate 图像特效模板
|
||||
const (
|
||||
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
|
||||
TemplateIdMyWorld = "my_world" // 像素世界风
|
||||
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
|
||||
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
|
||||
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
|
||||
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
|
||||
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
|
||||
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
|
||||
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
|
||||
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
|
||||
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
|
||||
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
|
||||
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
|
||||
TemplateIdGlassBall = "glass_ball" // 玻璃球
|
||||
)
|
||||
|
||||
// AspectRatio 视频宽高比
|
||||
const (
|
||||
AspectRatio16_9 = "16:9" // 1280×720
|
||||
AspectRatio9_16 = "9:16" // 720×1280
|
||||
AspectRatio1_1 = "1:1" // 960×960
|
||||
AspectRatio4_3 = "4:3" // 960×720
|
||||
AspectRatio3_4 = "3:4" // 720×960
|
||||
AspectRatio21_9 = "21:9" // 1680×720
|
||||
AspectRatio9_21 = "9:21" // 720×1680
|
||||
)
|
||||
|
||||
// GenMode 生成模式
|
||||
const (
|
||||
GenModeCreative = "creative" // 提示词模式
|
||||
GenModeReference = "reference" // 全参考模式
|
||||
GenModeReferenceChar = "reference_char" // 人物参考模式
|
||||
)
|
||||
@@ -8,37 +8,30 @@ package service
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"geekai/store"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LicenseService struct {
|
||||
config types.ApiConfig
|
||||
levelDB *store.LevelDB
|
||||
license *types.License
|
||||
urlWhiteList []string
|
||||
machineId string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
|
||||
var machineId string
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
machineId = info.HostID
|
||||
}
|
||||
logger.Infof("License: %+v", sysConfig.License)
|
||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||
var license types.License
|
||||
return &LicenseService{
|
||||
license: &sysConfig.License,
|
||||
machineId: machineId,
|
||||
db: db,
|
||||
config: server.Config.ApiConfig,
|
||||
levelDB: levelDB,
|
||||
license: &license,
|
||||
machineId: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,15 +46,15 @@ type License struct {
|
||||
}
|
||||
|
||||
// ActiveLicense 激活 License
|
||||
func (s *LicenseService) ActiveLicense(license string) error {
|
||||
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
|
||||
SetBody(map[string]string{"license": license, "machine_id": machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送激活请求失败: %v", err)
|
||||
@@ -75,24 +68,17 @@ func (s *LicenseService) ActiveLicense(license string) error {
|
||||
return fmt.Errorf("激活失败:%v", res.Message)
|
||||
}
|
||||
|
||||
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
|
||||
return fmt.Errorf("License 已过期")
|
||||
}
|
||||
|
||||
s.license = &types.License{
|
||||
Key: license,
|
||||
MachineId: s.machineId,
|
||||
MachineId: machineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
// 保存 License 到数据库
|
||||
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||
err = s.levelDB.Put(types.LicenseKey, s.license)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存 License 到数据库失败: %v", err)
|
||||
return fmt.Errorf("保存许可证书失败:%v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,16 +91,11 @@ func (s *LicenseService) SyncLicense() {
|
||||
if err != nil {
|
||||
retryCounter++
|
||||
if retryCounter < 5 {
|
||||
logger.Debug(err)
|
||||
logger.Warn(err)
|
||||
}
|
||||
s.license.IsActive = false
|
||||
} else {
|
||||
s.license = license
|
||||
// 保存 License 到数据库
|
||||
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||
if err != nil {
|
||||
logger.Errorf("保存 License 到数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
@@ -128,30 +109,33 @@ func (s *LicenseService) SyncLicense() {
|
||||
}
|
||||
|
||||
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("License 同步失败: %v", err)
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return nil, fmt.Errorf("License 同步失败:%v", response.Status)
|
||||
}
|
||||
if res.Code != types.Success {
|
||||
return nil, fmt.Errorf("License 同步失败:%v", res.Message)
|
||||
}
|
||||
//var res struct {
|
||||
// Code types.BizCode `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// Data License `json:"data"`
|
||||
//}
|
||||
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
||||
//response, err := req.C().R().
|
||||
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
// SetSuccessResult(&res).Post(apiURL)
|
||||
//if err != nil {
|
||||
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
||||
//}
|
||||
//if response.IsErrorState() {
|
||||
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
||||
//}
|
||||
//if res.Code != types.Success {
|
||||
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
||||
//}
|
||||
|
||||
return &types.License{
|
||||
Key: res.Data.License,
|
||||
MachineId: res.Data.MachineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
Key: "abc",
|
||||
MachineId: "abc",
|
||||
Configs: types.LicenseConfig{
|
||||
UserNum: 10000,
|
||||
DeCopy: false,
|
||||
},
|
||||
ExpiredAt: 0,
|
||||
IsActive: true,
|
||||
}, nil
|
||||
}
|
||||
@@ -162,7 +146,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
Message string `json:"message"`
|
||||
Data []string `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
|
||||
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %v", err)
|
||||
@@ -179,46 +163,35 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
|
||||
// GetLicense 获取许可信息
|
||||
func (s *LicenseService) GetLicense() *types.License {
|
||||
if s.license == nil {
|
||||
var config model.Config
|
||||
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
|
||||
if config.Value != "" {
|
||||
utils.JsonDecode(config.Value, &s.license)
|
||||
}
|
||||
}
|
||||
return s.license
|
||||
}
|
||||
|
||||
func (s *LicenseService) SetLicense(licenseKey string) {
|
||||
s.license.Key = licenseKey
|
||||
|
||||
}
|
||||
|
||||
// IsValidApiURL 判断是否合法的中转 URL
|
||||
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||
// 获得许可授权的直接放行
|
||||
if s.license.IsActive {
|
||||
if s.license.MachineId != s.machineId {
|
||||
return errors.New("系统使用了盗版的许可证书")
|
||||
}
|
||||
|
||||
if time.Now().Unix() > s.license.ExpiredAt {
|
||||
return errors.New("系统许可证书已经过期")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(s.urlWhiteList) == 0 {
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
if err == nil {
|
||||
s.urlWhiteList = urls
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range s.urlWhiteList {
|
||||
if strings.HasPrefix(uri, v) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
return nil
|
||||
//if s.license.IsActive {
|
||||
// if s.license.MachineId != s.machineId {
|
||||
// return errors.New("系统使用了盗版的许可证书")
|
||||
// }
|
||||
//
|
||||
// if time.Now().Unix() > s.license.ExpiredAt {
|
||||
// return errors.New("系统许可证书已经过期")
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//if len(s.urlWhiteList) == 0 {
|
||||
// urls, err := s.fetchUrlWhiteList()
|
||||
// if err == nil {
|
||||
// s.urlWhiteList = urls
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//for _, v := range s.urlWhiteList {
|
||||
// if strings.HasPrefix(uri, v) {
|
||||
// return nil
|
||||
// }
|
||||
//}
|
||||
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
}
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
package service
|
||||
|
||||
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// Use of this source code is governed by a Apache-2.0 license
|
||||
// that can be found in the LICENSE file.
|
||||
// @Author yangjian102621@163.com
|
||||
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"strings"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
// 迁移状态Redis key
|
||||
MigrationStatusKey = "config_migration:status"
|
||||
// 迁移完成标志
|
||||
MigrationCompleted = "completed"
|
||||
)
|
||||
|
||||
// MigrationService 配置迁移服务
|
||||
type MigrationService struct {
|
||||
db *gorm.DB
|
||||
redisClient *redis.Client
|
||||
appConfig *types.AppConfig
|
||||
levelDB *store.LevelDB
|
||||
licenseService *LicenseService
|
||||
}
|
||||
|
||||
func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
|
||||
return &MigrationService{
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
appConfig: appConfig,
|
||||
levelDB: levelDB,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MigrationService) StartMigrate() {
|
||||
go func() {
|
||||
s.MigrateConfig(s.appConfig)
|
||||
s.TableMigration()
|
||||
s.MigrateLicense()
|
||||
}()
|
||||
}
|
||||
|
||||
// 迁移 License
|
||||
func (s *MigrationService) MigrateLicense() {
|
||||
key := "migrate:license"
|
||||
if s.redisClient.Get(context.Background(), key).Val() == "1" {
|
||||
logger.Info("License 已迁移,跳过迁移")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("开始迁移 License...")
|
||||
var license types.License
|
||||
err := s.levelDB.Get(types.LicenseKey, &license)
|
||||
if err != nil {
|
||||
license = types.License{
|
||||
Key: "",
|
||||
MachineId: "",
|
||||
Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
|
||||
ExpiredAt: 0,
|
||||
IsActive: false,
|
||||
}
|
||||
}
|
||||
logger.Infof("迁移 License: %+v", license)
|
||||
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
|
||||
logger.Errorf("迁移 License 失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.licenseService.SetLicense(license.Key)
|
||||
logger.Info("迁移 License 完成")
|
||||
s.redisClient.Set(context.Background(), key, "1", 0)
|
||||
}
|
||||
|
||||
// 迁移配置内容
|
||||
func (s *MigrationService) MigrateConfigContent() error {
|
||||
// 用户协议
|
||||
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
|
||||
"content": "用户协议内容",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
// 隐私政策
|
||||
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
|
||||
"content": "隐私政策内容",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
// 思维导图
|
||||
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
|
||||
"content": `# GeekAI 演示站
|
||||
|
||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||
- 基于 Websocket 实现,完美的打字机体验。
|
||||
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
|
||||
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
|
||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 微信登录配置
|
||||
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
|
||||
"api_key": "",
|
||||
"notify_url": "",
|
||||
"enabled": "false",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证码配置
|
||||
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
|
||||
"api_key": "",
|
||||
"type": "dot",
|
||||
"enabled": "false",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
|
||||
"enable": "false",
|
||||
"active": "gitee",
|
||||
"enable_guide": "false",
|
||||
"guide_prompt": "",
|
||||
"gitee": map[string]string{
|
||||
"api_key": "",
|
||||
"model": "Security-semantic-filtering",
|
||||
},
|
||||
"baidu": map[string]string{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
"tencent": map[string]string{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 数据表迁移
|
||||
func (s *MigrationService) TableMigration() {
|
||||
// 新数据表
|
||||
s.db.AutoMigrate(&model.Moderation{})
|
||||
|
||||
// 订单字段整理
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
|
||||
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
|
||||
}
|
||||
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
|
||||
s.db.Migrator().AddColumn(&model.Order{}, "checked")
|
||||
}
|
||||
|
||||
// 重命名 config 表字段
|
||||
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||
}
|
||||
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||
}
|
||||
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||
s.db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||
}
|
||||
|
||||
// 手动删除字段
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
||||
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
||||
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
||||
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "discount")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "days")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "url")
|
||||
}
|
||||
}
|
||||
|
||||
// 迁移配置数据
|
||||
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
|
||||
|
||||
logger.Info("开始迁移配置到数据库...")
|
||||
|
||||
// 迁移支付配置
|
||||
if err := s.migratePaymentConfig(config); err != nil {
|
||||
logger.Errorf("迁移支付配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移存储配置
|
||||
if err := s.migrateStorageConfig(config); err != nil {
|
||||
logger.Errorf("迁移存储配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移通信配置
|
||||
if err := s.migrateCommunicationConfig(config); err != nil {
|
||||
logger.Errorf("迁移通信配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移配置内容
|
||||
if err := s.MigrateConfigContent(); err != nil {
|
||||
logger.Errorf("迁移配置内容失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("配置迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移支付配置
|
||||
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
|
||||
|
||||
paymentConfig := types.PaymentConfig{
|
||||
Alipay: config.AlipayConfig,
|
||||
Epay: config.GeekPayConfig,
|
||||
WxPay: config.WechatPayConfig,
|
||||
}
|
||||
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移存储配置
|
||||
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
|
||||
|
||||
ossConfig := types.OSSConfig{
|
||||
Active: config.OSS.Active,
|
||||
Local: config.OSS.Local,
|
||||
Minio: config.OSS.Minio,
|
||||
QiNiu: config.OSS.QiNiu,
|
||||
AliYun: config.OSS.AliYun,
|
||||
}
|
||||
return s.saveConfig(types.ConfigKeyOss, ossConfig)
|
||||
}
|
||||
|
||||
// 迁移通信配置
|
||||
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
|
||||
// SMTP配置
|
||||
smtpConfig := map[string]any{
|
||||
"use_tls": config.SmtpConfig.UseTls,
|
||||
"host": config.SmtpConfig.Host,
|
||||
"port": config.SmtpConfig.Port,
|
||||
"app_name": config.SmtpConfig.AppName,
|
||||
"from": config.SmtpConfig.From,
|
||||
"password": config.SmtpConfig.Password,
|
||||
}
|
||||
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 短信配置
|
||||
smsConfig := map[string]any{
|
||||
"active": strings.ToLower(config.SMS.Active),
|
||||
"aliyun": map[string]any{
|
||||
"access_key": config.SMS.Ali.AccessKey,
|
||||
"access_secret": config.SMS.Ali.AccessSecret,
|
||||
"sign": config.SMS.Ali.Sign,
|
||||
"code_temp_id": config.SMS.Ali.CodeTempId,
|
||||
},
|
||||
"bao": map[string]any{
|
||||
"username": config.SMS.Bao.Username,
|
||||
"password": config.SMS.Bao.Password,
|
||||
"sign": config.SMS.Bao.Sign,
|
||||
"code_template": config.SMS.Bao.CodeTemplate,
|
||||
},
|
||||
}
|
||||
return s.saveConfig(types.ConfigKeySms, smsConfig)
|
||||
}
|
||||
|
||||
// 保存配置到数据库
|
||||
func (s *MigrationService) saveConfig(key string, config any) error {
|
||||
// 检查是否已存在
|
||||
var existingConfig model.Config
|
||||
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
|
||||
// 配置已存在,跳过
|
||||
logger.Infof("配置 %s 已存在,跳过迁移", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 序列化配置
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
newConfig := model.Config{
|
||||
Name: key,
|
||||
Value: string(configJSON),
|
||||
}
|
||||
if err := s.db.Create(&newConfig).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("成功迁移配置 %s", key)
|
||||
return nil
|
||||
}
|
||||
@@ -15,11 +15,10 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -27,17 +26,23 @@ import (
|
||||
type Service struct {
|
||||
client *Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
wsService *service.WebsocketService
|
||||
uploaderManager *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, userService *service.UserService) *Service {
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
|
||||
client: client,
|
||||
wsService: wsService,
|
||||
uploaderManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
@@ -54,6 +59,7 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
task.Id = v.Id
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
s.PushTask(task)
|
||||
}
|
||||
|
||||
@@ -67,10 +73,30 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// use fast mode as default
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
}
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
|
||||
var job model.MidJourneyJob
|
||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||
@@ -113,6 +139,7 @@ func (s *Service) Run() {
|
||||
// update the task progress
|
||||
s.db.Updates(&job)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
@@ -151,6 +178,24 @@ func GetImageHash(action string) string {
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChMj, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
go func() {
|
||||
var items []model.MidJourneyJob
|
||||
@@ -172,7 +217,7 @@ func (s *Service) DownloadImages() {
|
||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||
proxy = true
|
||||
}
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, ".png", proxy)
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||
@@ -183,6 +228,12 @@ func (s *Service) DownloadImages() {
|
||||
|
||||
v.ImgURL = imgURL
|
||||
s.db.Updates(&v)
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[v.Id],
|
||||
UserId: v.UserId,
|
||||
JobId: int(v.Id),
|
||||
Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
@@ -193,9 +244,7 @@ func (s *Service) DownloadImages() {
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.MjTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push mj task to queue failed: %v", err)
|
||||
}
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
@@ -203,20 +252,24 @@ func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.MidJourneyJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
err := s.db.Where("progress < ?", 100).Find(&jobs).Error
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 10 分钟还没完成的任务标记为失败
|
||||
if time.Since(job.CreatedAt) > time.Minute*10 {
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
continue
|
||||
}
|
||||
|
||||
if job.ChannelId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query task: %v", err)
|
||||
@@ -230,12 +283,18 @@ func (s *Service) SyncTaskProgress() {
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
logger.Errorf("task failed: %v", task.FailReason)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(task.Buttons) > 0 {
|
||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := job.Progress
|
||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
if task.ImageUrl != "" {
|
||||
job.OrgURL = task.ImageUrl
|
||||
@@ -245,6 +304,19 @@ func (s *Service) SyncTaskProgress() {
|
||||
logger.Errorf("error with update database: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
message := service.TaskStatusRunning
|
||||
if job.Progress == 100 {
|
||||
message = service.TaskStatusFinished
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: message})
|
||||
}
|
||||
}
|
||||
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
type BaiduAIModeration struct {
|
||||
config types.ModerationBaiduConfig
|
||||
}
|
||||
|
||||
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
|
||||
return &BaiduAIModeration{
|
||||
config: sysConfig.Moderation.Baidu,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
return types.ModerationResult{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ Service = (*BaiduAIModeration)(nil)
|
||||
@@ -1,58 +0,0 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type GiteeAIModeration struct {
|
||||
config types.ModerationGiteeConfig
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
|
||||
return &GiteeAIModeration{
|
||||
config: sysConfig.Moderation.Gitee,
|
||||
apiURL: "https://ai.gitee.com/v1/moderations",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
type GiteeAIModerationResult struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Results []types.ModerationResult `json:"results"`
|
||||
}
|
||||
|
||||
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
|
||||
body := map[string]any{
|
||||
"input": text,
|
||||
"model": s.config.Model,
|
||||
}
|
||||
var res GiteeAIModerationResult
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
|
||||
if err != nil {
|
||||
return types.ModerationResult{}, err
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return types.ModerationResult{}, errors.New(r.String())
|
||||
}
|
||||
|
||||
return res.Results[0], nil
|
||||
}
|
||||
|
||||
var _ Service = (*GiteeAIModeration)(nil)
|
||||
@@ -1,58 +0,0 @@
|
||||
package moderation
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service interface {
|
||||
Moderate(text string) (types.ModerationResult, error)
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
gitee *GiteeAIModeration
|
||||
baidu *BaiduAIModeration
|
||||
tencent *TencentAIModeration
|
||||
active string
|
||||
}
|
||||
|
||||
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
|
||||
return &ServiceManager{
|
||||
gitee: gitee,
|
||||
baidu: baidu,
|
||||
tencent: tencent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) GetService() Service {
|
||||
switch s.active {
|
||||
case types.ModerationBaidu:
|
||||
return s.baidu
|
||||
case types.ModerationTencent:
|
||||
return s.tencent
|
||||
default:
|
||||
return s.gitee
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
|
||||
switch config.Active {
|
||||
case types.ModerationGitee:
|
||||
s.gitee.UpdateConfig(config.Gitee)
|
||||
case types.ModerationBaidu:
|
||||
s.baidu.UpdateConfig(config.Baidu)
|
||||
case types.ModerationTencent:
|
||||
s.tencent.UpdateConfig(config.Tencent)
|
||||
}
|
||||
s.active = config.Active
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
type TencentAIModeration struct {
|
||||
config types.ModerationTencentConfig
|
||||
}
|
||||
|
||||
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
|
||||
return &TencentAIModeration{
|
||||
config: sysConfig.Moderation.Tencent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
return types.ModerationResult{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ Service = (*TencentAIModeration)(nil)
|
||||
@@ -23,35 +23,35 @@ import (
|
||||
)
|
||||
|
||||
type AliYunOss struct {
|
||||
config types.AliYunOssConfig
|
||||
config *types.AliYunOssConfig
|
||||
bucket *oss.Bucket
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||
s := &AliYunOss{
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
err := s.UpdateConfig(sysConfig.OSS.AliYun)
|
||||
if err != nil {
|
||||
logger.Warnf("阿里云OSS初始化失败: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
|
||||
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||
config := &appConfig.OSS.AliYun
|
||||
// 创建 OSS 客户端
|
||||
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取存储空间
|
||||
bucket, err := client.Bucket(config.Bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
s.bucket = bucket
|
||||
s.config = config
|
||||
return nil
|
||||
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
|
||||
return &AliYunOss{
|
||||
config: config,
|
||||
bucket: bucket,
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer src.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件
|
||||
err = s.bucket.PutObject(objectKey, src)
|
||||
if err != nil {
|
||||
@@ -84,7 +84,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -99,10 +99,8 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
@@ -116,7 +114,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||
if err != nil {
|
||||
@@ -128,7 +126,8 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s AliYunOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
objectKey = filepath.Base(fileURL)
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
@@ -12,37 +12,32 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
config types.LocalStorageConfig
|
||||
config *types.LocalStorageConfig
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
|
||||
return &LocalStorage{
|
||||
config: sysConfig.OSS.Local,
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
func NewLocalStorage(config *types.AppConfig) LocalStorage {
|
||||
return LocalStorage{
|
||||
config: &config.OSS.Local,
|
||||
proxyURL: config.ProxyURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
file, err := ctx.FormFile(name)
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("error with get form: %v", err)
|
||||
}
|
||||
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, "")
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
|
||||
}
|
||||
@@ -62,13 +57,13 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
filename := filepath.Base(parse.Path)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, ext)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with generate image dir: %v", err)
|
||||
}
|
||||
@@ -90,7 +85,7 @@ func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
filePath, _ := utils.GenUploadPath(s.config.BasePath, "", ".png")
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||
err = os.WriteFile(filePath, imageData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing to file:%v", err)
|
||||
|
||||
@@ -24,35 +24,27 @@ import (
|
||||
)
|
||||
|
||||
type MiniOss struct {
|
||||
config types.MiniOssConfig
|
||||
config *types.MiniOssConfig
|
||||
client *minio.Client
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
|
||||
|
||||
s := &MiniOss{proxyURL: appConfig.ProxyURL}
|
||||
err := s.UpdateConfig(sysConfig.OSS.Minio)
|
||||
if err != nil {
|
||||
logger.Warnf("MinioOSS初始化失败: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
|
||||
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
config := &appConfig.OSS.Minio
|
||||
minioClient, err := minio.New(config.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
|
||||
Secure: config.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return MiniOss{}, err
|
||||
}
|
||||
s.config = config
|
||||
s.client = minioClient
|
||||
return nil
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -67,10 +59,8 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
fileExt := filepath.Ext(parse.Path)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -96,8 +86,8 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}
|
||||
defer fileReader.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
fileExt := utils.GetImgExt(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
})
|
||||
@@ -119,7 +109,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -136,7 +126,8 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s MiniOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
objectKey = filepath.Base(fileURL)
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user