mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-07 03:34:25 +08:00
Compare commits
140 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cfc7175e8 | ||
|
|
a3f6a641aa | ||
|
|
54fe49de5d | ||
|
|
454dfc1aa7 | ||
|
|
3156701d4e | ||
|
|
41eb0e634a | ||
|
|
73d003d6c3 | ||
|
|
5e4ba6d971 | ||
|
|
76d32c78d8 | ||
|
|
66776556d8 | ||
|
|
149f598f6d | ||
|
|
1d6f0ab714 | ||
|
|
545f257476 | ||
|
|
6cb1f16f56 | ||
|
|
598f6c48fb | ||
|
|
97e489901a | ||
|
|
eea57790de | ||
|
|
81b32523ed | ||
|
|
6b6fe1bebd | ||
|
|
a7063bf30a | ||
|
|
32fc4d86a2 | ||
|
|
e685876cc0 | ||
|
|
41e4b1c7ac | ||
|
|
76a3ada85f | ||
|
|
b1ddcef593 | ||
|
|
94a5187e75 | ||
|
|
521ca77541 | ||
|
|
5dd3c1835a | ||
|
|
9fb01ee3ee | ||
|
|
9d72edc048 | ||
|
|
a9505ff72d | ||
|
|
628ae15fd7 | ||
|
|
1ce71374ac | ||
|
|
7c81d946a7 | ||
|
|
cdaeb2a404 | ||
|
|
ca2de54438 | ||
|
|
303e9ed052 | ||
|
|
347b640614 | ||
|
|
19099aed6f | ||
|
|
615515094b | ||
|
|
c4fe6c825e | ||
|
|
26c18fcd5a | ||
|
|
a914994483 | ||
|
|
9edd6621b1 | ||
|
|
10e3e61b2c | ||
|
|
643cf6085a | ||
|
|
fa74ae18ee | ||
|
|
dffdbf697b | ||
|
|
5e59b3a708 | ||
|
|
7bc55f3ed1 | ||
|
|
73f5a44e0a | ||
|
|
2c6abbe7e4 | ||
|
|
1f0cf11636 | ||
|
|
c44f5d40fe | ||
|
|
314d81303b | ||
|
|
b9859e5591 | ||
|
|
a3d65ba939 | ||
|
|
8a2d2f66b5 | ||
|
|
d1c9fd6eba | ||
|
|
d629d842be | ||
|
|
f752ec5b06 | ||
|
|
c7b09f29ca | ||
|
|
51c270fb29 | ||
|
|
b97d4b7895 | ||
|
|
0627109b2b | ||
|
|
c2d4530395 | ||
|
|
c3be47d4ce | ||
|
|
a1acca6f7a | ||
|
|
ccfc9f17e9 | ||
|
|
4641482865 | ||
|
|
79522d9ab5 | ||
|
|
e0b4e8970a | ||
|
|
4d93e901e0 | ||
|
|
bcc72a3091 | ||
|
|
1c1ddf76fb | ||
|
|
04caf92702 | ||
|
|
c797b35f5a | ||
|
|
0746cd49f4 | ||
|
|
a3a2500498 | ||
|
|
ff69cb231a | ||
|
|
afb9193985 | ||
|
|
14fa4fdaa0 | ||
|
|
2a71d5d557 | ||
|
|
cd31333d0c | ||
|
|
f080425ee6 | ||
|
|
96dd0ddb99 | ||
|
|
a4ee5cdeff | ||
|
|
d0025032b0 | ||
|
|
43f00b1481 | ||
|
|
f580f671a3 | ||
|
|
b1fb16995a | ||
|
|
47907b9f0c | ||
|
|
ba55fca7cc | ||
|
|
f687a10416 | ||
|
|
393bfa137e | ||
|
|
4dcb0d850c | ||
|
|
88fa374104 | ||
|
|
e1b1c195f6 | ||
|
|
1352369af0 | ||
|
|
ded041da0f | ||
|
|
7b9a7475a9 | ||
|
|
3958e99e4d | ||
|
|
0ef51714c9 | ||
|
|
668ff70bc1 | ||
|
|
ed063a1d9d | ||
|
|
88eaddbd1d | ||
|
|
8369e18bf0 | ||
|
|
79b9476d3d | ||
|
|
41bfa3a974 | ||
|
|
6b0d4e81bf | ||
|
|
41e66d85d5 | ||
|
|
f98dcee7d4 | ||
|
|
04b364c1cd | ||
|
|
6c84d2557c | ||
|
|
8a4596b36a | ||
|
|
6e2deeed87 | ||
|
|
bf6834da4e | ||
|
|
68dcd054f9 | ||
|
|
a77bebbc29 | ||
|
|
36beb74de8 | ||
|
|
dd1f98db1e | ||
|
|
b7f41c524a | ||
|
|
a3f0576535 | ||
|
|
5c8a237e27 | ||
|
|
447adf45eb | ||
|
|
ca77288a69 | ||
|
|
63be3f5f56 | ||
|
|
cad1ce6943 | ||
|
|
54b5a78c0e | ||
|
|
98d4d58393 | ||
|
|
887fdb6679 | ||
|
|
63fd125439 | ||
|
|
c39dd913fd | ||
|
|
b40f7ed5f3 | ||
|
|
183829a08b | ||
|
|
03d33c784c | ||
|
|
eec10fdfbc | ||
|
|
0a3c74cd6f | ||
|
|
5768c7959e | ||
|
|
d124eddd9d |
56
CHANGELOG.md
56
CHANGELOG.md
@@ -1,5 +1,61 @@
|
||||
# 更新日志
|
||||
|
||||
## v4.2.5
|
||||
|
||||
- 功能优化:在代码右下角增加复制代码功能按钮,增加收起和展开代码功能
|
||||
- Bug 修复:修复 Shift + Enter 不换行的 Bug
|
||||
- Bug 修复:修复管理后台菜单添加页面的文本错误
|
||||
- Bug 修复:解决聊天页面异常退出不断重连的 bug
|
||||
- 功能优化:把 Luma 和可灵视频生成页面整合成一个视频创作中心页面,统一管理视频任务
|
||||
- 功能新增:增加即梦 AI 专题页面,支持即梦官方原生 API 的图片和视频生成 🎉🎉🎉
|
||||
|
||||
## v4.2.4
|
||||
|
||||
- 功能优化:更改前端构建技术选型,使用 Vite 构建,提升构建速度和兼容性
|
||||
- 功能优化:使用 SSE 发送消息,替换原来的 Websocket 消息方案
|
||||
- 功能新增:管理后台支持设置默认昵称
|
||||
- 功能优化:支持 Suno v4.5 模型支持
|
||||
- 功能新增:用户注册和用户登录增加用户协议和隐私政策功能,需要用户同意协议才可注册和登录。
|
||||
- 功能优化:修改重新回答功能,撤回千面的问答内容为可编辑内容,撤回的内容不会增加额外的上下文
|
||||
- 功能优化:优化聊天记录的存储结构,增加模型名称字段,支持存储更长的模型名称
|
||||
- Bug 修复:聊天应用绑定模型后无效,还是会轮询 API KEY,导致一会成功,一会请求失败。
|
||||
- 功能优化:如果管理后台没有启用会员充值菜单,移动端也不显示充值套餐功能
|
||||
|
||||
## v4.2.3
|
||||
|
||||
- 功能优化:增加模型分组与模型描述,采用卡片展示模式改进模型选择功能体验
|
||||
- 功能优化:化思维导图下载图片的清晰度以及解决拖动、缩放操作后下载图片内容不全问题
|
||||
- Bug 修复:修复 MJ 画图页面已画出的图,点复制指令无效问题
|
||||
- 功能优化:MJ 画图的分辨率支持自定义,优先使用 prompt 中--ar 参数
|
||||
- Bug 修复:修复 MJ 绘画 U1-V1,拼写错误
|
||||
- 功能优化:支持自动迁移数据表结构,无需在手动执行 SQL 了
|
||||
- 功能优化:移除首页的文字动画效果
|
||||
- 功能优化:在聊天页面增加对话列表展开和隐藏功能
|
||||
- 功能优化:聊天页面增加 AI 思考中动画效果
|
||||
|
||||
## v4.2.2
|
||||
|
||||
- 功能优化:开启图形验证码功能的时候现检查是否配置了 API 服务,防止开启之后没法登录的 Bug。
|
||||
- 功能优化:支持原生的 DeepSeek 推理模型 API,聊天 API KEY 支持设置完整的 API 路径,比如 https://api.geekai.pro/v1/chat/completions
|
||||
- 功能优化:支持 GPT-4o 图片编辑功能。
|
||||
- 功能新增:对话页面支持 AI 输出语音播报(TTS)。
|
||||
- 功能优化:替换瀑布流组件,优化用户体验。
|
||||
- 功能优化:生成思维导图时候自动缓存上一次的结果。
|
||||
- 功能优化:优化 MJ 绘图页面,增加 MJ-V7 模型支持。
|
||||
- 功能优化:后台管理增加生成一键登录链接地址功能
|
||||
|
||||
## v4.2.1
|
||||
|
||||
- 功能新增:新增支持可灵生成视频,支持文生视频,图生生视频。
|
||||
- Bug 修复:修复手机端登录页面 Logo 无法修改的问题。
|
||||
- 功能新增:重构所有异步任务(绘图,音乐,视频)更新方式,使用 http pull 来替代 websocket。
|
||||
- 功能优化:优化 Luma 图生视频功能,支持本地上传图片和远程图片。
|
||||
- Bug 修复:修复移动端聊天页面新建对话时候角色没有更模型绑定的 Bug。
|
||||
- 功能优化:优化聊天页面代码块样式,优化公式的解析。
|
||||
- 功能优化:在绘图,视频相关 API 增加提示词长度的检查,防止提示词超出导致写入数据库失败。
|
||||
- Bug 修复:优化 Redis 连接池配置,增加连接池超时时间,单核服务器报错 `redis: connection pool timeout`。
|
||||
- 功能优化:优化邮件验证码发送逻辑,更新邮件发送成功提示。
|
||||
|
||||
## v4.2.0
|
||||
|
||||
- 功能优化:优化聊天页面 Notice 组件样式,采用 Vuepress 文档样式
|
||||
|
||||
66
CLAUDE.md
Normal file
66
CLAUDE.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build Commands
|
||||
|
||||
### Go Backend (api/)
|
||||
- **Development**: `cd api && go run main.go` (uses config.toml)
|
||||
- **Build**: `cd api && make` (builds both amd64 and arm64 binaries)
|
||||
- **Individual builds**: `make amd64` or `make arm64`
|
||||
- **Clean**: `make clean`
|
||||
- **Config**: Copy `config.sample.toml` to `config.toml` and configure
|
||||
|
||||
### Web Frontend (web/)
|
||||
- **Development**: `cd web && npm run dev` (runs on Vite dev server with --host)
|
||||
- **Build**: `cd web && npm run build`
|
||||
- **Lint**: `cd web && npm run lint` (ESLint with auto-fix)
|
||||
|
||||
### Testing
|
||||
- Backend tests: `cd api/test && bash run_crawler_test.sh`
|
||||
- No specific frontend test configuration found
|
||||
|
||||
## Project Architecture
|
||||
|
||||
### Backend (Go)
|
||||
- **Framework**: Gin web framework with dependency injection via uber-go/fx
|
||||
- **Database**: GORM with MySQL, Redis for caching, LevelDB for local storage
|
||||
- **Authentication**: JWT tokens with Redis session storage
|
||||
- **Middleware**: CORS, authorization, parameter handling, static resource serving
|
||||
- **Structure**:
|
||||
- `handler/`: HTTP request handlers (REST API endpoints)
|
||||
- `service/`: Business logic services (AI integrations, payments, etc.)
|
||||
- `store/`: Database models and data access layer
|
||||
- `core/`: Application server and middleware configuration
|
||||
- `utils/`: Utility functions and helpers
|
||||
|
||||
### Frontend (Vue.js)
|
||||
- **Framework**: Vue 3 with Composition API
|
||||
- **UI Components**: Element Plus + Vant (mobile components)
|
||||
- **State Management**: Pinia
|
||||
- **Routing**: Vue Router with nested routes
|
||||
- **Build Tool**: Vite
|
||||
- **CSS**: Stylus preprocessor with Tailwind CSS utilities
|
||||
- **Features**: Responsive design (desktop/mobile views), theme switching (dark/light)
|
||||
|
||||
### Key Features
|
||||
- **AI Chat**: Multiple chat models and conversation management
|
||||
- **Image Generation**: MidJourney, Stable Diffusion, DALL-E integration
|
||||
- **Audio/Video**: Suno music creation, Luma/KeLing video generation
|
||||
- **User Management**: Authentication, payments, power logs, invitations
|
||||
- **Admin Panel**: Comprehensive management interface
|
||||
|
||||
### Database Models
|
||||
Key entities: User, ChatItem, ChatMessage, ChatRole, ChatModel, Order, Product, AdminUser, and various job types for AI services.
|
||||
|
||||
### API Structure
|
||||
- User APIs: `/api/user/*` (auth, profile, settings)
|
||||
- Chat APIs: `/api/chat/*` (conversations, messages)
|
||||
- AI Service APIs: `/api/mj/*`, `/api/sd/*`, `/api/dall/*`, `/api/suno/*`, `/api/video/*`
|
||||
- Admin APIs: `/api/admin/*` (management functions)
|
||||
|
||||
### Configuration
|
||||
- Backend: TOML configuration file (`config.toml`)
|
||||
- Database: MySQL with automatic migrations
|
||||
- Services: Redis, various AI API integrations
|
||||
- File Storage: Local, Aliyun OSS, MinIO, Qiniu options
|
||||
195
JIMENG_CONFIG_README.md
Normal file
195
JIMENG_CONFIG_README.md
Normal file
@@ -0,0 +1,195 @@
|
||||
# 即梦 AI 配置功能说明
|
||||
|
||||
## 功能概述
|
||||
|
||||
即梦 AI 配置功能允许管理员通过 Web 界面配置即梦 AI 的 API 密钥和算力消耗设置,支持动态配置更新,无需重启服务。
|
||||
|
||||
## 功能特性
|
||||
|
||||
### 1. 秘钥配置
|
||||
|
||||
- AccessKey 和 SecretKey 配置
|
||||
- 支持密码显示/隐藏
|
||||
- 连接测试功能
|
||||
|
||||
### 2. 算力配置
|
||||
|
||||
- 文生图算力消耗
|
||||
- 图生图算力消耗
|
||||
- 图片编辑算力消耗
|
||||
- 图片特效算力消耗
|
||||
- 文生视频算力消耗
|
||||
- 图生视频算力消耗
|
||||
|
||||
### 3. 动态配置
|
||||
|
||||
- 配置实时生效
|
||||
- 无需重启服务
|
||||
- 支持配置验证
|
||||
|
||||
## API 接口
|
||||
|
||||
### 获取配置
|
||||
|
||||
```
|
||||
GET /api/admin/jimeng/config
|
||||
```
|
||||
|
||||
### 更新配置
|
||||
|
||||
```
|
||||
POST /api/admin/jimeng/config
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"config": {
|
||||
"access_key": "your_access_key",
|
||||
"secret_key": "your_secret_key",
|
||||
"power": {
|
||||
"text_to_image": 10,
|
||||
"image_to_image": 15,
|
||||
"image_edit": 20,
|
||||
"image_effects": 25,
|
||||
"text_to_video": 30,
|
||||
"image_to_video": 35
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 测试连接
|
||||
|
||||
```
|
||||
POST /api/admin/jimeng/config/test
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"config": {
|
||||
"access_key": "your_access_key",
|
||||
"secret_key": "your_secret_key"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 前端页面
|
||||
|
||||
### 访问路径
|
||||
|
||||
管理后台 -> 即梦 AI -> 配置设置
|
||||
|
||||
### 页面功能
|
||||
|
||||
1. **秘钥配置标签页**
|
||||
|
||||
- AccessKey 输入框(密码模式)
|
||||
- SecretKey 输入框(密码模式)
|
||||
- 测试连接按钮
|
||||
|
||||
2. **算力配置标签页**
|
||||
|
||||
- 各种任务类型的算力消耗配置
|
||||
- 数字输入框,支持 1-100 范围
|
||||
- 提示信息说明
|
||||
|
||||
3. **操作按钮**
|
||||
- 保存配置
|
||||
- 重置配置
|
||||
|
||||
## 配置存储
|
||||
|
||||
配置存储在数据库的`config`表中:
|
||||
|
||||
- 配置键:`jimeng`
|
||||
- 配置值:JSON 格式的即梦 AI 配置
|
||||
|
||||
## 默认配置
|
||||
|
||||
如果配置不存在,系统会使用以下默认值:
|
||||
|
||||
```json
|
||||
{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
"power": {
|
||||
"text_to_image": 10,
|
||||
"image_to_image": 15,
|
||||
"image_edit": 20,
|
||||
"image_effects": 25,
|
||||
"text_to_video": 30,
|
||||
"image_to_video": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 使用流程
|
||||
|
||||
1. **初始配置**
|
||||
|
||||
- 访问管理后台即梦 AI 配置页面
|
||||
- 填写 AccessKey 和 SecretKey
|
||||
- 点击"测试连接"验证配置
|
||||
- 调整各功能算力消耗
|
||||
- 保存配置
|
||||
|
||||
2. **配置更新**
|
||||
|
||||
- 修改需要更新的配置项
|
||||
- 保存配置
|
||||
- 配置立即生效
|
||||
|
||||
3. **故障排查**
|
||||
- 使用"测试连接"功能验证 API 密钥
|
||||
- 检查配置是否正确保存
|
||||
- 查看服务日志
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **权限要求**
|
||||
|
||||
- 只有管理员可以访问配置页面
|
||||
- 需要有效的管理员登录会话
|
||||
|
||||
2. **配置验证**
|
||||
|
||||
- AccessKey 和 SecretKey 不能为空
|
||||
- 算力消耗必须大于 0
|
||||
- 建议先测试连接再保存配置
|
||||
|
||||
3. **服务影响**
|
||||
- 配置更新不会影响正在进行的任务
|
||||
- 新任务会使用更新后的配置
|
||||
- 客户端配置会在下次请求时更新
|
||||
|
||||
## 错误处理
|
||||
|
||||
1. **配置加载失败**
|
||||
|
||||
- 使用默认配置
|
||||
- 记录错误日志
|
||||
|
||||
2. **连接测试失败**
|
||||
|
||||
- 显示具体错误信息
|
||||
- 建议检查 API 密钥
|
||||
|
||||
3. **配置保存失败**
|
||||
- 显示错误信息
|
||||
- 保留原有配置
|
||||
|
||||
## 开发说明
|
||||
|
||||
### 后端文件
|
||||
|
||||
- `api/handler/admin/jimeng_handler.go` - 配置管理 API
|
||||
- `api/service/jimeng/service.go` - 配置服务逻辑
|
||||
- `api/core/types/jimeng.go` - 配置类型定义
|
||||
|
||||
### 前端文件
|
||||
|
||||
- `web/src/views/admin/jimeng/JimengSetting.vue` - 配置页面
|
||||
|
||||
### 数据库
|
||||
|
||||
- `config`表存储配置信息
|
||||
- 配置键:`jimeng`
|
||||
- 配置值:JSON 格式
|
||||
1
api/.gitignore
vendored
1
api/.gitignore
vendored
@@ -17,5 +17,6 @@ bin
|
||||
data
|
||||
config.toml
|
||||
static/upload
|
||||
static/audio
|
||||
storage.json
|
||||
res/certs/wechat/apiclient_key.pem
|
||||
|
||||
@@ -27,13 +27,58 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"golang.org/x/image/webp"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthConfig 定义授权配置
|
||||
type AuthConfig struct {
|
||||
ExactPaths map[string]bool // 精确匹配的路径
|
||||
PrefixPaths map[string]bool // 前缀匹配的路径
|
||||
}
|
||||
|
||||
var authConfig = &AuthConfig{
|
||||
ExactPaths: map[string]bool{
|
||||
"/api/user/login": false,
|
||||
"/api/user/logout": false,
|
||||
"/api/user/resetPass": false,
|
||||
"/api/user/register": false,
|
||||
"/api/admin/login": false,
|
||||
"/api/admin/logout": false,
|
||||
"/api/admin/login/captcha": false,
|
||||
"/api/app/list": false,
|
||||
"/api/app/type/list": false,
|
||||
"/api/app/list/user": false,
|
||||
"/api/model/list": false,
|
||||
"/api/mj/imgWall": false,
|
||||
"/api/mj/notify": false,
|
||||
"/api/invite/hits": false,
|
||||
"/api/sd/imgWall": false,
|
||||
"/api/dall/imgWall": false,
|
||||
"/api/product/list": false,
|
||||
"/api/menu/list": false,
|
||||
"/api/markMap/client": false,
|
||||
"/api/payment/doPay": false,
|
||||
"/api/payment/payWays": false,
|
||||
"/api/download": false,
|
||||
"/api/dall/models": false,
|
||||
},
|
||||
PrefixPaths: map[string]bool{
|
||||
"/api/test/": false,
|
||||
"/api/payment/notify/": false,
|
||||
"/api/user/clogin": false,
|
||||
"/api/config/": false,
|
||||
"/api/function/": false,
|
||||
"/api/sms/": false,
|
||||
"/api/captcha/": false,
|
||||
"/static/": false,
|
||||
},
|
||||
}
|
||||
|
||||
type AppServer struct {
|
||||
Debug bool
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
SysConfig *types.SystemConfig // system config cache
|
||||
@@ -43,7 +88,6 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
Debug: false,
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
}
|
||||
@@ -60,16 +104,100 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||
}
|
||||
|
||||
func (s *AppServer) Run(db *gorm.DB) error {
|
||||
|
||||
// 重命名 config 表字段
|
||||
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||
}
|
||||
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||
}
|
||||
if db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||
db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||
}
|
||||
|
||||
// load system configs
|
||||
var sysConfig model.Config
|
||||
err := db.Where("marker", "system").First(&sysConfig).Error
|
||||
err := db.Where("name", "system").First(&sysConfig).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load system config: %v", err)
|
||||
}
|
||||
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &s.SysConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode system config: %v", err)
|
||||
}
|
||||
|
||||
// 迁移数据表
|
||||
logger.Info("Migrating database tables...")
|
||||
db.AutoMigrate(
|
||||
&model.ChatItem{},
|
||||
&model.ChatMessage{},
|
||||
&model.ChatRole{},
|
||||
&model.ChatModel{},
|
||||
&model.InviteCode{},
|
||||
&model.InviteLog{},
|
||||
&model.Menu{},
|
||||
&model.Order{},
|
||||
&model.Product{},
|
||||
&model.User{},
|
||||
&model.Function{},
|
||||
&model.File{},
|
||||
&model.Redeem{},
|
||||
&model.Config{},
|
||||
&model.ApiKey{},
|
||||
&model.AdminUser{},
|
||||
&model.AppType{},
|
||||
&model.SdJob{},
|
||||
&model.SunoJob{},
|
||||
&model.PowerLog{},
|
||||
&model.VideoJob{},
|
||||
&model.MidJourneyJob{},
|
||||
&model.UserLoginLog{},
|
||||
&model.DallJob{},
|
||||
&model.JimengJob{},
|
||||
)
|
||||
// 手动删除字段
|
||||
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||
db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
||||
db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
||||
db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
||||
db.Migrator().DropColumn(&model.User{}, "chat_config")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
||||
db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
||||
db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
||||
}
|
||||
|
||||
logger.Info("Database tables migrated successfully")
|
||||
|
||||
// 统计安装信息
|
||||
go func() {
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
|
||||
timestamp := time.Now().Unix()
|
||||
product := "geekai-plus"
|
||||
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
||||
sign := utils.Sha256(signStr)
|
||||
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
||||
if err != nil {
|
||||
logger.Errorf("register install info failed: %v", err)
|
||||
} else {
|
||||
logger.Debugf("register install info success: %v", resp.String())
|
||||
}
|
||||
}
|
||||
}()
|
||||
logger.Infof("http://%s", s.Config.Listen)
|
||||
return s.Engine.Run(s.Config.Listen)
|
||||
}
|
||||
@@ -93,20 +221,24 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// 设置允许的请求源
|
||||
if origin != "" {
|
||||
// 设置允许的请求源
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
c.Header("Access-Control-Max-Age", "172800")
|
||||
//允许客户端传递校验信息比如 cookie (重要)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
} else {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
c.Header("Access-Control-Max-Age", "172800")
|
||||
//允许客户端传递校验信息比如 cookie (重要)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if method == http.MethodOptions {
|
||||
c.JSON(http.StatusOK, "ok!")
|
||||
}
|
||||
@@ -124,6 +256,11 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
// 用户授权验证
|
||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !needLogin(c) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
var tokenString string
|
||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||
@@ -142,18 +279,13 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
if needLogin(c) {
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
} else { // 直接放行
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
if isAdminApi {
|
||||
@@ -164,21 +296,21 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
|
||||
})
|
||||
|
||||
if err != nil && needLogin(c) {
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid && needLogin(c) {
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "Token is invalid")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "Token is expired")
|
||||
c.Abort()
|
||||
return
|
||||
@@ -188,57 +320,48 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
if isAdminApi {
|
||||
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
||||
}
|
||||
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
|
||||
if _, err := client.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "Token is not found in redis")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(types.LoginUserID, claims["user_id"])
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func needLogin(c *gin.Context) bool {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/user/logout" ||
|
||||
c.Request.URL.Path == "/api/user/resetPass" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/admin/logout" ||
|
||||
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
c.Request.URL.Path == "/api/chat/history" ||
|
||||
c.Request.URL.Path == "/api/chat/detail" ||
|
||||
c.Request.URL.Path == "/api/chat/list" ||
|
||||
c.Request.URL.Path == "/api/app/list" ||
|
||||
c.Request.URL.Path == "/api/app/type/list" ||
|
||||
c.Request.URL.Path == "/api/app/list/user" ||
|
||||
c.Request.URL.Path == "/api/model/list" ||
|
||||
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||
c.Request.URL.Path == "/api/mj/notify" ||
|
||||
c.Request.URL.Path == "/api/invite/hits" ||
|
||||
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||
c.Request.URL.Path == "/api/product/list" ||
|
||||
c.Request.URL.Path == "/api/menu/list" ||
|
||||
c.Request.URL.Path == "/api/markMap/client" ||
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
c.Request.URL.Path == "/api/download" ||
|
||||
c.Request.URL.Path == "/api/dall/models" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 如果不是 API 路径,不需要登录
|
||||
if !strings.HasPrefix(path, "/api") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查精确匹配的路径
|
||||
if skip, exists := authConfig.ExactPaths[path]; exists {
|
||||
return skip
|
||||
}
|
||||
|
||||
// 检查前缀匹配的路径
|
||||
for prefix, skip := range authConfig.PrefixPaths {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 跳过授权
|
||||
func (s *AppServer) SkipAuth(url string, prefix bool) {
|
||||
if prefix {
|
||||
authConfig.PrefixPaths[url] = false
|
||||
} else {
|
||||
authConfig.ExactPaths[url] = false
|
||||
}
|
||||
}
|
||||
|
||||
// 统一参数处理
|
||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
@@ -9,20 +9,20 @@ package types
|
||||
|
||||
// ApiRequest API 请求实体
|
||||
type ApiRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Messages []interface{} `json:"messages,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||
ResponseFormat interface{} `json:"response_format,omitempty"` // 响应格式
|
||||
Model string `json:"model,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Messages []any `json:"messages,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []any `json:"functions,omitempty"` // 兼容中转平台
|
||||
ResponseFormat any `json:"response_format,omitempty"` // 响应格式
|
||||
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
|
||||
Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
|
||||
Input map[string]any `json:"input,omitempty"` //兼容阿里通义千问
|
||||
Parameters map[string]any `json:"parameters,omitempty"` //兼容阿里通义千问
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -41,27 +41,17 @@ type ChoiceItem struct {
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Content interface{} `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FunctionCall struct {
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Content any `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
} `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatSession 聊天会话对象
|
||||
type ChatSession struct {
|
||||
UserId uint `json:"user_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model ChatModel `json:"model"` // GPT 模型
|
||||
Start int64 `json:"start"` // 开始请求时间戳
|
||||
Tools []int `json:"tools"` // 工具函数列表
|
||||
Stream bool `json:"stream"` // 是否采用流式输出
|
||||
}
|
||||
|
||||
type ChatModel struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -69,6 +59,8 @@ type ChatModel struct {
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Description string `json:"description"` //模型描述
|
||||
Category string `json:"category"` //模型类别
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id"` // 绑定 API KEY
|
||||
}
|
||||
|
||||
@@ -43,9 +43,10 @@ type SmtpConfig struct {
|
||||
}
|
||||
|
||||
type ApiConfig struct {
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
JimengConfig JimengConfig // 即梦AI配置
|
||||
}
|
||||
|
||||
type AlipayConfig struct {
|
||||
@@ -144,14 +145,15 @@ type SystemConfig struct {
|
||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
||||
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
|
||||
|
||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||
|
||||
@@ -161,13 +163,15 @@ type SystemConfig struct {
|
||||
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||
MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo
|
||||
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
ICP string `json:"icp"` // ICP 备案号
|
||||
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
DefaultNickname string `json:"default_nickname"` // 默认昵称
|
||||
ICP string `json:"icp"` // ICP 备案号
|
||||
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
||||
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
|
||||
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
|
||||
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB
|
||||
|
||||
}
|
||||
|
||||
18
api/core/types/jimeng.go
Normal file
18
api/core/types/jimeng.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package types
|
||||
|
||||
// JimengConfig 即梦AI配置
|
||||
type JimengConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Power JimengPower `json:"power"`
|
||||
}
|
||||
|
||||
// JimengPower 即梦AI算力配置
|
||||
type JimengPower struct {
|
||||
TextToImage int `json:"text_to_image"`
|
||||
ImageToImage int `json:"image_to_image"`
|
||||
ImageEdit int `json:"image_edit"`
|
||||
ImageEffects int `json:"image_effects"`
|
||||
TextToVideo int `json:"text_to_video"`
|
||||
ImageToVideo int `json:"image_to_video"`
|
||||
}
|
||||
@@ -16,7 +16,7 @@ type MKey interface {
|
||||
string | int | uint
|
||||
}
|
||||
type MValue interface {
|
||||
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
||||
*WsClient | context.CancelFunc | []any
|
||||
}
|
||||
type LMap[K MKey, T MValue] struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
@@ -26,7 +26,6 @@ const (
|
||||
type MjTask struct {
|
||||
Id uint `json:"id"` // 任务ID
|
||||
TaskId string `json:"task_id"` // 中转任务ID
|
||||
ClientId string `json:"client_id"`
|
||||
ImgArr []string `json:"img_arr"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -44,7 +43,6 @@ type MjTask struct {
|
||||
type SdTask struct {
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
Type TaskType `json:"type"`
|
||||
ClientId string `json:"client_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Params SdTaskParams `json:"params"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
@@ -52,7 +50,6 @@ type SdTask struct {
|
||||
}
|
||||
|
||||
type SdTaskParams struct {
|
||||
ClientId string `json:"client_id"` // 客户端ID
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||
@@ -73,22 +70,20 @@ type SdTaskParams struct {
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
ModelId uint `json:"model_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
Power int `json:"power"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
ModelId uint `json:"model_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
Power int `json:"power"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -96,7 +91,8 @@ type SunoTask struct {
|
||||
Title string `json:"title"`
|
||||
RefTaskId string `json:"ref_task_id,omitempty"`
|
||||
RefSongId string `json:"ref_song_id,omitempty"`
|
||||
Prompt string `json:"prompt"` // 提示词/歌词
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Lyrics string `json:"lyrics,omitempty"` // 歌词
|
||||
Tags string `json:"tags"`
|
||||
Model string `json:"model"`
|
||||
Instrumental bool `json:"instrumental"` // 是否纯音乐
|
||||
@@ -109,21 +105,21 @@ const (
|
||||
VideoLuma = "luma"
|
||||
VideoRunway = "runway"
|
||||
VideoCog = "cog"
|
||||
VideoKeLing = "keling"
|
||||
)
|
||||
|
||||
type VideoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
Params interface{} `json:"params"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
}
|
||||
|
||||
type VideoParams struct {
|
||||
type LumaVideoParams struct {
|
||||
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
|
||||
Loop bool `json:"loop"` // 是否循环参考图
|
||||
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
|
||||
@@ -133,3 +129,33 @@ type VideoParams struct {
|
||||
Style string `json:"style"` // 风格
|
||||
Duration int `json:"duration"` // 视频时长(秒)
|
||||
}
|
||||
|
||||
type KeLingVideoParams struct {
|
||||
TaskType string `json:"task_type"` // 任务类型: text2video/image2video
|
||||
Model string `json:"model"` // 模型: default/anime
|
||||
Prompt string `json:"prompt"` // 视频描述
|
||||
NegPrompt string `json:"negative_prompt"` // 负面提示词
|
||||
CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1)
|
||||
Mode string `json:"mode"` // 生成模式: std/pro
|
||||
AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1
|
||||
Duration string `json:"duration"` // 视频时长: 5/10
|
||||
CameraControl CameraControl `json:"camera_control"` // 摄像机控制
|
||||
Image string `json:"image"` // 参考图片URL(image2video)
|
||||
ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video)
|
||||
}
|
||||
|
||||
// CameraControl 摄像机控制
|
||||
type CameraControl struct {
|
||||
Type string `json:"type"` // 控制类型: simple/down_back/forward_up/right_turn_forward/left_turn_forward
|
||||
Config CameraConfig `json:"config"` // 控制参数(仅simple类型时使用)
|
||||
}
|
||||
|
||||
// CameraConfig 摄像机参数
|
||||
type CameraConfig struct {
|
||||
Horizontal int `json:"horizontal"` // 水平移动(-10到10)
|
||||
Vertical int `json:"vertical"` // 垂直移动(-10到10)
|
||||
Pan int `json:"pan"` // 左右旋转(-10到10)
|
||||
Tilt int `json:"tilt"` // 上下旋转(-10到10)
|
||||
Roll int `json:"roll"` // 横向翻转(-10到10)
|
||||
Zoom int `json:"zoom"` // 镜头缩放(-10到10)
|
||||
}
|
||||
|
||||
@@ -34,13 +34,14 @@ const (
|
||||
MsgTypeErr = WsMsgType("error")
|
||||
MsgTypePing = WsMsgType("ping") // 心跳消息
|
||||
|
||||
ChPing = WsChannel("ping")
|
||||
ChChat = WsChannel("chat")
|
||||
ChMj = WsChannel("mj")
|
||||
ChSd = WsChannel("sd")
|
||||
ChDall = WsChannel("dall")
|
||||
ChSuno = WsChannel("suno")
|
||||
ChLuma = WsChannel("luma")
|
||||
ChPing = WsChannel("ping")
|
||||
ChChat = WsChannel("chat")
|
||||
ChMj = WsChannel("mj")
|
||||
ChSd = WsChannel("sd")
|
||||
ChDall = WsChannel("dall")
|
||||
ChSuno = WsChannel("suno")
|
||||
ChLuma = WsChannel("luma")
|
||||
ChKeLing = WsChannel("keling")
|
||||
)
|
||||
|
||||
// InputMessage 对话输入消息结构
|
||||
|
||||
19
api/go.mod
19
api/go.mod
@@ -18,6 +18,7 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23
|
||||
go.uber.org/zap v1.23.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/driver/mysql v1.4.7
|
||||
@@ -27,8 +28,10 @@ require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
||||
|
||||
require (
|
||||
github.com/go-pay/gopay v1.5.101
|
||||
github.com/go-rod/rod v0.116.2
|
||||
github.com/google/go-tika v0.3.1
|
||||
github.com/microcosm-cc/bluemonday v1.0.26
|
||||
github.com/sashabaranov/go-openai v1.38.1
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.3.1
|
||||
github.com/syndtr/goleveldb v1.0.0
|
||||
@@ -43,15 +46,15 @@ require (
|
||||
github.com/go-pay/util v0.0.2 // indirect
|
||||
github.com/go-pay/xlog v0.0.2 // indirect
|
||||
github.com/go-pay/xtime v0.0.2 // indirect
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/gorilla/css v1.0.0 // indirect
|
||||
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99 // indirect
|
||||
github.com/howeyc/fsnotify v0.9.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a // indirect
|
||||
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||
github.com/ysmood/fetchup v0.3.0 // indirect
|
||||
github.com/ysmood/goob v0.4.0 // indirect
|
||||
github.com/ysmood/got v0.40.0 // indirect
|
||||
github.com/ysmood/gson v0.7.3 // indirect
|
||||
github.com/ysmood/leakless v0.9.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
)
|
||||
@@ -76,7 +79,7 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/klauspost/compress v1.16.7 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
@@ -118,7 +121,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/fx v1.19.3
|
||||
go.uber.org/multierr v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.7.0 // indirect
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
gorm.io/gorm v1.25.1
|
||||
|
||||
109
api/go.sum
109
api/go.sum
@@ -1,3 +1,5 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I=
|
||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 h1:cKNFQmeCQFN0WNfjScKoVrGi7vXxTVbkCvCqSrOf+P4=
|
||||
@@ -6,6 +8,7 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||
@@ -13,11 +16,13 @@ github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -28,6 +33,8 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
|
||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
@@ -73,6 +80,8 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
|
||||
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
@@ -82,11 +91,27 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
|
||||
@@ -100,15 +125,11 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
|
||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99 h1:A6qlLfihaWef15viqtecCz4XknZcgjgD7mEuhu7bHEc=
|
||||
github.com/gravityblast/fresh v0.0.0-20240621171608-8d1fef547a99/go.mod h1:ukFDwXV66bGV7JnfyxFKuKiVp4zH4orBKXML+VCSrhI=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/howeyc/fsnotify v0.9.0 h1:0gtV5JmOKH4A8SsFxG2BczSeXWWPvcMT0euZt5gDAxY=
|
||||
github.com/howeyc/fsnotify v0.9.0/go.mod h1:41HzSPxBGeFRQKEEwgh49TRw/nKBsYZ2cF1OzPjSJsA=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I=
|
||||
github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI=
|
||||
@@ -117,8 +138,11 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
|
||||
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
|
||||
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
@@ -129,6 +153,7 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
@@ -141,9 +166,6 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs=
|
||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
|
||||
@@ -177,10 +199,6 @@ github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:Ff
|
||||
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a h1:Tg4E4cXPZSZyd3H1tJlYo6ZreXV0ZJvE/lorNqyw1AU=
|
||||
github.com/pilu/config v0.0.0-20131214182432-3eb99e6c0b9a/go.mod h1:9Or9aIl95Kp43zONcHd5tLZGKXb9iLx0pZjau0uJ5zg=
|
||||
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99 h1:+X7Gb40b5Bl3v5+3MiGK8Jhemjp65MHc+nkVCfq1Yfc=
|
||||
github.com/pilu/fresh v0.0.0-20240621171608-8d1fef547a99/go.mod h1:2LLTtftTZSdAPR/iVyennXZDLZOYzyDn+T0qEKJ8eSw=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -188,6 +206,7 @@ github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 h1:IFhPCcB0/H
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk=
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1 h1:UoQv7fBKtzAiD1qZPIvTy62Se48YLKxcCYP9nAwWMa0=
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYXOwye868w=
|
||||
@@ -203,6 +222,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8=
|
||||
github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||
@@ -215,6 +236,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -237,8 +259,24 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK
|
||||
github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
||||
github.com/ysmood/fetchup v0.3.0 h1:UhYz9xnLEVn2ukSuK3KCgcznWpHMdrmbsPpllcylyu8=
|
||||
github.com/ysmood/fetchup v0.3.0/go.mod h1:hbysoq65PXL0NQeNzUczNYIKpwpkwFL4LXMDEvIQq9A=
|
||||
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
|
||||
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
|
||||
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
|
||||
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
|
||||
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
|
||||
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
|
||||
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
|
||||
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
|
||||
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
|
||||
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
|
||||
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
|
||||
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
@@ -253,8 +291,8 @@ go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
|
||||
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec=
|
||||
go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
|
||||
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
|
||||
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -268,15 +306,23 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
@@ -287,12 +333,15 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -302,7 +351,6 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -332,16 +380,39 @@ golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
@@ -354,6 +425,10 @@ gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYs
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
@@ -363,4 +438,6 @@ gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8o
|
||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
|
||||
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -78,10 +78,10 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
typeIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
if v.ModelId > 0 {
|
||||
modelIds = append(modelIds, v.ModelId)
|
||||
modelIds = append(modelIds, int(v.ModelId))
|
||||
}
|
||||
if v.Tid > 0 {
|
||||
typeIds = append(typeIds, v.Tid)
|
||||
typeIds = append(typeIds, int(v.Tid))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,8 +113,8 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
role.Id = v.Id
|
||||
role.CreatedAt = v.CreatedAt.Unix()
|
||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||
role.ModelName = modelNameMap[role.ModelId]
|
||||
role.TypeName = typeNameMap[role.Tid]
|
||||
role.ModelName = modelNameMap[int(role.ModelId)]
|
||||
role.TypeName = typeNameMap[int(role.Tid)]
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -189,7 +190,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
for _, item := range items {
|
||||
list = append(list, chatMessageVo{
|
||||
Id: item.Id,
|
||||
Id: uint(item.Id),
|
||||
UserId: item.UserId,
|
||||
Username: userMap[item.UserId],
|
||||
Content: item.Content,
|
||||
@@ -208,20 +209,28 @@ func (h *ChatHandler) Messages(c *gin.Context) {
|
||||
func (h *ChatHandler) History(c *gin.Context) {
|
||||
chatId := c.Query("chat_id") // 会话 ID
|
||||
var items []model.ChatMessage
|
||||
var messages = make([]vo.HistoryMessage, 0)
|
||||
var messages = make([]vo.ChatMessage, 0)
|
||||
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No history message")
|
||||
return
|
||||
} else {
|
||||
for _, item := range items {
|
||||
var v vo.HistoryMessage
|
||||
var v vo.ChatMessage
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 解析内容
|
||||
var content vo.MsgContent
|
||||
err = utils.JsonDecode(item.Content, &content)
|
||||
if err != nil {
|
||||
content.Text = item.Content
|
||||
}
|
||||
v.Content = content
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
v.UpdatedAt = item.UpdatedAt.Unix()
|
||||
if err == nil {
|
||||
messages = append(messages, v)
|
||||
}
|
||||
messages = append(messages, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,20 +30,23 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||
|
||||
func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Enabled bool `json:"enabled"`
|
||||
SortNum int `json:"sort_num"`
|
||||
Open bool `json:"open"`
|
||||
Platform string `json:"platform"`
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Type string `json:"type"`
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Enabled bool `json:"enabled"`
|
||||
SortNum int `json:"sort_num"`
|
||||
Open bool `json:"open"`
|
||||
Platform string `json:"platform"`
|
||||
Power int `json:"power"`
|
||||
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||
Desc string `json:"desc"` //模型描述
|
||||
Tag string `json:"tag"` //模型标签
|
||||
Temperature float32 `json:"temperature"` // 模型温度
|
||||
KeyId int `json:"key_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Type string `json:"type"`
|
||||
Options map[string]string `json:"options"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -59,14 +62,16 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
item.Name = data.Name
|
||||
item.Value = data.Value
|
||||
item.Enabled = data.Enabled
|
||||
item.SortNum = data.SortNum
|
||||
item.Open = data.Open
|
||||
item.Power = data.Power
|
||||
item.MaxTokens = data.MaxTokens
|
||||
item.MaxContext = data.MaxContext
|
||||
item.Desc = data.Desc
|
||||
item.Tag = data.Tag
|
||||
item.Temperature = data.Temperature
|
||||
item.KeyId = data.KeyId
|
||||
item.KeyId = uint(data.KeyId)
|
||||
item.Type = data.Type
|
||||
item.Options = utils.JsonEncode(data.Options)
|
||||
var res *gorm.DB
|
||||
if data.Id > 0 {
|
||||
res = h.DB.Save(&item)
|
||||
@@ -95,12 +100,16 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
enable := h.GetBool(c, "enable")
|
||||
name := h.GetTrim(c, "name")
|
||||
modelType := h.GetTrim(c, "type")
|
||||
if enable {
|
||||
session = session.Where("enabled", enable)
|
||||
}
|
||||
if name != "" {
|
||||
session = session.Where("name LIKE ?", name+"%")
|
||||
}
|
||||
if modelType != "" {
|
||||
session = session.Where("type", modelType)
|
||||
}
|
||||
var items []model.ChatModel
|
||||
var cms = make([]vo.ChatModel, 0)
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
@@ -112,7 +121,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
// initialize key name
|
||||
keyIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
keyIds = append(keyIds, v.KeyId)
|
||||
keyIds = append(keyIds, int(v.KeyId))
|
||||
}
|
||||
var keys []model.ApiKey
|
||||
keyMap := make(map[uint]string)
|
||||
|
||||
@@ -48,6 +48,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
logger.Errorf("Update config failed: %v", err)
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
@@ -58,16 +59,22 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果要启用图形验证码功能,则检查是否配置了 API 服务
|
||||
if data.Config.EnabledVerify && h.App.Config.ApiConfig.AppId == "" {
|
||||
resp.ERROR(c, "启用验证码服务需要先配置 GeekAI 官方 API 服务 AppId 和 Token")
|
||||
return
|
||||
}
|
||||
|
||||
value := utils.JsonEncode(&data.Config)
|
||||
config := model.Config{Key: data.Key, Config: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
|
||||
config := model.Config{Name: data.Key, Value: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Name: data.Key})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Config = value
|
||||
config.Value = value
|
||||
res := h.DB.Updates(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
@@ -76,16 +83,16 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
|
||||
// update config cache for AppServer
|
||||
var cfg model.Config
|
||||
h.DB.Where("marker", data.Key).First(&cfg)
|
||||
h.DB.Where("name", data.Key).First(&cfg)
|
||||
var err error
|
||||
if data.Key == "system" {
|
||||
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
||||
err = utils.JsonDecode(cfg.Value, &h.App.SysConfig)
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.Infof("Update AppServer's config successfully: %v", config.Config)
|
||||
logger.Infof("Update AppServer's config successfully: %v", config.Value)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, config)
|
||||
@@ -95,14 +102,14 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
res := h.DB.Where("name", key).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]interface{}
|
||||
err := utils.JsonDecode(config.Config, &value)
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -132,7 +139,8 @@ func (h *ConfigHandler) Active(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, info.HostID)
|
||||
resp.SUCCESS(c)
|
||||
|
||||
}
|
||||
|
||||
// GetLicense 获取 License 信息
|
||||
@@ -144,7 +152,6 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
// FixData 修复数据
|
||||
func (h *ConfigHandler) FixData(c *gin.Context) {
|
||||
resp.ERROR(c, "当前升级版本没有数据需要修正!")
|
||||
return
|
||||
//var fixed bool
|
||||
//version := "data_fix_4.1.4"
|
||||
//err := h.levelDB.Get(version, &fixed)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -189,11 +190,10 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = "mid-journey"
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
userId = int(job.UserId)
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "sd":
|
||||
var job model.SdJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -205,11 +205,10 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = "stable-diffusion"
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
userId = int(job.UserId)
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "dall":
|
||||
var job model.DallJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -225,14 +224,13 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if progress != 100 {
|
||||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||
err := h.userService.IncreasePower(uint(userId), power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: md,
|
||||
Remark: remark,
|
||||
|
||||
296
api/handler/admin/jimeng_handler.go
Normal file
296
api/handler/admin/jimeng_handler.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminJimengHandler 管理后台即梦AI处理器
|
||||
type AdminJimengHandler struct {
|
||||
handler.BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||
return &AdminJimengHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册即梦AI管理后台路由
|
||||
func (h *AdminJimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/admin/jimeng/")
|
||||
rg.GET("/jobs", h.Jobs)
|
||||
rg.GET("/jobs/:id", h.JobDetail)
|
||||
rg.POST("/jobs/remove", h.BatchRemove)
|
||||
rg.GET("/stats", h.Stats)
|
||||
rg.GET("/config", h.GetConfig)
|
||||
rg.POST("/config/update", h.UpdateConfig)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *AdminJimengHandler) Jobs(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
taskType := h.GetTrim(c, "type")
|
||||
status := h.GetTrim(c, "status")
|
||||
|
||||
var tasks []model.JimengJob
|
||||
var total int64
|
||||
|
||||
session := h.DB.Model(&model.JimengJob{})
|
||||
|
||||
// 构建查询条件
|
||||
if userId > 0 {
|
||||
session = session.Where("user_id = ?", userId)
|
||||
}
|
||||
if taskType != "" {
|
||||
session = session.Where("type = ?", taskType)
|
||||
}
|
||||
if status != "" {
|
||||
session = session.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := session.Count(&total).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务数量失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = session.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&tasks).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"jobs": tasks,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// JobDetail 获取任务详情
|
||||
func (h *AdminJimengHandler) JobDetail(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var job model.JimengJob
|
||||
err = h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除任务
|
||||
func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
var req struct {
|
||||
JobIds []uint `json:"job_ids" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var deletedCount int64 = 0
|
||||
for _, jobId := range req.JobIds {
|
||||
var job model.JimengJob
|
||||
err := h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
continue // 跳过不存在的
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
if job.Status != model.JMTaskStatusSuccess && job.Power > 0 {
|
||||
remark := fmt.Sprintf("任务未成功,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: remark,
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = tx.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
tx.Commit()
|
||||
deletedCount++
|
||||
if job.ImgURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
}
|
||||
if job.VideoURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||
if err != nil {
|
||||
logger.Error("remove video failed: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": "批量删除成功",
|
||||
"deleted_count": deletedCount,
|
||||
})
|
||||
}
|
||||
|
||||
// Stats 获取统计信息
|
||||
func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
type StatResult struct {
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := h.DB.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 整理统计数据
|
||||
result := gin.H{
|
||||
"totalTasks": int64(0),
|
||||
"completedTasks": int64(0),
|
||||
"processingTasks": int64(0),
|
||||
"failedTasks": int64(0),
|
||||
"pendingTasks": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
|
||||
switch stat.Status {
|
||||
case model.JMTaskStatusInQueue:
|
||||
result["pendingTasks"] = stat.Count
|
||||
case model.JMTaskStatusSuccess:
|
||||
result["completedTasks"] = stat.Count
|
||||
case model.JMTaskStatusGenerating:
|
||||
result["processingTasks"] = stat.Count
|
||||
case model.JMTaskStatusFailed:
|
||||
result["failedTasks"] = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, result)
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
|
||||
jimengConfig := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, jimengConfig)
|
||||
}
|
||||
|
||||
// UpdateConfig 更新即梦AI配置
|
||||
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
var req types.JimengConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if req.AccessKey == "" {
|
||||
resp.ERROR(c, "AccessKey不能为空")
|
||||
return
|
||||
}
|
||||
if req.SecretKey == "" {
|
||||
resp.ERROR(c, "SecretKey不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证算力配置
|
||||
if req.Power.TextToImage <= 0 {
|
||||
resp.ERROR(c, "文生图算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageToImage <= 0 {
|
||||
resp.ERROR(c, "图生图算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageEdit <= 0 {
|
||||
resp.ERROR(c, "图片编辑算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageEffects <= 0 {
|
||||
resp.ERROR(c, "图片特效算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.TextToVideo <= 0 {
|
||||
resp.ERROR(c, "文生视频算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageToVideo <= 0 {
|
||||
resp.ERROR(c, "图生视频算力必须大于0")
|
||||
return
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
tx := h.DB.Begin()
|
||||
value := utils.JsonEncode(&req)
|
||||
config := model.Config{Name: "jimeng", Value: value}
|
||||
|
||||
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "保存配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Value = value
|
||||
err = tx.Updates(&config).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新服务中的客户端配置
|
||||
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
|
||||
if updateErr != nil {
|
||||
resp.ERROR(c, updateErr.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -33,6 +34,7 @@ func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
||||
}
|
||||
|
||||
type mediaQuery struct {
|
||||
Type string `json:"type"` // 任务类型 luma, keling
|
||||
Prompt string `json:"prompt"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt []string `json:"created_at"`
|
||||
@@ -84,15 +86,15 @@ func (h *MediaHandler) SunoList(c *gin.Context) {
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
// LumaList Luma 视频任务列表
|
||||
func (h *MediaHandler) LumaList(c *gin.Context) {
|
||||
// Videos 视频任务列表
|
||||
func (h *MediaHandler) Videos(c *gin.Context) {
|
||||
var data mediaQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
session := h.DB.Session(&gorm.Session{}).Where("type", data.Type)
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
@@ -148,12 +150,12 @@ func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = "suno"
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
userId = int(job.UserId)
|
||||
remark = fmt.Sprintf("SUNO 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.AudioURL
|
||||
break
|
||||
case "luma":
|
||||
case "keling":
|
||||
var job model.VideoJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
@@ -164,21 +166,20 @@ func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
tx.Delete(&job)
|
||||
md = job.Type
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
userId = int(job.UserId)
|
||||
remark = fmt.Sprintf("LUMA 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.VideoURL
|
||||
if fileURL == "" {
|
||||
fileURL = job.WaterURL
|
||||
}
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if progress != 100 {
|
||||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||
err := h.userService.IncreasePower(uint(userId), power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: md,
|
||||
Remark: remark,
|
||||
|
||||
@@ -106,8 +106,8 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 设置响应头,告诉浏览器这是一个附件,需要下载
|
||||
c.Header("Content-Disposition", "attachment; filename=output.csv")
|
||||
c.Header("Content-Type", "text/csv")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=output.csv")
|
||||
c.Header("Prompt-Type", "text/csv")
|
||||
|
||||
// 创建一个 CSV writer
|
||||
writer := csv.NewWriter(c.Writer)
|
||||
|
||||
@@ -13,9 +13,10 @@ import (
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UploadHandler struct {
|
||||
@@ -28,14 +29,27 @@ func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderMan
|
||||
}
|
||||
|
||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
// 判断文件大小
|
||||
f, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 {
|
||||
resp.ERROR(c, "文件大小超过限制")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
userId := 0
|
||||
res := h.DB.Create(&model.File{
|
||||
UserId: userId,
|
||||
UserId: uint(userId),
|
||||
Name: file.Name,
|
||||
ObjKey: file.ObjKey,
|
||||
URL: file.URL,
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -177,6 +178,7 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
Power: data.Power,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode(data.ChatRoles),
|
||||
ChatConfig: "{}",
|
||||
ChatModels: utils.JsonEncode(data.ChatModels),
|
||||
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
||||
}
|
||||
@@ -320,3 +322,36 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, logs))
|
||||
}
|
||||
|
||||
// GenLoginLink 生成登录链接
|
||||
func (h *UserHandler) GenLoginLink(c *gin.Context) {
|
||||
id := c.Query("id")
|
||||
if id == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
if err := h.DB.Where("id = ?", id).First(&user).Error; err != nil {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
}
|
||||
// 保存到 redis
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, tokenString)
|
||||
}
|
||||
|
||||
@@ -21,26 +21,49 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ChatEventStart = "start"
|
||||
ChatEventEnd = "end"
|
||||
ChatEventError = "error"
|
||||
ChatEventMessageDelta = "message_delta"
|
||||
ChatEventTitle = "title"
|
||||
)
|
||||
|
||||
type ChatInput struct {
|
||||
UserId uint `json:"user_id"`
|
||||
RoleId uint `json:"role_id"`
|
||||
ModelId uint `json:"model_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Tools []uint `json:"tools"`
|
||||
Stream bool `json:"stream"`
|
||||
Files []vo.File `json:"files"`
|
||||
ChatModel model.ChatModel `json:"chat_model,omitempty"`
|
||||
ChatRole model.ChatRole `json:"chat_role,omitempty"`
|
||||
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID,用于重新生成答案的时候过滤上下文
|
||||
}
|
||||
|
||||
type ChatHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
@@ -51,22 +74,83 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
||||
if !h.App.Debug {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Recover message from error: ", r)
|
||||
}
|
||||
}()
|
||||
// Chat 处理聊天请求
|
||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
var input ChatInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Prompt-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
// 这里做个全局的异常处理,防止整个请求异常,导致 SSE 连接断开
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("chat handler error: %v", err)
|
||||
pushMessage(c, ChatEventError, err)
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
// 使用旧的聊天数据覆盖模型和角色ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", input.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
input.ModelId = chat.ModelId
|
||||
input.RoleId = chat.RoleId
|
||||
}
|
||||
|
||||
// 验证聊天角色
|
||||
var chatRole model.ChatRole
|
||||
err := h.DB.First(&chatRole, input.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
|
||||
return
|
||||
}
|
||||
input.ChatRole = chatRole
|
||||
|
||||
// 获取模型信息
|
||||
var chatModel model.ChatModel
|
||||
err = h.DB.Where("id", input.ModelId).First(&chatModel).Error
|
||||
if err != nil || !chatModel.Enabled {
|
||||
pushMessage(c, ChatEventError, "当前AI模型暂未启用,请更换模型后再发起对话!")
|
||||
return
|
||||
}
|
||||
input.ChatModel = chatModel
|
||||
|
||||
// 发送消息
|
||||
err = h.sendMessage(ctx, input, c)
|
||||
if err != nil {
|
||||
pushMessage(c, ChatEventError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pushMessage(c, ChatEventEnd, "对话完成")
|
||||
}
|
||||
|
||||
func pushMessage(c *gin.Context, msgType string, content interface{}) {
|
||||
c.SSEvent("message", map[string]interface{}{
|
||||
"type": msgType,
|
||||
"body": content,
|
||||
})
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.Context) error {
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
|
||||
res := h.DB.Model(&model.User{}).First(&user, input.UserId)
|
||||
if res.Error != nil {
|
||||
return errors.New("未授权用户,您正在进行非法操作!")
|
||||
}
|
||||
@@ -77,12 +161,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
return errors.New("User 对象转换失败," + err.Error())
|
||||
}
|
||||
|
||||
if userVo.Status == false {
|
||||
if !userVo.Status {
|
||||
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
}
|
||||
|
||||
if userVo.Power < session.Model.Power {
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power)
|
||||
if userVo.Power < input.ChatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, input.ChatModel.Power)
|
||||
}
|
||||
|
||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||
@@ -90,31 +174,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||
if promptTokens > session.Model.MaxContext {
|
||||
promptTokens, _ := utils.CalcTokens(input.Prompt, input.ChatModel.Value)
|
||||
if promptTokens > input.ChatModel.MaxContext {
|
||||
|
||||
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: session.Model.Value,
|
||||
Stream: session.Stream,
|
||||
Temperature: session.Model.Temperature,
|
||||
Model: input.ChatModel.Value,
|
||||
Stream: input.Stream,
|
||||
Temperature: input.ChatModel.Temperature,
|
||||
}
|
||||
// 兼容 OpenAI 模型
|
||||
if strings.HasPrefix(session.Model.Value, "o1-") ||
|
||||
strings.HasPrefix(session.Model.Value, "o3-") ||
|
||||
strings.HasPrefix(session.Model.Value, "gpt") {
|
||||
utils.SendChunkMsg(ws, "> AI 正在思考...\n")
|
||||
req.MaxCompletionTokens = session.Model.MaxTokens
|
||||
session.Start = time.Now().Unix()
|
||||
if strings.HasPrefix(input.ChatModel.Value, "o1-") ||
|
||||
strings.HasPrefix(input.ChatModel.Value, "o3-") ||
|
||||
strings.HasPrefix(input.ChatModel.Value, "gpt") {
|
||||
req.MaxCompletionTokens = input.ChatModel.MaxTokens
|
||||
} else {
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
req.MaxTokens = input.ChatModel.MaxTokens
|
||||
}
|
||||
|
||||
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
if len(input.Tools) > 0 && !strings.HasPrefix(input.ChatModel.Value, "o1-") {
|
||||
var items []model.Function
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", input.Tools).Find(&items)
|
||||
if res.Error == nil {
|
||||
var tools = make([]types.Tool, 0)
|
||||
for _, v := range items {
|
||||
@@ -145,25 +227,27 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
// 加载聊天上下文
|
||||
chatCtx := make([]interface{}, 0)
|
||||
messages := make([]interface{}, 0)
|
||||
chatCtx := make([]any, 0)
|
||||
messages := make([]any, 0)
|
||||
if h.App.SysConfig.EnableContext {
|
||||
if h.ChatContexts.Has(session.ChatId) {
|
||||
messages = h.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
_ = utils.JsonDecode(role.Context, &messages)
|
||||
if h.App.SysConfig.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||
if res.Error == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
if msg.Type == types.ReplyMsg {
|
||||
ms.Role = "assistant"
|
||||
}
|
||||
chatCtx = append(chatCtx, ms)
|
||||
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
|
||||
if h.App.SysConfig.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
|
||||
if input.LastMsgId > 0 { // 重新生成逻辑
|
||||
dbSession = dbSession.Where("id < ?", input.LastMsgId)
|
||||
// 删除对应的聊天记录
|
||||
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
|
||||
}
|
||||
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
||||
if err == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
if msg.Type == types.ReplyMsg {
|
||||
ms.Role = "assistant"
|
||||
}
|
||||
chatCtx = append(chatCtx, ms)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -178,7 +262,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
v := messages[i]
|
||||
tks, _ = utils.CalcTokens(utils.JsonEncode(v), req.Model)
|
||||
// 上下文 token 超出了模型的最大上下文长度
|
||||
if tokens+tks >= session.Model.MaxContext {
|
||||
if tokens+tks >= input.ChatModel.MaxContext {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -190,78 +274,106 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
tokens += tks
|
||||
chatCtx = append(chatCtx, v)
|
||||
}
|
||||
|
||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||
}
|
||||
reqMgs := make([]interface{}, 0)
|
||||
reqMgs := make([]any, 0)
|
||||
|
||||
for i := len(chatCtx) - 1; i >= 0; i-- {
|
||||
reqMgs = append(reqMgs, chatCtx[i])
|
||||
}
|
||||
|
||||
fullPrompt := prompt
|
||||
text := prompt
|
||||
// extract files in prompt
|
||||
files := utils.ExtractFileURLs(prompt)
|
||||
logger.Debugf("detected FILES: %+v", files)
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
|
||||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
|
||||
strings.HasSuffix(session.Model.Value, "claude-3")) {
|
||||
contents := make([]string, 0)
|
||||
var file model.File
|
||||
for _, v := range files {
|
||||
h.DB.Where("url = ?", v).First(&file)
|
||||
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
} else {
|
||||
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
|
||||
}
|
||||
|
||||
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
|
||||
if tokens > session.Model.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
}
|
||||
logger.Debug("最终Prompt:", fullPrompt)
|
||||
|
||||
// extract images from prompt
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
fileContents := make([]string, 0) // 文件内容
|
||||
var finalPrompt = input.Prompt
|
||||
imgList := make([]any, 0)
|
||||
for _, file := range input.Files {
|
||||
logger.Debugf("detected file: %+v", file.URL)
|
||||
// 处理图片
|
||||
if isImageURL(file.URL) {
|
||||
imgList = append(imgList, gin.H{
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": v,
|
||||
"url": file.URL,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
modelValue := input.ChatModel.Value
|
||||
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) {
|
||||
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
continue
|
||||
} else {
|
||||
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
}
|
||||
}
|
||||
data = append(data, gin.H{
|
||||
"type": "text",
|
||||
"text": strings.TrimSpace(text),
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
|
||||
logger.Debugf("%+v", req.Messages)
|
||||
if len(fileContents) > 0 {
|
||||
finalPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(fileContents, "\n"), input.Prompt)
|
||||
tokens, _ := utils.CalcTokens(finalPrompt, req.Model)
|
||||
if tokens > input.ChatModel.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
} else {
|
||||
finalPrompt = input.Prompt
|
||||
}
|
||||
|
||||
return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws)
|
||||
if len(imgList) > 0 {
|
||||
imgList = append(imgList, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": input.Prompt,
|
||||
})
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": imgList,
|
||||
})
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": finalPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
return h.sendOpenAiMessage(req, userVo, ctx, input, c)
|
||||
}
|
||||
|
||||
// 判断一个 URL 是否图片链接
|
||||
func isImageURL(url string) bool {
|
||||
// 检查是否是有效的URL
|
||||
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(path.Ext(url))
|
||||
validImageExts := map[string]bool{
|
||||
".jpg": true,
|
||||
".jpeg": true,
|
||||
".png": true,
|
||||
".gif": true,
|
||||
".bmp": true,
|
||||
".webp": true,
|
||||
".svg": true,
|
||||
".ico": true,
|
||||
}
|
||||
|
||||
if !validImageExts[ext] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 发送HEAD请求检查Content-Type
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
resp, err := client.Head(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
return strings.HasPrefix(contentType, "image/")
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
@@ -330,15 +442,14 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
|
||||
// 发送请求到 OpenAI 服务器
|
||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if session.Model.KeyId > 0 {
|
||||
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if apiKey.Id == 0 {
|
||||
if input.ChatModel.KeyId > 0 {
|
||||
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey)
|
||||
} else { // use the last unused key
|
||||
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||
}
|
||||
|
||||
if apiKey.Id == 0 {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
@@ -349,8 +460,14 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
return nil, err
|
||||
}
|
||||
logger.Debugf("对话请求消息体:%+v", req)
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
var apiURL string
|
||||
p, _ := url.Parse(apiKey.ApiURL)
|
||||
// 如果设置的是 BASE_URL 没有路径,则添加 /v1/chat/completions
|
||||
if p.Path == "" {
|
||||
apiURL = fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
} else {
|
||||
apiURL = apiKey.ApiURL
|
||||
}
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
@@ -382,16 +499,16 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
}
|
||||
|
||||
// 扣减用户算力
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens int, replyTokens int) {
|
||||
power := 1
|
||||
if session.Model.Power > 0 {
|
||||
power = session.Model.Power
|
||||
if input.ChatModel.Power > 0 {
|
||||
power = input.ChatModel.Power
|
||||
}
|
||||
|
||||
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
|
||||
err := h.userService.DecreasePower(userVo.Id, power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: session.Model.Value,
|
||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
||||
Model: input.ChatModel.Value,
|
||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", input.ChatModel.Name, promptTokens, replyTokens),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
@@ -402,19 +519,11 @@ func (h *ChatHandler) saveChatHistory(
|
||||
req types.ApiRequest,
|
||||
usage Usage,
|
||||
message types.Message,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
input ChatInput,
|
||||
userVo vo.User,
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
|
||||
// 更新上下文消息
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx := req.Messages // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
var promptTokens, replyTokens, totalTokens int
|
||||
@@ -425,12 +534,15 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(usage.Prompt),
|
||||
UserId: userVo.Id,
|
||||
ChatId: input.ChatId,
|
||||
RoleId: input.RoleId,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: utils.JsonEncode(vo.MsgContent{
|
||||
Text: usage.Prompt,
|
||||
Files: input.Files,
|
||||
}),
|
||||
Tokens: promptTokens,
|
||||
TotalTokens: promptTokens,
|
||||
UseContext: true,
|
||||
@@ -453,12 +565,15 @@ func (h *ChatHandler) saveChatHistory(
|
||||
totalTokens = replyTokens + getTotalTokens(req)
|
||||
}
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: usage.Content,
|
||||
UserId: userVo.Id,
|
||||
ChatId: input.ChatId,
|
||||
RoleId: input.RoleId,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: input.ChatRole.Icon,
|
||||
Content: utils.JsonEncode(vo.MsgContent{
|
||||
Text: message.Content,
|
||||
Files: input.Files,
|
||||
}),
|
||||
Tokens: replyTokens,
|
||||
TotalTokens: totalTokens,
|
||||
UseContext: true,
|
||||
@@ -472,17 +587,17 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
if session.Model.Power > 0 {
|
||||
h.subUserPower(userVo, session, promptTokens, replyTokens)
|
||||
if input.ChatModel.Power > 0 {
|
||||
h.subUserPower(userVo, input, promptTokens, replyTokens)
|
||||
}
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
|
||||
err = h.DB.Where("chat_id = ?", input.ChatId).First(&chatItem).Error
|
||||
if err != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.ChatId = input.ChatId
|
||||
chatItem.UserId = userVo.Id
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
chatItem.RoleId = input.RoleId
|
||||
chatItem.ModelId = input.ModelId
|
||||
if utf8.RuneCountInString(usage.Prompt) > 30 {
|
||||
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
|
||||
} else {
|
||||
@@ -496,28 +611,320 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
}
|
||||
|
||||
// 将AI回复消息中生成的图片链接下载到本地
|
||||
func (h *ChatHandler) extractImgUrl(text string) string {
|
||||
pattern := `!\[([^\]]*)]\(([^)]+)\)`
|
||||
re := regexp.MustCompile(pattern)
|
||||
matches := re.FindAllStringSubmatch(text, -1)
|
||||
|
||||
// 下载图片并替换链接地址
|
||||
for _, match := range matches {
|
||||
imageURL := match[2]
|
||||
logger.Debug(imageURL)
|
||||
// 对于相同地址的图片,已经被替换了,就不再重复下载了
|
||||
if !strings.Contains(text, imageURL) {
|
||||
continue
|
||||
}
|
||||
|
||||
newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
continue
|
||||
}
|
||||
|
||||
text = strings.ReplaceAll(text, imageURL, newImgURL)
|
||||
// TextToSpeech 文本生成语音
|
||||
func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
var data struct {
|
||||
ModelId int `json:"model_id"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
textHash := utils.Sha256(fmt.Sprintf("%d/%s", data.ModelId, data.Text))
|
||||
audioFile := fmt.Sprintf("%s/audio", h.App.Config.StaticDir)
|
||||
if _, err := os.Stat(audioFile); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(audioFile, 0755); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
audioFile = fmt.Sprintf("%s/%s.mp3", audioFile, textHash)
|
||||
if _, err := os.Stat(audioFile); err == nil {
|
||||
// 设置响应头
|
||||
c.Header("Prompt-Type", "audio/mpeg")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
|
||||
c.File(audioFile)
|
||||
return
|
||||
}
|
||||
|
||||
// 查询模型
|
||||
var chatModel model.ChatModel
|
||||
err := h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "找不到语音模型")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用 DeepSeek 的 API 接口
|
||||
var apiKey model.ApiKey
|
||||
if chatModel.KeyId > 0 {
|
||||
h.DB.Where("id", chatModel.KeyId).First(&apiKey)
|
||||
}
|
||||
if apiKey.Id == 0 {
|
||||
h.DB.Where("type", "tts").Where("enabled", true).First(&apiKey)
|
||||
}
|
||||
if apiKey.Id == 0 {
|
||||
resp.ERROR(c, "no TTS API key, please import key")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("chatModel: %+v, apiKey: %+v", chatModel, apiKey)
|
||||
|
||||
// 调用 openai tts api
|
||||
config := openai.DefaultConfig(apiKey.Value)
|
||||
config.BaseURL = apiKey.ApiURL + "/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
voice := openai.VoiceAlloy
|
||||
var options map[string]string
|
||||
err = utils.JsonDecode(chatModel.Options, &options)
|
||||
if err == nil {
|
||||
voice = openai.SpeechVoice(options["voice"])
|
||||
}
|
||||
req := openai.CreateSpeechRequest{
|
||||
Model: openai.SpeechModel(chatModel.Value),
|
||||
Input: data.Text,
|
||||
Voice: voice,
|
||||
}
|
||||
|
||||
audioData, err := client.CreateSpeech(context.Background(), req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先将音频数据读取到内存
|
||||
audioBytes, err := io.ReadAll(audioData)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存到音频文件
|
||||
err = os.WriteFile(audioFile, audioBytes, 0644)
|
||||
if err != nil {
|
||||
logger.Error("failed to save audio file: ", err)
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Prompt-Type", "audio/mpeg")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
|
||||
|
||||
// 直接写入完整的音频数据到响应
|
||||
_, err = c.Writer.Write(audioBytes)
|
||||
if err != nil {
|
||||
logger.Error("写入音频数据到响应失败:", err)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// // OPenAI 消息发送实现
|
||||
// func (h *ChatHandler) sendOpenAiMessage(
|
||||
// req types.ApiRequest,
|
||||
// userVo vo.User,
|
||||
// ctx context.Context,
|
||||
// session *types.ChatSession,
|
||||
// role model.ChatRole,
|
||||
// prompt string,
|
||||
// c *gin.Context) error {
|
||||
// promptCreatedAt := time.Now() // 记录提问时间
|
||||
// start := time.Now()
|
||||
// var apiKey = model.ApiKey{}
|
||||
// response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
// logger.Info("HTTP请求完成,耗时:", time.Since(start))
|
||||
// if err != nil {
|
||||
// if strings.Contains(err.Error(), "context canceled") {
|
||||
// return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||
// } else if strings.Contains(err.Error(), "no available key") {
|
||||
// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
// }
|
||||
// return err
|
||||
// } else {
|
||||
// defer response.Body.Close()
|
||||
// }
|
||||
|
||||
// if response.StatusCode != 200 {
|
||||
// body, _ := io.ReadAll(response.Body)
|
||||
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
||||
// }
|
||||
|
||||
// contentType := response.Header.Get("Prompt-Type")
|
||||
// if strings.Contains(contentType, "text/event-stream") {
|
||||
// replyCreatedAt := time.Now() // 记录回复时间
|
||||
// // 循环读取 Chunk 消息
|
||||
// var message = types.Message{Role: "assistant"}
|
||||
// var contents = make([]string, 0)
|
||||
// var function model.Function
|
||||
// var toolCall = false
|
||||
// var arguments = make([]string, 0)
|
||||
// var reasoning = false
|
||||
|
||||
// pushMessage(c, ChatEventStart, "开始响应")
|
||||
// scanner := bufio.NewScanner(response.Body)
|
||||
// for scanner.Scan() {
|
||||
// line := scanner.Text()
|
||||
// if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
// continue
|
||||
// }
|
||||
// var responseBody = types.ApiResponse{}
|
||||
// err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
// if err != nil { // 数据解析出错
|
||||
// return errors.New(line)
|
||||
// }
|
||||
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
// continue
|
||||
// }
|
||||
// if responseBody.Choices[0].Delta.Prompt == nil &&
|
||||
// responseBody.Choices[0].Delta.ToolCalls == nil &&
|
||||
// responseBody.Choices[0].Delta.ReasoningContent == "" {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||
// pushMessage(c, ChatEventError, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
// break
|
||||
// }
|
||||
|
||||
// var tool types.ToolCall
|
||||
// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
||||
// tool = responseBody.Choices[0].Delta.ToolCalls[0]
|
||||
// if toolCall && tool.Function.Name == "" {
|
||||
// arguments = append(arguments, tool.Function.Arguments)
|
||||
// continue
|
||||
// }
|
||||
// }
|
||||
|
||||
// // 兼容 Function Call
|
||||
// fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
// if fun.Name != "" {
|
||||
// tool = *new(types.ToolCall)
|
||||
// tool.Function.Name = fun.Name
|
||||
// } else if toolCall {
|
||||
// arguments = append(arguments, fun.Arguments)
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if !utils.IsEmptyValue(tool) {
|
||||
// res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
|
||||
// if res.Error == nil {
|
||||
// toolCall = true
|
||||
// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": callMsg,
|
||||
// })
|
||||
// contents = append(contents, callMsg)
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if responseBody.Choices[0].FinishReason == "tool_calls" ||
|
||||
// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
// break
|
||||
// }
|
||||
|
||||
// // output stopped
|
||||
// if responseBody.Choices[0].FinishReason != "" {
|
||||
// break // 输出完成或者输出中断了
|
||||
// } else { // 正常输出结果
|
||||
// // 兼容思考过程
|
||||
// if responseBody.Choices[0].Delta.ReasoningContent != "" {
|
||||
// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
|
||||
// if !reasoning {
|
||||
// reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
|
||||
// reasoning = true
|
||||
// }
|
||||
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": reasoningContent,
|
||||
// })
|
||||
// contents = append(contents, reasoningContent)
|
||||
// } else if responseBody.Choices[0].Delta.Prompt != "" {
|
||||
// finalContent := responseBody.Choices[0].Delta.Prompt
|
||||
// if reasoning {
|
||||
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
|
||||
// reasoning = false
|
||||
// }
|
||||
// contents = append(contents, utils.InterfaceToString(finalContent))
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": finalContent,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// } // end for
|
||||
|
||||
// if err := scanner.Err(); err != nil {
|
||||
// if strings.Contains(err.Error(), "context canceled") {
|
||||
// logger.Info("用户取消了请求:", prompt)
|
||||
// } else {
|
||||
// logger.Error("信息读取出错:", err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if toolCall { // 调用函数完成任务
|
||||
// params := make(map[string]any)
|
||||
// _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
// params["user_id"] = userVo.Id
|
||||
// var apiRes types.BizVo
|
||||
// r, err := req2.C().R().SetHeader("Body-Type", "application/json").
|
||||
// SetHeader("Authorization", function.Token).
|
||||
// SetBody(params).Post(function.Action)
|
||||
// errMsg := ""
|
||||
// if err != nil {
|
||||
// errMsg = err.Error()
|
||||
// } else {
|
||||
// all, _ := io.ReadAll(r.Body)
|
||||
// err = json.Unmarshal(all, &apiRes)
|
||||
// if err != nil {
|
||||
// errMsg = err.Error()
|
||||
// } else if apiRes.Code != types.Success {
|
||||
// errMsg = apiRes.Message
|
||||
// }
|
||||
// }
|
||||
|
||||
// if errMsg != "" {
|
||||
// errMsg = "调用函数工具出错:" + errMsg
|
||||
// contents = append(contents, errMsg)
|
||||
// } else {
|
||||
// errMsg = utils.InterfaceToString(apiRes.Data)
|
||||
// contents = append(contents, errMsg)
|
||||
// }
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": errMsg,
|
||||
// })
|
||||
// }
|
||||
|
||||
// // 消息发送成功
|
||||
// if len(contents) > 0 {
|
||||
// usage := Usage{
|
||||
// Prompt: prompt,
|
||||
// Prompt: strings.Join(contents, ""),
|
||||
// PromptTokens: 0,
|
||||
// CompletionTokens: 0,
|
||||
// TotalTokens: 0,
|
||||
// }
|
||||
// message.Prompt = usage.Prompt
|
||||
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
// }
|
||||
// } else {
|
||||
// var respVo OpenAIResVo
|
||||
// body, err := io.ReadAll(response.Body)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("读取响应失败:%v", body)
|
||||
// }
|
||||
// err = json.Unmarshal(body, &respVo)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("解析响应失败:%v", body)
|
||||
// }
|
||||
// content := respVo.Choices[0].Message.Prompt
|
||||
// if strings.HasPrefix(req.Model, "o1-") {
|
||||
// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
|
||||
// }
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": content,
|
||||
// })
|
||||
// respVo.Usage.Prompt = prompt
|
||||
// respVo.Usage.Prompt = content
|
||||
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
// List 获取会话列表
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
logger.Info(h.GetLoginUserId(c))
|
||||
if !h.IsLogin(c) {
|
||||
resp.SUCCESS(c)
|
||||
return
|
||||
@@ -28,7 +29,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var items = make([]vo.ChatItem, 0)
|
||||
var chats []model.ChatItem
|
||||
h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
h.DB.Debug().Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
if len(chats) == 0 {
|
||||
resp.SUCCESS(c, items)
|
||||
return
|
||||
@@ -104,8 +105,6 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
var chatIds = make([]string, 0)
|
||||
for _, chat := range chats {
|
||||
chatIds = append(chatIds, chat.ChatId)
|
||||
// 清空会话上下文
|
||||
h.ChatContexts.Delete(chat.ChatId)
|
||||
}
|
||||
err = h.DB.Transaction(func(tx *gorm.DB) error {
|
||||
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
||||
@@ -133,20 +132,28 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
func (h *ChatHandler) History(c *gin.Context) {
|
||||
chatId := c.Query("chat_id") // 会话 ID
|
||||
var items []model.ChatMessage
|
||||
var messages = make([]vo.HistoryMessage, 0)
|
||||
var messages = make([]vo.ChatMessage, 0)
|
||||
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No history message")
|
||||
return
|
||||
} else {
|
||||
for _, item := range items {
|
||||
var v vo.HistoryMessage
|
||||
var v vo.ChatMessage
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 解析内容
|
||||
var content vo.MsgContent
|
||||
err = utils.JsonDecode(item.Content, &content)
|
||||
if err != nil {
|
||||
content.Text = item.Content
|
||||
}
|
||||
v.Content = content
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
v.UpdatedAt = item.UpdatedAt.Unix()
|
||||
if err == nil {
|
||||
messages = append(messages, v)
|
||||
}
|
||||
messages = append(messages, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,10 +186,6 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
|
||||
|
||||
// 清空会话上下文
|
||||
h.ChatContexts.Delete(chatId)
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,14 +30,16 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
var items []model.ChatModel
|
||||
var chatModels = make([]vo.ChatModel, 0)
|
||||
session := h.DB.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
|
||||
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
|
||||
t := c.Query("type")
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
} else {
|
||||
session = session.Where("type", "chat")
|
||||
}
|
||||
|
||||
session = session.Where("open", true)
|
||||
if h.IsLogin(c) {
|
||||
if h.IsLogin(c) && t == "chat" {
|
||||
user, _ := h.GetLoginUser(c)
|
||||
var models []int
|
||||
err := utils.JsonDecode(user.ChatModels, &models)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
req2 "github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
@@ -55,18 +56,16 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
prompt string,
|
||||
ws *types.WsClient) error {
|
||||
input ChatInput,
|
||||
c *gin.Context) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, input, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Since(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||
return fmt.Errorf("用户取消了请求:%s", input.Prompt)
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
}
|
||||
@@ -89,13 +88,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
var function model.Function
|
||||
var toolCall = false
|
||||
var arguments = make([]string, 0)
|
||||
|
||||
if strings.HasPrefix(req.Model, "o1-") {
|
||||
content := fmt.Sprintf("AI 思考结束,耗时:%d 秒。\n\n", time.Now().Unix()-session.Start)
|
||||
contents = append(contents, "> AI 正在思考中...\n")
|
||||
contents = append(contents, content)
|
||||
utils.SendChunkMsg(ws, content)
|
||||
}
|
||||
var reasoning = false
|
||||
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
@@ -111,12 +104,14 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
|
||||
if responseBody.Choices[0].Delta.Content == nil &&
|
||||
responseBody.Choices[0].Delta.ToolCalls == nil &&
|
||||
responseBody.Choices[0].Delta.ReasoningContent == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||
utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
pushMessage(c, "text", "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -144,7 +139,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if res.Error == nil {
|
||||
toolCall = true
|
||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
utils.SendChunkMsg(ws, callMsg)
|
||||
pushMessage(c, "text", callMsg)
|
||||
contents = append(contents, callMsg)
|
||||
}
|
||||
continue
|
||||
@@ -159,22 +154,38 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else { // 正常输出结果
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
utils.SendChunkMsg(ws, content)
|
||||
// 兼容思考过程
|
||||
if responseBody.Choices[0].Delta.ReasoningContent != "" {
|
||||
reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
|
||||
if !reasoning {
|
||||
reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
|
||||
reasoning = true
|
||||
}
|
||||
|
||||
pushMessage(c, "text", reasoningContent)
|
||||
contents = append(contents, reasoningContent)
|
||||
} else if responseBody.Choices[0].Delta.Content != "" {
|
||||
finalContent := responseBody.Choices[0].Delta.Content
|
||||
if reasoning {
|
||||
finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content)
|
||||
reasoning = false
|
||||
}
|
||||
contents = append(contents, utils.InterfaceToString(finalContent))
|
||||
pushMessage(c, "text", finalContent)
|
||||
}
|
||||
}
|
||||
} // end for
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
logger.Info("用户取消了请求:", input.Prompt)
|
||||
} else {
|
||||
logger.Error("信息读取出错:", err)
|
||||
}
|
||||
}
|
||||
|
||||
if toolCall { // 调用函数完成任务
|
||||
params := make(map[string]interface{})
|
||||
params := make(map[string]any)
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
params["user_id"] = userVo.Id
|
||||
@@ -202,20 +213,20 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
errMsg = utils.InterfaceToString(apiRes.Data)
|
||||
contents = append(contents, errMsg)
|
||||
}
|
||||
utils.SendChunkMsg(ws, errMsg)
|
||||
pushMessage(c, "text", errMsg)
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
usage := Usage{
|
||||
Prompt: prompt,
|
||||
Prompt: input.Prompt,
|
||||
Content: strings.Join(contents, ""),
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
message.Content = usage.Content
|
||||
h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
h.saveChatHistory(req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
|
||||
}
|
||||
} else { // 非流式输出
|
||||
var respVo OpenAIResVo
|
||||
@@ -228,13 +239,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
return fmt.Errorf("解析响应失败:%v", body)
|
||||
}
|
||||
content := respVo.Choices[0].Message.Content
|
||||
if strings.HasPrefix(req.Model, "o1-") {
|
||||
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||
}
|
||||
utils.SendChunkMsg(ws, content)
|
||||
respVo.Usage.Prompt = prompt
|
||||
pushMessage(c, "text", content)
|
||||
respVo.Usage.Prompt = input.Prompt
|
||||
respVo.Usage.Content = content
|
||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -64,10 +64,12 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||
var user model.User
|
||||
h.DB.First(&user, userId)
|
||||
var roleKeys []string
|
||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
if user.ChatRoles != "" {
|
||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
}
|
||||
}
|
||||
// 保证用户至少有一个角色可用
|
||||
if len(roleKeys) > 0 {
|
||||
|
||||
@@ -31,14 +31,14 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
res := h.DB.Where("name", key).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]interface{}
|
||||
err := utils.JsonDecode(config.Config, &value)
|
||||
var value map[string]any
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -70,7 +70,6 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
task := types.DallTask{
|
||||
ClientId: data.ClientId,
|
||||
UserId: uint(userId),
|
||||
ModelId: chatModel.Id,
|
||||
ModelName: chatModel.Value,
|
||||
@@ -78,7 +77,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Power: chatModel.Power,
|
||||
}
|
||||
job := model.DallJob{
|
||||
@@ -97,7 +96,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
h.dallService.PushTask(task)
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(int(user.Id), chatModel.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(user.Id, chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/crawler"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
@@ -212,7 +213,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
Prompt: prompt,
|
||||
ModelId: 0,
|
||||
ModelName: "dall-e-3",
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
@@ -239,7 +240,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(int(user.Id), job.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: task.ModelName,
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)),
|
||||
@@ -252,6 +253,76 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
// 实现一个联网搜索的函数工具,采用爬虫实现
|
||||
func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
if err := h.checkAuth(c); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var params map[string]interface{}
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 从参数中获取搜索关键词
|
||||
keyword, ok := params["keyword"].(string)
|
||||
if !ok || keyword == "" {
|
||||
resp.ERROR(c, "搜索关键词不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 从参数中获取最大页数,默认为1页
|
||||
maxPages := 1
|
||||
if pages, ok := params["max_pages"].(float64); ok {
|
||||
maxPages = int(pages)
|
||||
}
|
||||
|
||||
// 获取用户ID
|
||||
userID, ok := params["user_id"].(float64)
|
||||
if !ok {
|
||||
resp.ERROR(c, "用户ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 查询用户信息
|
||||
var user model.User
|
||||
res := h.DB.Where("id = ?", int(userID)).First(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力是否足够
|
||||
searchPower := 1 // 每次搜索消耗1点算力
|
||||
if user.Power < searchPower {
|
||||
resp.ERROR(c, "算力不足,无法执行网络搜索")
|
||||
return
|
||||
}
|
||||
|
||||
// 执行网络搜索
|
||||
searchResults, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 扣减用户算力
|
||||
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "web_search",
|
||||
Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回搜索结果
|
||||
resp.SUCCESS(c, searchResults)
|
||||
}
|
||||
|
||||
// List 获取所有的工具函数列表
|
||||
func (h *FunctionHandler) List(c *gin.Context) {
|
||||
var items []model.Function
|
||||
|
||||
442
api/handler/jimeng_handler.go
Normal file
442
api/handler/jimeng_handler.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JimengHandler 即梦AI处理器
|
||||
type JimengHandler struct {
|
||||
BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewJimengHandler 创建即梦AI处理器
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler {
|
||||
return &JimengHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由,新增统一任务接口
|
||||
func (h *JimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/jimeng")
|
||||
rg.POST("task", h.CreateTask) // 只保留统一任务接口
|
||||
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
|
||||
rg.POST("jobs", h.Jobs)
|
||||
rg.GET("remove", h.Remove)
|
||||
rg.GET("retry", h.Retry)
|
||||
}
|
||||
|
||||
// JimengTaskRequest 统一任务请求结构体
|
||||
// 支持所有生图和生成视频类型
|
||||
type JimengTaskRequest struct {
|
||||
TaskType string `json:"task_type" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
ImageInput string `json:"image_input"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
TemplateId string `json:"template_id"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
// CreateTask 统一任务创建接口
|
||||
func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
var req JimengTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||
resp.ERROR(c, "提示词不能为空")
|
||||
return
|
||||
}
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
var powerCost int
|
||||
var taskType model.JMTaskType
|
||||
var params map[string]any
|
||||
var reqKey string
|
||||
var modelName string
|
||||
|
||||
switch req.TaskType {
|
||||
case "text_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToImage)
|
||||
taskType = model.JMTaskTypeTextToImage
|
||||
reqKey = jimeng.ReqKeyTextToImage
|
||||
modelName = "即梦文生图"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
case "image_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToImage)
|
||||
taskType = model.JMTaskTypeImageToImage
|
||||
reqKey = jimeng.ReqKeyImageToImagePortrait
|
||||
modelName = "即梦图生图"
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
}
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
case "image_edit":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEdit)
|
||||
taskType = model.JMTaskTypeImageEdit
|
||||
reqKey = jimeng.ReqKeyImageEdit
|
||||
modelName = "即梦图像编辑"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
case "image_effects":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||
taskType = model.JMTaskTypeImageEffects
|
||||
reqKey = jimeng.ReqKeyImageEffects
|
||||
modelName = "即梦图像特效"
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input1": req.ImageInput,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
case "text_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToVideo)
|
||||
taskType = model.JMTaskTypeTextToVideo
|
||||
reqKey = jimeng.ReqKeyTextToVideo
|
||||
modelName = "即梦文生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
case "image_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToVideo)
|
||||
taskType = model.JMTaskTypeImageToVideo
|
||||
reqKey = jimeng.ReqKeyImageToVideo
|
||||
modelName = "即梦图生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
default:
|
||||
resp.ERROR(c, "不支持的任务类型")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: taskType,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: reqKey,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
var req struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Filter string `json:"filter"`
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var jobs []model.JimengJob
|
||||
var total int64
|
||||
query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
||||
|
||||
switch req.Filter {
|
||||
case "image":
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToImage,
|
||||
model.JMTaskTypeImageToImage,
|
||||
model.JMTaskTypeImageEdit,
|
||||
model.JMTaskTypeImageEffects,
|
||||
})
|
||||
case "video":
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToVideo,
|
||||
model.JMTaskTypeImageToVideo,
|
||||
})
|
||||
}
|
||||
|
||||
if len(req.Ids) > 0 {
|
||||
query = query.Where("id IN (?)", req.Ids)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := query.Order("updated_at DESC").Offset(offset).Limit(req.PageSize).Find(&jobs).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 填充 VO
|
||||
var jobVos []vo.JimengJob
|
||||
for _, job := range jobs {
|
||||
var jobVo vo.JimengJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
jobVo.CreatedAt = job.CreatedAt.Unix()
|
||||
jobVos = append(jobVos, jobVo)
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, req.Page, req.PageSize, jobVos))
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取任务,判断状态
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
if job.UserId != user.Id {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能删除")
|
||||
return
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Where("id = ? AND user_id = ?", jobId, user.Id).Delete(&model.JimengJob{}).Error; err != nil {
|
||||
logger.Errorf("delete jimeng job failed: %v", err)
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 退回算力
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "退回算力失败")
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// Retry 重试任务
|
||||
func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否存在且属于当前用户
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if job.UserId != userId {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败的任务才能重试
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能重试")
|
||||
return
|
||||
}
|
||||
|
||||
// 重置任务状态
|
||||
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JMTaskStatusInQueue, ""); err != nil {
|
||||
logger.Errorf("reset job status failed: %v", err)
|
||||
resp.ERROR(c, "重置任务状态失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 重新推送到队列
|
||||
if err := h.jimengService.PushTaskToQueue(uint(jobId)); err != nil {
|
||||
logger.Errorf("push retry task to queue failed: %v", err)
|
||||
resp.ERROR(c, "推送重试任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
|
||||
}
|
||||
|
||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
config := h.jimengService.GetConfig()
|
||||
|
||||
switch taskType {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
return config.Power.TextToImage
|
||||
case model.JMTaskTypeImageToImage:
|
||||
return config.Power.ImageToImage
|
||||
case model.JMTaskTypeImageEdit:
|
||||
return config.Power.ImageEdit
|
||||
case model.JMTaskTypeImageEffects:
|
||||
return config.Power.ImageEffects
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
return config.Power.TextToVideo
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
return config.Power.ImageToVideo
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
}
|
||||
|
||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||
config := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"text_to_image": config.Power.TextToImage,
|
||||
"image_to_image": config.Power.ImageToImage,
|
||||
"image_edit": config.Power.ImageEdit,
|
||||
"image_effects": config.Power.ImageEffects,
|
||||
"text_to_video": config.Power.TextToVideo,
|
||||
"image_to_video": config.Power.ImageToVideo,
|
||||
})
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -95,7 +96,7 @@ func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
|
||||
// 扣减算力
|
||||
if chatModel.Power > 0 {
|
||||
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||
|
||||
@@ -66,7 +66,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
TaskType string `json:"task_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegPrompt string `json:"neg_prompt"`
|
||||
Rate string `json:"rate"`
|
||||
@@ -153,7 +152,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
task := types.MjTask{
|
||||
ClientId: data.ClientId,
|
||||
TaskId: taskId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
Prompt: data.Prompt,
|
||||
@@ -162,11 +160,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: data.TaskType,
|
||||
UserId: userId,
|
||||
UserId: uint(userId),
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
@@ -207,7 +205,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
type reqVo struct {
|
||||
Index int `json:"index"`
|
||||
ClientId string `json:"client_id"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
@@ -229,7 +226,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
taskId, _ := h.snowflake.Next(true)
|
||||
task := types.MjTask{
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskUpscale,
|
||||
UserId: userId,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -240,7 +236,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskUpscale.String(),
|
||||
UserId: userId,
|
||||
UserId: uint(userId),
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
@@ -286,7 +282,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
taskId, _ := h.snowflake.Next(true)
|
||||
task := types.MjTask{
|
||||
Type: types.TaskVariation,
|
||||
ClientId: data.ClientId,
|
||||
UserId: userId,
|
||||
Index: data.Index,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -297,7 +292,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskVariation.String(),
|
||||
ChannelId: data.ChannelId,
|
||||
UserId: userId,
|
||||
UserId: uint(userId),
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
@@ -427,7 +422,7 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
|
||||
err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -15,11 +15,12 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type NetHandler struct {
|
||||
@@ -46,7 +47,7 @@ func (h *NetHandler) Upload(c *gin.Context) {
|
||||
|
||||
userId := h.GetLoginUserId(c)
|
||||
res := h.DB.Create(&model.File{
|
||||
UserId: int(userId),
|
||||
UserId: uint(userId),
|
||||
Name: file.Name,
|
||||
ObjKey: file.ObjKey,
|
||||
URL: file.URL,
|
||||
@@ -143,7 +144,15 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 使用http.Get下载文件
|
||||
r, err := http.Get(fileUrl)
|
||||
req, err := http.NewRequest("GET", fileUrl, nil)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 模拟浏览器 UA
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
client := &http.Client{}
|
||||
r, err := client.Do(req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -156,6 +165,5 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
// 将下载的文件内容写入响应
|
||||
_, _ = io.Copy(c.Writer, r.Body)
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
}
|
||||
|
||||
// 增加用户算力
|
||||
err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
|
||||
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
|
||||
Type: types.PowerRecharge,
|
||||
Model: order.PayWay,
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||
|
||||
@@ -48,7 +48,7 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -56,11 +56,15 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
|
||||
if h.App.SysConfig.PromptPower > 0 {
|
||||
userId := h.GetLoginUserId(c)
|
||||
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: h.getPromptModel(),
|
||||
Remark: "生成歌词",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
@@ -75,18 +79,22 @@ func (h *PromptHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if h.App.SysConfig.PromptPower > 0 {
|
||||
userId := h.GetLoginUserId(c)
|
||||
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: h.getPromptModel(),
|
||||
Remark: "生成绘画提示词",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||
}
|
||||
@@ -100,7 +108,7 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -108,11 +116,15 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
||||
|
||||
if h.App.SysConfig.PromptPower > 0 {
|
||||
userId := h.GetLoginUserId(c)
|
||||
h.userService.DecreasePower(int(userId), h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: h.getPromptModel(),
|
||||
Remark: "生成视频脚本",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||
@@ -146,9 +158,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *PromptHandler) getPromptModel() string {
|
||||
if h.App.SysConfig.TranslateModelId > 0 {
|
||||
if h.App.SysConfig.AssistantModelId > 0 {
|
||||
var chatModel model.ChatModel
|
||||
h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel)
|
||||
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
|
||||
return chatModel.Value
|
||||
}
|
||||
return "gpt-4o"
|
||||
|
||||
@@ -198,7 +198,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(int(userId), h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "advanced-voice",
|
||||
Remark: "实时语音通话",
|
||||
|
||||
@@ -61,7 +61,7 @@ func (h *RedeemHandler) Verify(c *gin.Context) {
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
|
||||
err := h.userService.IncreasePower(userId, item.Power, model.PowerLog{
|
||||
Type: types.PowerRedeem,
|
||||
Model: "兑换码",
|
||||
Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]),
|
||||
|
||||
@@ -102,6 +102,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
if data.Sampler == "" {
|
||||
data.Sampler = "Euler a"
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
taskId, err := h.snowflake.Next(true)
|
||||
@@ -111,8 +112,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
task := types.SdTask{
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskImage,
|
||||
Type: types.TaskImage,
|
||||
Params: types.SdTaskParams{
|
||||
TaskId: taskId,
|
||||
Prompt: data.Prompt,
|
||||
@@ -131,11 +131,11 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
HdSteps: data.HdSteps,
|
||||
},
|
||||
UserId: userId,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
UserId: userId,
|
||||
UserId: uint(userId),
|
||||
Type: types.TaskImage.String(),
|
||||
TaskId: taskId,
|
||||
Params: utils.JsonEncode(task.Params),
|
||||
@@ -273,7 +273,7 @@ func (h *SdJobHandler) Publish(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
|
||||
|
||||
err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
|
||||
err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: uint(userId)}).UpdateColumn("publish", action).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -111,9 +111,5 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.Debug {
|
||||
resp.SUCCESS(c, code)
|
||||
} else {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -18,9 +18,10 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SunoHandler struct {
|
||||
@@ -45,7 +46,6 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Instrumental bool `json:"instrumental"`
|
||||
Lyrics string `json:"lyrics"`
|
||||
@@ -90,7 +90,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
task := types.SunoTask{
|
||||
ClientId: data.ClientId,
|
||||
UserId: int(h.GetLoginUserId(c)),
|
||||
Type: data.Type,
|
||||
Title: data.Title,
|
||||
@@ -98,6 +97,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
RefSongId: data.RefSongId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Prompt: data.Prompt,
|
||||
Lyrics: data.Lyrics,
|
||||
Tags: data.Tags,
|
||||
Model: data.Model,
|
||||
Instrumental: data.Instrumental,
|
||||
@@ -107,7 +107,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
// 插入数据库
|
||||
job := model.SunoJob{
|
||||
UserId: task.UserId,
|
||||
UserId: uint(task.UserId),
|
||||
Prompt: data.Prompt,
|
||||
Instrumental: data.Instrumental,
|
||||
ModelName: data.Model,
|
||||
|
||||
@@ -137,13 +137,15 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
|
||||
salt := utils.RandString(8)
|
||||
user := model.User{
|
||||
Username: data.Username,
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
Username: data.Username,
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
ChatConfig: "{}",
|
||||
ChatModels: "{}",
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
}
|
||||
|
||||
// check if the username is existing
|
||||
@@ -170,10 +172,15 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
if data.InviteCode != "" {
|
||||
user.Power += h.App.SysConfig.InvitePower
|
||||
}
|
||||
|
||||
if h.licenseService.GetLicense().Configs.DeCopy {
|
||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
} else {
|
||||
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
||||
defaultNickname := h.App.SysConfig.DefaultNickname
|
||||
if defaultNickname == "" {
|
||||
defaultNickname = "极客学长"
|
||||
}
|
||||
user.Nickname = fmt.Sprintf("%s@%d", defaultNickname, utils.RandomNumber(6))
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
@@ -187,7 +194,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
// 增加邀请数量
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||
if h.App.SysConfig.InvitePower > 0 {
|
||||
err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
|
||||
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{
|
||||
Type: types.PowerInvite,
|
||||
Model: "Invite",
|
||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||||
@@ -736,7 +743,7 @@ func (h *UserHandler) SignIn(c *gin.Context) {
|
||||
// 签到
|
||||
h.levelDB.Put(key, true)
|
||||
if h.App.SysConfig.DailyPower > 0 {
|
||||
h.userService.IncreasePower(int(userId), h.App.SysConfig.DailyPower, model.PowerLog{
|
||||
h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{
|
||||
Type: types.PowerSignIn,
|
||||
Model: "SignIn",
|
||||
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower),
|
||||
|
||||
@@ -18,9 +18,10 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VideoHandler struct {
|
||||
@@ -45,7 +46,6 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
|
||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImg string `json:"first_frame_img,omitempty"`
|
||||
EndFrameImg string `json:"end_frame_img,omitempty"`
|
||||
@@ -56,6 +56,11 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
// 检查 Prompt 长度
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -68,29 +73,23 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
}
|
||||
|
||||
userId := int(h.GetLoginUserId(c))
|
||||
params := types.VideoParams{
|
||||
params := types.LumaVideoParams{
|
||||
PromptOptimize: data.ExpandPrompt,
|
||||
Loop: data.Loop,
|
||||
StartImgURL: data.FirstFrameImg,
|
||||
EndImgURL: data.EndFrameImg,
|
||||
}
|
||||
task := types.VideoTask{
|
||||
ClientId: data.ClientId,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
UserId: userId,
|
||||
UserId: uint(userId),
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Power: h.App.SysConfig.LumaPower,
|
||||
@@ -119,20 +118,117 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
Channel string `json:"channel"`
|
||||
TaskType string `json:"task_type"` // 任务类型: text2video/image2video
|
||||
Model string `json:"model"` // 模型: kling-v1-5,kling-v1-6
|
||||
Prompt string `json:"prompt"` // 视频描述
|
||||
NegPrompt string `json:"negative_prompt"` // 负面提示词
|
||||
CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1)
|
||||
Mode string `json:"mode"` // 生成模式: std/pro
|
||||
AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1
|
||||
Duration string `json:"duration"` // 视频时长: 5/10
|
||||
CameraControl types.CameraControl `json:"camera_control"` // 摄像机控制
|
||||
Image string `json:"image"` // 参考图片URL(image2video)
|
||||
ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video)
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 计算当前任务所需算力
|
||||
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
|
||||
power := h.App.SysConfig.KeLingPowers[key]
|
||||
if power == 0 {
|
||||
resp.ERROR(c, "当前模型暂不支持")
|
||||
return
|
||||
}
|
||||
if user.Power < power {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
}
|
||||
|
||||
userId := int(h.GetLoginUserId(c))
|
||||
params := types.KeLingVideoParams{
|
||||
TaskType: data.TaskType,
|
||||
Model: data.Model,
|
||||
Prompt: data.Prompt,
|
||||
NegPrompt: data.NegPrompt,
|
||||
CfgScale: data.CfgScale,
|
||||
Mode: data.Mode,
|
||||
AspectRatio: data.AspectRatio,
|
||||
Duration: data.Duration,
|
||||
CameraControl: data.CameraControl,
|
||||
Image: data.Image,
|
||||
ImageTail: data.ImageTail,
|
||||
}
|
||||
task := types.VideoTask{
|
||||
UserId: userId,
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Channel: data.Channel,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
UserId: uint(userId),
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Power: power,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
}
|
||||
tx := h.DB.Create(&job)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, tx.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
task.Id = job.Id
|
||||
h.videoService.PushTask(task)
|
||||
|
||||
// update user's power
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "keling",
|
||||
Remark: fmt.Sprintf("keling 文生视频,任务ID:%d", job.Id),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *VideoHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
t := c.Query("type")
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
all := h.GetBool(c, "all")
|
||||
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
}
|
||||
if all {
|
||||
session = session.Where("publish", 0).Where("progress", 100)
|
||||
} else {
|
||||
session = session.Where("user_id", h.GetLoginUserId(c))
|
||||
session = session.Where("user_id", userId)
|
||||
}
|
||||
// 统计总数
|
||||
var total int64
|
||||
@@ -161,6 +257,33 @@ func (h *VideoHandler) List(c *gin.Context) {
|
||||
if item.VideoURL == "" {
|
||||
item.VideoURL = v.WaterURL
|
||||
}
|
||||
// 解析任务详情
|
||||
if item.Type == types.VideoKeLing {
|
||||
task := types.VideoTask{}
|
||||
err = utils.JsonDecode(v.TaskInfo, &task)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var params types.KeLingVideoParams
|
||||
err = utils.JsonDecode(utils.JsonEncode(task.Params), ¶ms)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
item.RawData = map[string]interface{}{
|
||||
"task_type": params.TaskType,
|
||||
"model": params.Model,
|
||||
"cfg_scale": params.CfgScale,
|
||||
"mode": params.Mode,
|
||||
"aspect_ratio": params.AspectRatio,
|
||||
"duration": params.Duration,
|
||||
"model_name": fmt.Sprintf("%s_%s_%s", params.Model, params.Mode, params.Duration),
|
||||
}
|
||||
|
||||
// 如果视频URL不为空,则设置为生成成功
|
||||
if item.VideoURL != "" {
|
||||
item.Progress = 100
|
||||
}
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
@@ -192,6 +315,8 @@ func (h *VideoHandler) Remove(c *gin.Context) {
|
||||
// 删除文件
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *VideoHandler) Publish(c *gin.Context) {
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Websocket 连接处理 handler
|
||||
|
||||
type WebsocketHandler struct {
|
||||
BaseHandler
|
||||
wsService *service.WebsocketService
|
||||
chatHandler *ChatHandler
|
||||
}
|
||||
|
||||
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
|
||||
return &WebsocketHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
chatHandler: chatHandler,
|
||||
wsService: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
ws, err := (&websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: strings.Split(clientProtocols, ","),
|
||||
}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
clientId := c.Query("client_id")
|
||||
client := types.NewWsClient(ws, clientId)
|
||||
userId := h.GetLoginUserId(c)
|
||||
if userId == 0 {
|
||||
_ = client.Send([]byte("Invalid user_id"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
_ = client.Send([]byte("Invalid user_id"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
h.wsService.Clients.Put(clientId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||
client.Close()
|
||||
h.wsService.Clients.Delete(clientId)
|
||||
break
|
||||
}
|
||||
|
||||
var message types.InputMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("Receive a message:%+v", message)
|
||||
if message.Type == types.MsgTypePing {
|
||||
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||||
continue
|
||||
}
|
||||
|
||||
// 当前只处理聊天消息,其他消息全部丢弃
|
||||
var chatMessage types.ChatMessage
|
||||
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
|
||||
if err != nil || message.Channel != types.ChChat {
|
||||
logger.Warnf("invalid message body:%+v", message.Body)
|
||||
continue
|
||||
}
|
||||
var chatRole model.ChatRole
|
||||
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
|
||||
continue
|
||||
}
|
||||
// if the role bind a model_id, use role's bind model_id
|
||||
if chatRole.ModelId > 0 {
|
||||
chatMessage.RoleId = chatRole.ModelId
|
||||
}
|
||||
// get model info
|
||||
var chatModel model.ChatModel
|
||||
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
|
||||
if err != nil || chatModel.Enabled == false {
|
||||
utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
|
||||
continue
|
||||
}
|
||||
|
||||
session := &types.ChatSession{
|
||||
ClientIP: c.ClientIP(),
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
// use old chat data override the chat model and role ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
chatModel.Id = chat.ModelId
|
||||
chatMessage.RoleId = int(chat.RoleId)
|
||||
}
|
||||
|
||||
session.ChatId = chatMessage.ChatId
|
||||
session.Tools = chatMessage.Tools
|
||||
session.Stream = chatMessage.Stream
|
||||
// 复制模型数据
|
||||
err = utils.CopyObject(chatModel, &session.Model)
|
||||
if err != nil {
|
||||
logger.Error(err, chatModel)
|
||||
}
|
||||
session.Model.Id = chatModel.Id
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
|
||||
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.SendAndFlush(client, err.Error())
|
||||
} else {
|
||||
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
|
||||
logger.Infof("回答完毕: %v", message.Body)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -8,11 +8,12 @@ package logger
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var logger *zap.Logger
|
||||
@@ -23,7 +24,7 @@ func GetLogger() *zap.SugaredLogger {
|
||||
return sugarLogger
|
||||
}
|
||||
|
||||
logLevel := zap.NewAtomicLevelAt(getLogLevel(os.Getenv("LOG_LEVEL")))
|
||||
logLevel := zap.NewAtomicLevelAt(getLogLevel(os.Getenv("GEEKAI_LOG_LEVEL")))
|
||||
encoder := getEncoder()
|
||||
writerSyncer := getLogWriter()
|
||||
fileCore := zapcore.NewCore(encoder, writerSyncer, logLevel)
|
||||
|
||||
36
api/main.go
36
api/main.go
@@ -17,6 +17,7 @@ import (
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
@@ -140,6 +141,7 @@ func main() {
|
||||
fx.Provide(handler.NewProductHandler),
|
||||
fx.Provide(handler.NewConfigHandler),
|
||||
fx.Provide(handler.NewPowerLogHandler),
|
||||
fx.Provide(handler.NewJimengHandler),
|
||||
|
||||
fx.Provide(admin.NewConfigHandler),
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
@@ -153,6 +155,7 @@ func main() {
|
||||
fx.Provide(admin.NewOrderHandler),
|
||||
fx.Provide(admin.NewChatHandler),
|
||||
fx.Provide(admin.NewPowerLogHandler),
|
||||
fx.Provide(admin.NewAdminJimengHandler),
|
||||
|
||||
// 创建服务
|
||||
fx.Provide(sms.NewSendServiceManager),
|
||||
@@ -163,7 +166,6 @@ func main() {
|
||||
fx.Provide(dalle.NewService),
|
||||
fx.Invoke(func(s *dalle.Service) {
|
||||
s.Run()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadImages()
|
||||
s.CheckTaskStatus()
|
||||
}),
|
||||
@@ -182,7 +184,6 @@ func main() {
|
||||
fx.Invoke(func(s *mj.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadImages()
|
||||
}),
|
||||
|
||||
@@ -191,23 +192,26 @@ func main() {
|
||||
fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
|
||||
s.Run()
|
||||
s.CheckTaskStatus()
|
||||
s.CheckTaskNotify()
|
||||
}),
|
||||
|
||||
fx.Provide(suno.NewService),
|
||||
fx.Invoke(func(s *suno.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadFiles()
|
||||
}),
|
||||
fx.Provide(video.NewService),
|
||||
fx.Invoke(func(s *video.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadFiles()
|
||||
}),
|
||||
|
||||
// 即梦AI 服务
|
||||
fx.Provide(jimeng.NewService),
|
||||
fx.Invoke(func(service *jimeng.Service) {
|
||||
service.Start()
|
||||
}),
|
||||
fx.Provide(service.NewUserService),
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewHuPiPay),
|
||||
@@ -248,6 +252,7 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
||||
group := s.Engine.Group("/api/chat/")
|
||||
group.Any("message", h.Chat)
|
||||
group.GET("list", h.List)
|
||||
group.GET("detail", h.Detail)
|
||||
group.POST("update", h.Update)
|
||||
@@ -256,6 +261,7 @@ func main() {
|
||||
group.GET("clear", h.Clear)
|
||||
group.POST("tokens", h.Tokens)
|
||||
group.GET("stop", h.StopGenerate)
|
||||
group.POST("tts", h.TextToSpeech)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
|
||||
s.Engine.POST("/api/upload", h.Upload)
|
||||
@@ -335,6 +341,7 @@ func main() {
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("loginLog", h.LoginLog)
|
||||
group.GET("genLoginLink", h.GenLoginLink)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
||||
@@ -431,6 +438,7 @@ func main() {
|
||||
group.POST("weibo", h.WeiBo)
|
||||
group.POST("zaobao", h.ZaoBao)
|
||||
group.POST("dalle3", h.Dall3)
|
||||
group.POST("websearch", h.WebSearch)
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
||||
@@ -492,10 +500,19 @@ func main() {
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
||||
group := s.Engine.Group("/api/video")
|
||||
group.POST("luma/create", h.LumaCreate)
|
||||
group.POST("keling/create", h.KeLingCreate)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
|
||||
// 即梦AI 路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.JimengHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.AdminJimengHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(admin.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/admin/app/type")
|
||||
@@ -515,11 +532,6 @@ func main() {
|
||||
group := s.Engine.Group("/api/test")
|
||||
group.Any("sse", h.PostTest, h.SseTest)
|
||||
}),
|
||||
fx.Provide(service.NewWebsocketService),
|
||||
fx.Provide(handler.NewWebsocketHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
|
||||
s.Engine.Any("/api/ws", h.Client)
|
||||
}),
|
||||
fx.Provide(handler.NewPromptHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||
group := s.Engine.Group("/api/prompt")
|
||||
@@ -560,8 +572,8 @@ func main() {
|
||||
fx.Provide(admin.NewMediaHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
|
||||
group := s.Engine.Group("/api/admin/media")
|
||||
group.POST("/list/suno", h.SunoList)
|
||||
group.POST("/list/luma", h.LumaList)
|
||||
group.POST("/suno", h.SunoList)
|
||||
group.POST("/videos", h.Videos)
|
||||
group.GET("/remove", h.Remove)
|
||||
}),
|
||||
fx.Provide(handler.NewRealtimeHandler),
|
||||
|
||||
333
api/service/crawler/service.go
Normal file
333
api/service/crawler/service.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package crawler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/logger"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
// Service 网络爬虫服务
|
||||
type Service struct {
|
||||
browser *rod.Browser
|
||||
}
|
||||
|
||||
// NewService 创建一个新的爬虫服务
|
||||
func NewService() (*Service, error) {
|
||||
// 启动浏览器
|
||||
path, _ := launcher.LookPath()
|
||||
u := launcher.New().Bin(path).
|
||||
Headless(true). // 无头模式
|
||||
Set("disable-web-security", ""). // 禁用网络安全限制
|
||||
Set("disable-gpu", ""). // 禁用 GPU 加速
|
||||
Set("no-sandbox", ""). // 禁用沙箱模式
|
||||
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
|
||||
MustLaunch()
|
||||
|
||||
browser := rod.New().ControlURL(u).MustConnect()
|
||||
|
||||
return &Service{
|
||||
browser: browser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchResult 搜索结果
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"` // 标题
|
||||
URL string `json:"url"` // 链接
|
||||
Content string `json:"content"` // 内容摘要
|
||||
}
|
||||
|
||||
// WebSearch 网络搜索
|
||||
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
|
||||
if keyword == "" {
|
||||
return nil, errors.New("搜索关键词不能为空")
|
||||
}
|
||||
|
||||
if maxPages <= 0 {
|
||||
maxPages = 1
|
||||
}
|
||||
if maxPages > 10 {
|
||||
maxPages = 10 // 最多搜索 10 页
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0)
|
||||
|
||||
// 使用百度搜索
|
||||
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
|
||||
|
||||
// 设置页面超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 创建页面
|
||||
page := s.browser.MustPage()
|
||||
defer page.MustClose()
|
||||
|
||||
// 设置视口大小
|
||||
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
|
||||
Width: 1280,
|
||||
Height: 800,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("设置视口失败: %v", err)
|
||||
}
|
||||
|
||||
// 导航到搜索页面
|
||||
err = page.Context(ctx).Navigate(searchURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待搜索结果加载完成
|
||||
err = page.WaitLoad()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
|
||||
}
|
||||
|
||||
// 分析当前页面的搜索结果
|
||||
for i := 0; i < maxPages; i++ {
|
||||
if i > 0 {
|
||||
// 点击下一页按钮
|
||||
nextPage, err := page.Element("a.n")
|
||||
if err != nil || nextPage == nil {
|
||||
break // 没有下一页
|
||||
}
|
||||
|
||||
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
|
||||
if err != nil {
|
||||
break // 点击下一页失败
|
||||
}
|
||||
|
||||
// 等待新页面加载
|
||||
err = page.WaitLoad()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 提取搜索结果
|
||||
resultElements, err := page.Elements(".result, .c-container")
|
||||
if err != nil || resultElements == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, result := range resultElements {
|
||||
// 获取标题
|
||||
titleElement, err := result.Element("h3, .t")
|
||||
if err != nil || titleElement == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
title, err := titleElement.Text()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 URL
|
||||
linkElement, err := titleElement.Element("a")
|
||||
if err != nil || linkElement == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
href, err := linkElement.Attribute("href")
|
||||
if err != nil || href == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取内容摘要 - 尝试多个可能的选择器
|
||||
var contentElement *rod.Element
|
||||
var content string
|
||||
|
||||
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
|
||||
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
|
||||
for _, selector := range selectors {
|
||||
contentElement, err = result.Element(selector)
|
||||
if err == nil && contentElement != nil {
|
||||
content, _ = contentElement.Text()
|
||||
if content != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有选择器都失败,尝试直接从结果块中提取文本
|
||||
if content == "" {
|
||||
// 获取结果元素的所有文本
|
||||
fullText, err := result.Text()
|
||||
if err == nil && fullText != "" {
|
||||
// 简单处理:从全文中移除标题,剩下的可能是摘要
|
||||
fullText = strings.Replace(fullText, title, "", 1)
|
||||
// 清理文本
|
||||
content = strings.TrimSpace(fullText)
|
||||
// 限制内容长度
|
||||
if len(content) > 200 {
|
||||
content = content[:200] + "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加到结果集
|
||||
results = append(results, SearchResult{
|
||||
Title: title,
|
||||
URL: *href,
|
||||
Content: content,
|
||||
})
|
||||
|
||||
// 限制结果数量,每页最多 10 条
|
||||
if len(results) >= 10*maxPages {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取真实 URL(百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL)
|
||||
for i, result := range results {
|
||||
realURL, err := s.getRedirectURL(result.URL)
|
||||
if err == nil && realURL != "" {
|
||||
results[i].URL = realURL
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 获取真实 URL
|
||||
func (s *Service) getRedirectURL(shortURL string) (string, error) {
|
||||
// 创建页面
|
||||
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
defer func() {
|
||||
_ = page.Close()
|
||||
}()
|
||||
|
||||
// 导航到短链接
|
||||
err = page.Navigate(shortURL)
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
|
||||
// 等待重定向完成
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 获取当前 URL
|
||||
info, err := page.Info()
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
|
||||
return info.URL, nil
|
||||
}
|
||||
|
||||
// Close 关闭浏览器
|
||||
func (s *Service) Close() error {
|
||||
if s.browser != nil {
|
||||
err := s.browser.Close()
|
||||
s.browser = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchWeb 封装的搜索方法
|
||||
func SearchWeb(keyword string, maxPages int) (string, error) {
|
||||
// 添加panic恢复机制
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log := logger.GetLogger()
|
||||
log.Errorf("爬虫服务崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
service, err := NewService()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用goroutine和通道来处理超时
|
||||
resultChan := make(chan []SearchResult, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
results, err := service.WebSearch(keyword, maxPages)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resultChan <- results
|
||||
}()
|
||||
|
||||
// 等待结果或超时
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
|
||||
case err := <-errChan:
|
||||
return "", fmt.Errorf("搜索失败: %v", err)
|
||||
case results := <-resultChan:
|
||||
if len(results) == 0 {
|
||||
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
|
||||
|
||||
for i, result := range results {
|
||||
// // 尝试打开链接获取实际内容
|
||||
// page := service.browser.MustPage()
|
||||
// defer page.MustClose()
|
||||
|
||||
// // 设置页面超时
|
||||
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
// defer pageCancel()
|
||||
|
||||
// // 导航到目标页面
|
||||
// err := page.Context(pageCtx).Navigate(result.URL)
|
||||
// if err == nil {
|
||||
// // 等待页面加载
|
||||
// _ = page.WaitLoad()
|
||||
|
||||
// // 获取页面标题
|
||||
// title, err := page.Eval("() => document.title")
|
||||
// if err == nil && title.Value.String() != "" {
|
||||
// result.Title = title.Value.String()
|
||||
// }
|
||||
|
||||
// // 获取页面主要内容
|
||||
// if content, err := page.Element("body"); err == nil {
|
||||
// if text, err := content.Text(); err == nil {
|
||||
// // 清理并截取内容
|
||||
// text = strings.TrimSpace(text)
|
||||
// if len(text) > 200 {
|
||||
// text = text[:200] + "..."
|
||||
// }
|
||||
// result.Prompt = text
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
|
||||
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
|
||||
if result.Content != "" {
|
||||
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
@@ -34,33 +33,29 @@ type Service struct {
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
userService *service.UserService
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.DallTask) {
|
||||
logger.Infof("add a new DALL-E task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push dall-e task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
// 将数据库中未提交的任务加载到队列
|
||||
var jobs []model.DallJob
|
||||
s.db.Where("progress", 0).Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
@@ -84,16 +79,16 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
go func() {
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -138,7 +133,11 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
s.db.Where("id = ?", task.ModelId).First(&chatModel)
|
||||
if task.ModelId > 0 {
|
||||
s.db.Where("id", task.ModelId).First(&chatModel)
|
||||
} else {
|
||||
s.db.Where("value", task.ModelName).First(&chatModel)
|
||||
}
|
||||
|
||||
// get image generation API KEY
|
||||
var apiKey model.ApiKey
|
||||
@@ -184,9 +183,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||
}
|
||||
|
||||
all, _ := io.ReadAll(r.Body)
|
||||
logger.Debugf("response: %+v", string(all))
|
||||
|
||||
// update the api key last use time
|
||||
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
var imgURL string
|
||||
@@ -212,10 +208,9 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.Id, int(task.UserId), res.Data[0].Url)
|
||||
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
@@ -225,26 +220,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task status checking ...")
|
||||
@@ -254,7 +229,7 @@ func (s *Service) CheckTaskStatus() {
|
||||
s.db.Where("progress < ?", 100).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
// 超时的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||
if time.Since(job.CreatedAt) > time.Minute*10 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
@@ -269,7 +244,7 @@ func (s *Service) CheckTaskStatus() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
err = s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
|
||||
err = s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: task.ModelName,
|
||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
|
||||
@@ -301,7 +276,7 @@ func (s *Service) DownloadImages() {
|
||||
}
|
||||
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, v.OrgURL)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: %s, error: %v", imgURL, err)
|
||||
continue
|
||||
@@ -316,9 +291,9 @@ func (s *Service) DownloadImages() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
|
||||
func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, ".png", false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -328,6 +303,5 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
||||
139
api/service/jimeng/client.go
Normal file
139
api/service/jimeng/client.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
)
|
||||
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(accessKey, secretKey string) *Client {
|
||||
// 使用官方SDK的visual实例
|
||||
visualInstance := visual.NewInstance()
|
||||
visualInstance.Client.SetAccessKey(accessKey)
|
||||
visualInstance.Client.SetSecretKey(secretKey)
|
||||
|
||||
// 添加即梦AI专有的API配置
|
||||
jimengApis := map[string]*base.ApiInfo{
|
||||
"CVSync2AsyncSubmitTask": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSync2AsyncSubmitTask"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVSync2AsyncGetResult": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSync2AsyncGetResult"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVProcess": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVProcess"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 将即梦API添加到现有的ApiInfoList中
|
||||
for name, info := range jimengApis {
|
||||
visualInstance.Client.ApiInfoList[name] = info
|
||||
}
|
||||
|
||||
return &Client{
|
||||
visual: visualInstance,
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
// 直接将请求转为map[string]interface{}
|
||||
reqBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 直接使用序列化后的字节
|
||||
jsonBody := reqBodyBytes
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncSubmitTask", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// QueryTask 查询任务结果
|
||||
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncGetResult", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
600
api/service/jimeng/service.go
Normal file
600
api/service/jimeng/service.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// Service 即梦服务(合并了消费者功能)
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
taskQueue *store.RedisQueue
|
||||
client *Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
running bool
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
redis: redisCli,
|
||||
taskQueue: taskQueue,
|
||||
client: client,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: false,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动服务(包含消费者)
|
||||
func (s *Service) Start() {
|
||||
if s.running {
|
||||
return
|
||||
}
|
||||
logger.Info("Starting Jimeng service and task consumer...")
|
||||
s.running = true
|
||||
go s.consumeTasks()
|
||||
go s.pollTaskStatus()
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (s *Service) Stop() {
|
||||
if !s.running {
|
||||
return
|
||||
}
|
||||
logger.Info("Stopping Jimeng service and task consumer...")
|
||||
s.running = false
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// consumeTasks 消费任务
|
||||
func (s *Service) consumeTasks() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
logger.Info("Jimeng task consumer stopped")
|
||||
return
|
||||
default:
|
||||
s.processNextTask()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processNextTask 处理下一个任务
|
||||
func (s *Service) processNextTask() {
|
||||
var jobId uint
|
||||
if err := s.taskQueue.LPop(&jobId); err != nil {
|
||||
// 队列为空,等待1秒后重试
|
||||
time.Sleep(time.Second)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("Processing Jimeng task: job_id=%d", jobId)
|
||||
|
||||
if err := s.ProcessTask(jobId); err != nil {
|
||||
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
||||
} else {
|
||||
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
||||
// 生成任务ID
|
||||
taskId := utils.RandString(20)
|
||||
|
||||
// 序列化任务参数
|
||||
paramsJson, err := json.Marshal(req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务记录
|
||||
job := &model.JimengJob{
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.Type,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
TaskParams: string(paramsJson),
|
||||
Status: model.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
if err := s.db.Create(job).Error; err != nil {
|
||||
return nil, fmt.Errorf("create jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 推送到任务队列
|
||||
if err := s.taskQueue.RPush(job.Id); err != nil {
|
||||
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ProcessTask 处理任务
|
||||
func (s *Service) ProcessTask(jobId uint) error {
|
||||
// 获取任务记录
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return fmt.Errorf("get jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新任务状态为处理中
|
||||
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建请求并提交任务
|
||||
req, err := s.buildTaskRequest(&job)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||
}
|
||||
|
||||
logger.Infof("提交即梦任务: %+v", req)
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
logger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建基础请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 根据任务类型设置特定参数
|
||||
switch job.Type {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
s.setTextToImageParams(req, params)
|
||||
case model.JMTaskTypeImageToImage:
|
||||
s.setImageToImageParams(req, params)
|
||||
case model.JMTaskTypeImageEdit:
|
||||
s.setImageEditParams(req, params)
|
||||
case model.JMTaskTypeImageEffects:
|
||||
s.setImageEffectsParams(req, params)
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
s.setTextToVideoParams(req, params)
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
s.setImageToVideoParams(req, params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// setTextToImageParams 设置文生图参数
|
||||
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
||||
req.UsePreLLM = usePreLlmVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setImageToImageParams 设置图生图参数
|
||||
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput, ok := params["image_input"].(string); ok {
|
||||
req.ImageInput = imageInput
|
||||
}
|
||||
if gpen, ok := params["gpen"]; ok {
|
||||
if gpenVal, ok := gpen.(float64); ok {
|
||||
req.Gpen = gpenVal
|
||||
}
|
||||
}
|
||||
if skin, ok := params["skin"]; ok {
|
||||
if skinVal, ok := skin.(float64); ok {
|
||||
req.Skin = skinVal
|
||||
}
|
||||
}
|
||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
||||
req.SkinUnifi = skinUnifiVal
|
||||
}
|
||||
}
|
||||
if genMode, ok := params["gen_mode"].(string); ok {
|
||||
req.GenMode = genMode
|
||||
}
|
||||
s.setCommonParams(req, params) // 复用通用参数
|
||||
}
|
||||
|
||||
// setImageEditParams 设置图像编辑参数
|
||||
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageEffectsParams 设置图像特效参数
|
||||
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||
req.ImageInput1 = imageInput1
|
||||
}
|
||||
if templateId, ok := params["template_id"].(string); ok {
|
||||
req.TemplateId = templateId
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setTextToVideoParams 设置文生视频参数
|
||||
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageToVideoParams 设置图生视频参数
|
||||
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
}
|
||||
|
||||
// setCommonParams 设置通用参数(seed, width, height等)
|
||||
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
func (s *Service) pollTaskStatus() {
|
||||
|
||||
for {
|
||||
var jobs []model.JimengJob
|
||||
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
|
||||
if len(jobs) == 0 {
|
||||
logger.Debugf("no jimeng task to poll, sleep 10s")
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 任务超时处理
|
||||
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
|
||||
s.handleTaskError(job.Id, "task timeout")
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
TaskId: job.TaskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("query jimeng task status failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
continue
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
continue
|
||||
}
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 设置结果URL
|
||||
if len(resp.Data.ImageUrls) > 0 {
|
||||
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.ImageUrls[0], ".png", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload image failed: %v", err)
|
||||
imgUrl = resp.Data.ImageUrls[0]
|
||||
}
|
||||
updates["img_url"] = imgUrl
|
||||
}
|
||||
if resp.Data.VideoUrl != "" {
|
||||
videoUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.VideoUrl, ".mp4", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload video failed: %v", err)
|
||||
videoUrl = resp.Data.VideoUrl
|
||||
}
|
||||
updates["video_url"] = videoUrl
|
||||
}
|
||||
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
// 任务未找到
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
// 任务过期
|
||||
s.handleTaskError(job.Id, "task expired")
|
||||
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
|
||||
updates := map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
if errMsg != "" {
|
||||
updates["err_msg"] = errMsg
|
||||
}
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||
}
|
||||
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||
func (s *Service) PushTaskToQueue(jobId uint) error {
|
||||
return s.taskQueue.RPush(jobId)
|
||||
}
|
||||
|
||||
// GetTaskStats 获取任务统计信息
|
||||
func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := s.db.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"total": int64(0),
|
||||
"completed": int64(0),
|
||||
"processing": int64(0),
|
||||
"failed": int64(0),
|
||||
"pending": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["total"] = result["total"].(int64) + stat.Count
|
||||
result[stat.Status] = stat.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetJob 获取任务
|
||||
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := testClient.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClientConfig 更新客户端配置
|
||||
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
// 创建新的客户端
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
s.client = newClient
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultPower = types.JimengPower{
|
||||
TextToImage: 20,
|
||||
ImageToImage: 20,
|
||||
ImageEdit: 20,
|
||||
ImageEffects: 20,
|
||||
TextToVideo: 300,
|
||||
ImageToVideo: 300,
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (s *Service) GetConfig() *types.JimengConfig {
|
||||
var config model.Config
|
||||
err := s.db.Where("name", "jimeng").First(&config).Error
|
||||
if err != nil {
|
||||
// 如果配置不存在,返回默认配置
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
var jimengConfig types.JimengConfig
|
||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
return &jimengConfig
|
||||
}
|
||||
145
api/service/jimeng/types.go
Normal file
145
api/service/jimeng/types.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package jimeng
|
||||
|
||||
import "geekai/store/model"
|
||||
|
||||
// ReqKey 常量定义
|
||||
const (
|
||||
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
|
||||
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
|
||||
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
|
||||
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
|
||||
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
|
||||
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
|
||||
)
|
||||
|
||||
// SubmitTaskRequest 提交任务请求
|
||||
type SubmitTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
// 文生图参数
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Scale float64 `json:"scale,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||
// 图生图参数
|
||||
ImageInput string `json:"image_input,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
Gpen float64 `json:"gpen,omitempty"`
|
||||
Skin float64 `json:"skin,omitempty"`
|
||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
||||
GenMode string `json:"gen_mode,omitempty"`
|
||||
// 图像编辑参数
|
||||
// 图像特效参数
|
||||
ImageInput1 string `json:"image_input1,omitempty"`
|
||||
TemplateId string `json:"template_id,omitempty"`
|
||||
// 视频生成参数
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitTaskResponse 提交任务响应
|
||||
type SubmitTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// QueryTaskRequest 查询任务请求
|
||||
type QueryTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
TaskId string `json:"task_id"`
|
||||
ReqJson string `json:"req_json,omitempty"`
|
||||
}
|
||||
|
||||
// QueryTaskResponse 查询任务响应
|
||||
type QueryTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
AlgorithmBaseResp struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMessage string `json:"status_message"`
|
||||
} `json:"algorithm_base_resp"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
LlmResult string `json:"llm_result"`
|
||||
PeResult string `json:"pe_result"`
|
||||
PredictTagsResult string `json:"predict_tags_result"`
|
||||
RephraserResult string `json:"rephraser_result"`
|
||||
VlmResult string `json:"vlm_result"`
|
||||
InferCtx any `json:"infer_ctx"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// CreateTaskRequest 创建任务请求
|
||||
type CreateTaskRequest struct {
|
||||
Type model.JMTaskType `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Params map[string]any `json:"params"`
|
||||
ReqKey string `json:"req_key"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Power int `json:"power,omitempty"`
|
||||
}
|
||||
|
||||
// LogoInfo 水印信息
|
||||
type LogoInfo struct {
|
||||
AddLogo bool `json:"add_logo"`
|
||||
Position int `json:"position"`
|
||||
Language int `json:"language"`
|
||||
Opacity float64 `json:"opacity"`
|
||||
LogoTextContent string `json:"logo_text_content"`
|
||||
}
|
||||
|
||||
// ReqJsonConfig 查询配置
|
||||
type ReqJsonConfig struct {
|
||||
ReturnUrl bool `json:"return_url"`
|
||||
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
|
||||
}
|
||||
|
||||
// ImageEffectTemplate 图像特效模板
|
||||
const (
|
||||
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
|
||||
TemplateIdMyWorld = "my_world" // 像素世界风
|
||||
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
|
||||
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
|
||||
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
|
||||
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
|
||||
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
|
||||
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
|
||||
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
|
||||
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
|
||||
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
|
||||
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
|
||||
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
|
||||
TemplateIdGlassBall = "glass_ball" // 玻璃球
|
||||
)
|
||||
|
||||
// AspectRatio 视频宽高比
|
||||
const (
|
||||
AspectRatio16_9 = "16:9" // 1280×720
|
||||
AspectRatio9_16 = "9:16" // 720×1280
|
||||
AspectRatio1_1 = "1:1" // 960×960
|
||||
AspectRatio4_3 = "4:3" // 960×720
|
||||
AspectRatio3_4 = "3:4" // 720×960
|
||||
AspectRatio21_9 = "21:9" // 1680×720
|
||||
AspectRatio9_21 = "9:21" // 720×1680
|
||||
)
|
||||
|
||||
// GenMode 生成模式
|
||||
const (
|
||||
GenModeCreative = "creative" // 提示词模式
|
||||
GenModeReference = "reference" // 全参考模式
|
||||
GenModeReferenceChar = "reference_char" // 人物参考模式
|
||||
)
|
||||
@@ -31,7 +31,8 @@ type LicenseService struct {
|
||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||
var license types.License
|
||||
var machineId string
|
||||
_ = levelDB.Get(types.LicenseKey, &license)
|
||||
err := levelDB.Get(types.LicenseKey, &license)
|
||||
logger.Infof("License: %+v", server.SysConfig)
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
machineId = info.HostID
|
||||
@@ -101,7 +102,7 @@ func (s *LicenseService) SyncLicense() {
|
||||
if err != nil {
|
||||
retryCounter++
|
||||
if retryCounter < 5 {
|
||||
logger.Warn(err)
|
||||
logger.Debug(err)
|
||||
}
|
||||
s.license.IsActive = false
|
||||
} else {
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -26,23 +27,17 @@ import (
|
||||
type Service struct {
|
||||
client *Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
wsService *service.WebsocketService
|
||||
uploaderManager *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
|
||||
client: client,
|
||||
wsService: wsService,
|
||||
uploaderManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
@@ -59,7 +54,6 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
task.Id = v.Id
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
s.PushTask(task)
|
||||
}
|
||||
|
||||
@@ -96,7 +90,6 @@ func (s *Service) Run() {
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
}
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
|
||||
var job model.MidJourneyJob
|
||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||
@@ -139,7 +132,6 @@ func (s *Service) Run() {
|
||||
// update the task progress
|
||||
s.db.Updates(&job)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
@@ -178,24 +170,6 @@ func GetImageHash(action string) string {
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChMj, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
go func() {
|
||||
var items []model.MidJourneyJob
|
||||
@@ -217,7 +191,7 @@ func (s *Service) DownloadImages() {
|
||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||
proxy = true
|
||||
}
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, ".png", proxy)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||
@@ -228,12 +202,6 @@ func (s *Service) DownloadImages() {
|
||||
|
||||
v.ImgURL = imgURL
|
||||
s.db.Updates(&v)
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[v.Id],
|
||||
UserId: v.UserId,
|
||||
JobId: int(v.Id),
|
||||
Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
@@ -244,7 +212,9 @@ func (s *Service) DownloadImages() {
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.MjTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push mj task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
@@ -259,7 +229,7 @@ func (s *Service) SyncTaskProgress() {
|
||||
|
||||
for _, job := range jobs {
|
||||
// 10 分钟还没完成的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||
if time.Since(job.CreatedAt) > time.Minute*10 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
@@ -279,18 +249,12 @@ func (s *Service) SyncTaskProgress() {
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
logger.Errorf("task failed: %v", task.FailReason)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(task.Buttons) > 0 {
|
||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := job.Progress
|
||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
if task.ImageUrl != "" {
|
||||
job.OrgURL = task.ImageUrl
|
||||
@@ -300,19 +264,6 @@ func (s *Service) SyncTaskProgress() {
|
||||
logger.Errorf("error with update database: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
message := service.TaskStatusRunning
|
||||
if job.Progress == 100 {
|
||||
message = service.TaskStatusFinished
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: message})
|
||||
}
|
||||
}
|
||||
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
|
||||
@@ -84,7 +84,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -99,8 +99,10 @@ func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
@@ -37,7 +38,7 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
return File{}, fmt.Errorf("error with get form: %v", err)
|
||||
}
|
||||
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, "")
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
|
||||
}
|
||||
@@ -57,13 +58,13 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
filename := filepath.Base(parse.Path)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with generate image dir: %v", err)
|
||||
}
|
||||
@@ -85,7 +86,7 @@ func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||
filePath, _ := utils.GenUploadPath(s.config.BasePath, "", ".png")
|
||||
err = os.WriteFile(filePath, imageData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing to file:%v", err)
|
||||
|
||||
@@ -44,7 +44,7 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -59,8 +59,10 @@ func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := filepath.Ext(parse.Path)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -86,7 +88,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}
|
||||
defer fileReader.Close()
|
||||
|
||||
fileExt := utils.GetImgExt(file.Filename)
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
|
||||
@@ -93,7 +93,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -108,8 +108,10 @@ func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
|
||||
@@ -23,7 +23,7 @@ type File struct {
|
||||
}
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (File, error)
|
||||
PutUrlFile(url string, useProxy bool) (string, error)
|
||||
PutUrlFile(url string, ext string, useProxy bool) (string, error)
|
||||
PutBase64(imageData string) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -30,20 +31,16 @@ var logger = logger2.GetLogger()
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
wsService *service.WebsocketService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C(),
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
}
|
||||
@@ -102,8 +99,6 @@ func (s *Service) Run() {
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -225,15 +220,12 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
resp, err := s.checkTaskProgress(apiKey)
|
||||
// 更新任务进度
|
||||
if err == nil && resp.Progress > 0 {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@@ -242,7 +234,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
|
||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, error) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().
|
||||
@@ -250,37 +242,20 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressRe
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return nil, err
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return fmt.Errorf("error http code status: %v", response.Status), nil
|
||||
return nil, fmt.Errorf("error http code status: %v", response.Status)
|
||||
}
|
||||
|
||||
return nil, &res
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSd, message.Message)
|
||||
}
|
||||
}()
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push sd task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
@@ -297,7 +272,7 @@ func (s *Service) CheckTaskStatus() {
|
||||
|
||||
for _, job := range jobs {
|
||||
// 5 分钟还没完成的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
||||
if time.Since(job.CreatedAt) > time.Minute*5 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
|
||||
@@ -18,10 +18,11 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -34,27 +35,25 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[string]string
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
||||
uploadManager: manager,
|
||||
wsService: wsService,
|
||||
clientIds: map[string]string{},
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.SunoTask) {
|
||||
logger.Infof("add a new Suno task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push suno task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
@@ -70,7 +69,6 @@ func (s *Service) Run() {
|
||||
}
|
||||
task.Id = v.Id
|
||||
s.PushTask(task)
|
||||
s.clientIds[v.TaskId] = task.ClientId
|
||||
}
|
||||
logger.Info("Starting Suno job consumer...")
|
||||
go func() {
|
||||
@@ -95,16 +93,16 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("任务提交成功: %+v", r)
|
||||
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Data,
|
||||
"channel": r.Channel,
|
||||
})
|
||||
s.clientIds[r.Data] = task.ClientId
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -133,20 +131,20 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
"continue_clip_id": task.RefSongId,
|
||||
"continue_at": task.ExtendSecs,
|
||||
"make_instrumental": task.Instrumental,
|
||||
"mv": task.Model,
|
||||
}
|
||||
// 灵感模式
|
||||
if task.Type == 1 {
|
||||
reqBody["gpt_description_prompt"] = task.Prompt
|
||||
} else { // 自定义模式
|
||||
reqBody["prompt"] = task.Prompt
|
||||
reqBody["prompt"] = task.Lyrics
|
||||
reqBody["tags"] = task.Tags
|
||||
reqBody["mv"] = task.Model
|
||||
reqBody["title"] = task.Title
|
||||
}
|
||||
|
||||
var res RespVo
|
||||
apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
logger.Debugf("API URL: %s, request body: %s", apiURL, utils.JsonEncode(reqBody))
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
@@ -262,27 +260,6 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Suno task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
logger.Debugf("client id: %+v", s.wsService.Clients)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
logger.Debugf("%+v", client)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSuno, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadFiles() {
|
||||
go func() {
|
||||
var items []model.SunoJob
|
||||
@@ -295,14 +272,14 @@ func (s *Service) DownloadFiles() {
|
||||
for _, v := range items {
|
||||
// 下载图片和音频
|
||||
logger.Infof("try download cover image: %s", v.CoverURL)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, ".png", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download audio: %s", v.AudioURL)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, ".mp3", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
@@ -311,7 +288,6 @@ func (s *Service) DownloadFiles() {
|
||||
v.AudioURL = audioURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
@@ -377,12 +353,10 @@ func (s *Service) SyncTaskProgress() {
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,109 +12,89 @@ type NotifyMessage struct {
|
||||
ClientId string `json:"client_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
|
||||
const ImagePromptOptimizeTemplate = `
|
||||
Create a highly effective prompt to provide to an AI image generation tool in order to create an artwork based on a desired concept.
|
||||
以下是一条 AI 提示词示例,用于优化和扩写绘图提示词:
|
||||
|
||||
Please specify details about the artwork, such as the style, subject, mood, and other important characteristics you want the resulting image to have.
|
||||
请你作为一名专业的 AI 绘图提示词优化专家,基于用户提供的简单绘图描述,生成一份详细、专业且富有创意的 AI 绘图提示词指令。在优化过程中,你需要做到以下几点:
|
||||
|
||||
Remember, prompts should always be output in English.
|
||||
1. 深入理解用户描述的核心意图和关键元素,挖掘潜在的细节和情感氛围,将其融入到提示词中。
|
||||
2. 丰富画面细节,包括但不限于场景背景、人物特征、物体属性、光影效果、色彩搭配等,使画面更加生动逼真。
|
||||
3. 运用专业的艺术风格术语,如超现实主义、印象派、赛博朋克等,为画面增添独特的艺术魅力。
|
||||
4. 考虑构图和视角,如俯视、仰视、特写、全景等,提升画面的视觉冲击力。
|
||||
5. 确保提示词指令清晰、准确、完整,便于 AI 绘图模型理解和生成高质量图像。最终输出的提示词应简洁明了,避免冗余信息,以逗号分隔各个元素,突出重点,
|
||||
让用户能够直接复制使用,从而帮助用户将简单的想法转化为精美绝伦的画作。
|
||||
6. 不管用户输入的是什么语言,你务必要用英文输出优化后的提示词。
|
||||
7. 直接输出优化后的提示词,不要输出其他任何五官内容。
|
||||
|
||||
# Steps
|
||||
下面是一个提示词优化示例:
|
||||
===示例开始===
|
||||
原始指令 :一个穿着红色连衣裙的少女在花园里浇花,阳光明媚。
|
||||
|
||||
1. **Subject Description**: Describe the main subject of the image clearly. Include as much detail as possible about what should be in the scene. For example, "a majestic lion roaring at sunrise" or "a futuristic city with flying cars."
|
||||
|
||||
2. **Art Style**: Specify the art style you envision. Possible options include 'realistic', 'impressionist', a specific artist name, or imaginative styles like "cyberpunk." This helps the AI achieve your visual expectations.
|
||||
优化后的 AI 绘图提示词指令:一位年轻美丽的少女,约 16 - 18 岁,有着柔顺的黑色长发,披散在肩上,面容精致,眼神温柔而专注。她穿着一条复古风格的红色连衣裙,裙子上有精致的褶皱和白色的蕾丝花边,裙摆轻轻飘动。少女站在一个充满生机的花园中,花园里种满了各种各样的鲜花,有娇艳的玫瑰、淡雅的百合、缤纷的郁金香等,花朵色彩鲜艳,绿叶繁茂。她手持一个银色的 watering can(浇水壶),正在细心地给一朵盛开的玫瑰浇水。阳光从画面的右侧洒下,形成明亮而温暖的光晕,照亮了少女和整个花园,营造出一种宁静、美好的氛围,画面采用写实风格,光影效果逼真,色彩鲜明且富有层次感,构图以少女为中心,前景是盛开的花朵,背景是花园的树木和篱笆,整体画面充满诗意和浪漫气息。
|
||||
===示例结束===
|
||||
|
||||
3. **Mood or Atmosphere**: Convey the feeling you want the image to evoke. For instance, peaceful, chaotic, epic, etc.
|
||||
|
||||
4. **Color Palette and Lighting**: Mention color preferences or lighting. For example, "vibrant with shades of blue and purple" or "dim and dramatic lighting."
|
||||
|
||||
5. **Optional Features**: You can add any additional attributes, such as background details, attention to textures, or any specific kind of framing.
|
||||
|
||||
# Output Format
|
||||
|
||||
- **Prompt Format**: A descriptive phrase that includes key aspects of the artwork (subject, style, mood, colors, lighting, any optional features).
|
||||
|
||||
Here is an example of how the final prompt should look:
|
||||
|
||||
"An ethereal landscape featuring towering ice mountains, in an impressionist style reminiscent of Claude Monet, with a serene mood. The sky is glistening with soft purples and whites, with a gentle morning sun illuminating the scene."
|
||||
|
||||
**Please input the prompt words directly in English, and do not input any other explanatory statements**
|
||||
|
||||
# Examples
|
||||
|
||||
1. **Input**:
|
||||
- Subject: A white tiger in a dense jungle
|
||||
- Art Style: Realistic
|
||||
- Mood: Intense, mysterious
|
||||
- Lighting: Dramatic contrast with light filtering through leaves
|
||||
|
||||
**Output Prompt**: "A realistic rendering of a white tiger stealthily moving through a dense jungle, with an intense, mysterious mood. The lighting creates strong contrasts as beams of sunlight filter through a thick canopy of leaves."
|
||||
|
||||
2. **Input**:
|
||||
- Subject: An enchanted castle on a floating island
|
||||
- Art Style: Fantasy
|
||||
- Mood: Majestic, magical
|
||||
- Colors: Bright blues, greens, and gold
|
||||
|
||||
**Output Prompt**: "A majestic fantasy castle on a floating island above the clouds, with bright blues, greens, and golds to create a magical, dreamy atmosphere. Textured cobblestone details and glistening waters surround the scene."
|
||||
|
||||
# Notes
|
||||
|
||||
- Ensure that you mix different aspects to get a comprehensive and visually compelling prompt.
|
||||
- Be as descriptive as possible as it often helps generate richer, more detailed images.
|
||||
- If you want the image to resemble a particular artist's work, be sure to mention the artist explicitly. e.g., "in the style of Van Gogh."
|
||||
|
||||
The theme of the creation is:【%s】
|
||||
现在用户输入的原始提示词为:【%s】
|
||||
`
|
||||
|
||||
const LyricPromptTemplate = `
|
||||
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
|
||||
请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。
|
||||
输出格式如下:
|
||||
下面是一个标准的歌词输出模板:
|
||||
歌曲名称
|
||||
第一节:
|
||||
{{歌词内容}}
|
||||
副歌:
|
||||
{{歌词内容}}
|
||||
|
||||
第二节:
|
||||
{{歌词内容}}
|
||||
副歌:
|
||||
{{歌词内容}}
|
||||
[Verse]
|
||||
[歌词]
|
||||
|
||||
尾声:
|
||||
{{歌词内容}}
|
||||
[Verse 2]
|
||||
[歌词]
|
||||
|
||||
[Chorus]
|
||||
[歌词]
|
||||
|
||||
[Verse 3]
|
||||
[歌词]
|
||||
|
||||
[Bridge]
|
||||
[歌词]
|
||||
|
||||
[Chorus]
|
||||
[歌词]
|
||||
|
||||
[Verse 4]
|
||||
[歌词]
|
||||
|
||||
[Bridge]
|
||||
假如此刻眼泪能倒流
|
||||
让我学会微笑不掩忧
|
||||
一次次的碎片堆积的愁
|
||||
最终也会开成希望的秋
|
||||
|
||||
[Chorus]
|
||||
假如我还能牵你的手
|
||||
天空也许会更蔚蓝悠游
|
||||
曾经那些未完成的错过
|
||||
愿能变成今天的收获
|
||||
`
|
||||
|
||||
const VideoPromptTemplate = `
|
||||
As an expert in video generation prompts, please create a detailed descriptive prompt for the following video concept. The description should include the setting, character appearance, actions, overall atmosphere, and camera angles. Please make it as detailed and vivid as possible to help ensure that every aspect of the video is accurately captured.
|
||||
const VideoPromptTemplate = `## 任务描述
|
||||
你是一位优秀AI视频创作专家,擅长编写专业的AI视频提示词,现在你的任务是对用户输入的简单视频描述提示词进行专业优化和扩写,使其转化为详细的、具备专业影视画面感的 AI 生成视频提示词指令。需涵盖风格、主体元素、环境氛围、细节特征、人物状态(若有)、镜头运用及整体氛围营造等方面,以生动形象、富有感染力且精准的描述,引导 AI 生成高质量的视频内容。下面是一个示例:
|
||||
===示例开始===
|
||||
输入: “汽车在沙漠功能上行驶”,
|
||||
输出: “纪实摄影风格,一辆尘土飞扬的复古越野车在无垠的沙漠公路上疾驰,车身线条硬朗,漆面斑驳,透露出岁月的痕迹。驾驶室内的司机戴着墨镜,专注地握着方向盘,眼神坚定地望向前方。夕阳的余晖洒在车身上,沙漠的沙丘在远处延绵起伏,一片金黄。广角镜头捕捉到车辆行驶时扬起的沙尘,营造出动感与冒险的氛围。远景全貌,强调速度感与环境辽阔。”
|
||||
===示例结束===
|
||||
|
||||
Please remember that regardless of the user’s input, the final output must be in English.
|
||||
## 输出要求:
|
||||
1. 直接输出扩写后的提示词就好,不要输出其他任何不相关信息
|
||||
2. 如果用户用中文提问,你就用中文回答,如果用英文提问,你也必须用英文回答。
|
||||
3. 请确保提示词的长度长度在1000个字以内。
|
||||
|
||||
# Details to Include
|
||||
|
||||
- Describe the overall visual style of the video (e.g., animated, realistic, retro tone, etc.)
|
||||
- Identify key characters or objects in the video and describe their appearance, attire, and expressions
|
||||
- Describe the environment of the scene, including weather, lighting, colors, and important details
|
||||
- Explain the behavior and interactions of the characters
|
||||
- Include any unique camera angles, movements, or special effects
|
||||
|
||||
# Output Format
|
||||
Provide the prompt in paragraph form, ensuring that the description is detailed enough for a video generation system to recreate the envisioned scene. Include the beginning, middle, and end of the scene to convey a complete storyline.
|
||||
|
||||
# Example
|
||||
**User Input:**
|
||||
“A small cat basking in the sun on a balcony.”
|
||||
|
||||
**Generated Prompt:**
|
||||
On a bright spring afternoon, an orange-striped kitten lies lazily on a balcony, basking in the warm sunlight. The iron railings around the balcony cast soft shadows that dance gently with the light. The cat’s eyes are half-closed, exuding a sense of contentment and tranquility in its surroundings. In the distance, a few fluffy white clouds drift slowly across the blue sky. The camera initially focuses on the cat’s face, capturing the delicate details of its fur, and then gradually zooms out to reveal the full balcony scene, immersing viewers in a moment of calm and relaxation.
|
||||
|
||||
The theme of the creation is:【%s】
|
||||
=====
|
||||
用户的输入的视频主题是:【%s】
|
||||
`
|
||||
|
||||
const MetaPromptTemplate = `
|
||||
@@ -133,7 +113,7 @@ Please remember, the final output must be the same language with user’s input.
|
||||
- What kinds of examples may need to be included, how many, and whether they are complex enough to benefit from placeholders.
|
||||
- Clarity and Conciseness: Use clear, specific language. Avoid unnecessary instructions or bland statements.
|
||||
- Formatting: Use markdown features for readability. DO NOT USE CODE BLOCKS UNLESS SPECIFICALLY REQUESTED.
|
||||
- Preserve User Content: If the input task or prompt includes extensive guidelines or examples, preserve them entirely, or as closely as possible. If they are vague, consider breaking down into sub-steps. Keep any details, guidelines, examples, variables, or placeholders provided by the user.
|
||||
- Preserve User Prompt: If the input task or prompt includes extensive guidelines or examples, preserve them entirely, or as closely as possible. If they are vague, consider breaking down into sub-steps. Keep any details, guidelines, examples, variables, or placeholders provided by the user.
|
||||
- Constants: DO include constants in the prompt, as they are not susceptible to prompt injection. Such as guides, rubrics, and examples.
|
||||
- Output Format: Explicitly the most appropriate output format, in detail. This should include length and syntax (e.g. short sentence, paragraph, JSON, etc.)
|
||||
- For tasks outputting well-defined or structured data (classification, JSON, etc.) bias toward outputting a JSON.
|
||||
|
||||
@@ -4,9 +4,10 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"gorm.io/gorm"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
@@ -19,7 +20,7 @@ func NewUserService(db *gorm.DB) *UserService {
|
||||
}
|
||||
|
||||
// IncreasePower 增加用户算力
|
||||
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
|
||||
func (s *UserService) IncreasePower(userId uint, power int, log model.PowerLog) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@@ -51,7 +52,7 @@ func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) e
|
||||
}
|
||||
|
||||
// DecreasePower 减少用户算力
|
||||
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
|
||||
func (s *UserService) DecreasePower(userId uint, power int, log model.PowerLog) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
package video
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.VideoTask) {
|
||||
logger.Infof("add a new Video task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
var jobs []model.VideoJob
|
||||
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
var task types.VideoTask
|
||||
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||
if err != nil {
|
||||
logger.Errorf("decode task info with error: %v", err)
|
||||
continue
|
||||
}
|
||||
task.Id = v.Id
|
||||
s.PushTask(task)
|
||||
s.clientIds[v.Id] = task.ClientId
|
||||
}
|
||||
logger.Info("Starting Video job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
var task types.VideoTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if task.ClientId != "" {
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
}
|
||||
|
||||
var r LumaRespVo
|
||||
r, err = s.LumaCreate(task)
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
"cover_url": "/images/failed.jpg",
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Id,
|
||||
"channel": r.Channel,
|
||||
"prompt_ext": r.Prompt,
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
s.PushTask(task)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type LumaRespVo struct {
|
||||
Id string `json:"id"`
|
||||
Prompt string `json:"prompt"`
|
||||
State string `json:"state"`
|
||||
QueueState interface{} `json:"queue_state"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Video interface{} `json:"video"`
|
||||
VideoRaw interface{} `json:"video_raw"`
|
||||
Liked interface{} `json:"liked"`
|
||||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||
Thumbnail interface{} `json:"thumbnail"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return LumaRespVo{}, errors.New("no available API KEY for Luma")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"user_prompt": task.Prompt,
|
||||
"expand_prompt": task.Params.PromptOptimize,
|
||||
"loop": task.Params.Loop,
|
||||
"image_url": task.Params.StartImgURL,
|
||||
"image_end_url": task.Params.EndImgURL,
|
||||
}
|
||||
var res LumaRespVo
|
||||
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.StatusCode != 200 && r.StatusCode != 201 {
|
||||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
// update the last_use_at for api key
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
session.Updates(&apiKey)
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Suno task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("Receive notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChLuma, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadFiles() {
|
||||
go func() {
|
||||
var items []model.VideoJob
|
||||
for {
|
||||
res := s.db.Where("progress", 102).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
if v.WaterURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download video: %s", v.WaterURL)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("download video success: %s", videoURL)
|
||||
v.WaterURL = videoURL
|
||||
|
||||
if v.VideoURL != "" {
|
||||
logger.Infof("try download no water video: %s", v.VideoURL)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
logger.Infof("download no water video success: %s", videoURL)
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.VideoJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
|
||||
if err != nil {
|
||||
logger.Errorf("query task with error: %v", err)
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("task: %+v", task)
|
||||
if task.State == "completed" { // 更新任务信息
|
||||
data := map[string]interface{}{
|
||||
"progress": 102, // 102 表示资源未下载完成,
|
||||
"water_url": task.Video.Url,
|
||||
"raw_data": utils.JsonEncode(task),
|
||||
"prompt_ext": task.Prompt,
|
||||
"cover_url": task.Thumbnail.Url,
|
||||
}
|
||||
if task.Video.DownloadUrl != "" {
|
||||
data["video_url"] = task.Video.DownloadUrl
|
||||
}
|
||||
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||||
if err != nil {
|
||||
logger.Errorf("更新数据库失败:%v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "luma",
|
||||
Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 更新任务状态
|
||||
s.db.Model(&job).UpdateColumn("power", 0)
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type LumaTaskVo struct {
|
||||
Id string `json:"id"`
|
||||
Liked interface{} `json:"liked"`
|
||||
State string `json:"state"`
|
||||
Video struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
DownloadUrl string `json:"download_url"`
|
||||
} `json:"video"`
|
||||
Prompt string `json:"prompt"`
|
||||
UserId string `json:"user_id"`
|
||||
BatchId string `json:"batch_id"`
|
||||
Thumbnail struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"thumbnail"`
|
||||
VideoRaw struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"video_raw"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastFrame struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"last_frame"`
|
||||
}
|
||||
|
||||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
|
||||
Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
|
||||
var res LumaTaskVo
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if r.StatusCode != 200 {
|
||||
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
663
api/service/video/video.go
Normal file
663
api/service/video/video.go
Normal file
@@ -0,0 +1,663 @@
|
||||
package video
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.VideoTask) {
|
||||
logger.Infof("add a new Video task to the task list: %+v", task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push video task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的任务加载到队列
|
||||
var jobs []model.VideoJob
|
||||
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
var task types.VideoTask
|
||||
err := utils.JsonDecode(v.TaskInfo, &task)
|
||||
if err != nil {
|
||||
logger.Errorf("decode task info with error: %v", err)
|
||||
continue
|
||||
}
|
||||
task.Id = v.Id
|
||||
s.PushTask(task)
|
||||
}
|
||||
logger.Info("Starting Video job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
var task types.VideoTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if task.Type == types.VideoLuma {
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
var r LumaRespVo
|
||||
r, err = s.LumaCreate(task)
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
"cover_url": "/images/failed.jpg",
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Id,
|
||||
"channel": r.Channel,
|
||||
"prompt_ext": r.Prompt,
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
s.PushTask(task)
|
||||
}
|
||||
} else if task.Type == types.VideoKeLing {
|
||||
var r KeLingRespVo
|
||||
r, err = s.KeLingCreate(task)
|
||||
logger.Debugf("ke ling create task result: %+v", r)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
"cover_url": "/images/failed.jpg",
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Data.TaskID,
|
||||
"channel": r.Channel,
|
||||
"prompt_ext": task.Prompt,
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
s.PushTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadFiles() {
|
||||
go func() {
|
||||
var items []model.VideoJob
|
||||
for {
|
||||
res := s.db.Where("progress", 102).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
if v.WaterURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download video: %s", v.WaterURL)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, ".mp4", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("download video success: %s", videoURL)
|
||||
v.WaterURL = videoURL
|
||||
|
||||
if v.VideoURL != "" {
|
||||
logger.Infof("try download no water video: %s", v.VideoURL)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, ".mp4", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
logger.Infof("download no water video success: %s", videoURL)
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
|
||||
// Convert TaskInfo to VideoTask
|
||||
var videoTask types.VideoTask
|
||||
if err := json.Unmarshal([]byte(v.TaskInfo), &videoTask); err != nil {
|
||||
logger.Errorf("failed to unmarshal task info to VideoTask: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.VideoJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
if job.Type == types.VideoLuma {
|
||||
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
|
||||
if err != nil {
|
||||
logger.Errorf("query task with error: %v", err)
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||||
"err_msg": err.Error(),
|
||||
"cover_url": "/images/failed.jpg",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("task: %+v", task)
|
||||
if task.State == "completed" { // 更新任务信息
|
||||
data := map[string]interface{}{
|
||||
"progress": 102, // 102 表示资源未下载完成,
|
||||
"water_url": task.Video.Url,
|
||||
"raw_data": utils.JsonEncode(task),
|
||||
"prompt_ext": task.Prompt,
|
||||
"cover_url": task.Thumbnail.Url,
|
||||
}
|
||||
if task.Video.DownloadUrl != "" {
|
||||
data["video_url"] = task.Video.DownloadUrl
|
||||
}
|
||||
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||||
if err != nil {
|
||||
logger.Errorf("更新数据库失败:%v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else if job.Type == types.VideoKeLing {
|
||||
// Convert TaskInfo to VideoTask
|
||||
var videoTask types.VideoTask
|
||||
if err := json.Unmarshal([]byte(job.TaskInfo), &videoTask); err != nil {
|
||||
logger.Errorf("failed to unmarshal task info to VideoTask: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Type assert task.Params to KeLingVideoParams
|
||||
paramsMap, ok := videoTask.Params.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert map to KeLingVideoParams
|
||||
paramsBytes, err := json.Marshal(paramsMap)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var params types.KeLingVideoParams
|
||||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
task, err := s.QueryKeLingTask(job.TaskId, job.Channel, params.TaskType)
|
||||
if err != nil {
|
||||
logger.Errorf("query task with error: %v", err)
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||||
"err_msg": err.Error(),
|
||||
"cover_url": "/images/failed.jpg",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("task: %+v", task)
|
||||
if task.TaskStatus == "succeed" { // 更新任务信息
|
||||
data := map[string]interface{}{
|
||||
"progress": 102, // 102 表示资源未下载完成,
|
||||
"water_url": task.TaskResult.Videos[0].URL,
|
||||
"raw_data": utils.JsonEncode(task),
|
||||
"prompt_ext": job.Prompt,
|
||||
"cover_url": "",
|
||||
}
|
||||
if len(task.TaskResult.Videos) > 0 {
|
||||
data["video_url"] = task.TaskResult.Videos[0].URL
|
||||
}
|
||||
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||||
if err != nil {
|
||||
logger.Errorf("更新数据库失败:%v", err)
|
||||
continue
|
||||
}
|
||||
} else if task.TaskStatus == "failed" {
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": task.TaskStatusMsg,
|
||||
"cover_url": "/images/failed.jpg",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: job.Type,
|
||||
Remark: fmt.Sprintf("%s 任务失败,退回算力。任务ID:%s,Err:%s", job.Type, job.TaskId, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 更新任务状态
|
||||
s.db.Model(&job).UpdateColumn("power", 0)
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type LumaTaskVo struct {
|
||||
Id string `json:"id"`
|
||||
Liked interface{} `json:"liked"`
|
||||
State string `json:"state"`
|
||||
Video struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
DownloadUrl string `json:"download_url"`
|
||||
} `json:"video"`
|
||||
Prompt string `json:"prompt"`
|
||||
UserId string `json:"user_id"`
|
||||
BatchId string `json:"batch_id"`
|
||||
Thumbnail struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"thumbnail"`
|
||||
VideoRaw struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"video_raw"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastFrame struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"last_frame"`
|
||||
}
|
||||
|
||||
type LumaRespVo struct {
|
||||
Id string `json:"id"`
|
||||
Prompt string `json:"prompt"`
|
||||
State string `json:"state"`
|
||||
QueueState interface{} `json:"queue_state"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Video interface{} `json:"video"`
|
||||
VideoRaw interface{} `json:"video_raw"`
|
||||
Liked interface{} `json:"liked"`
|
||||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||
Thumbnail interface{} `json:"thumbnail"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return LumaRespVo{}, errors.New("no available API KEY for Luma")
|
||||
}
|
||||
|
||||
// Type assert task.Params to LumaVideoParams
|
||||
paramsMap, ok := task.Params.(map[string]interface{})
|
||||
if !ok {
|
||||
return LumaRespVo{}, errors.New("invalid params type for Luma video task")
|
||||
}
|
||||
|
||||
// Convert map to LumaVideoParams
|
||||
paramsBytes, err := json.Marshal(paramsMap)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("failed to marshal params: %v", err)
|
||||
}
|
||||
|
||||
var params types.LumaVideoParams
|
||||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err)
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"user_prompt": task.Prompt,
|
||||
"expand_prompt": params.PromptOptimize,
|
||||
"loop": params.Loop,
|
||||
"image_url": params.StartImgURL, // 图生视频
|
||||
"image_end_url": params.EndImgURL, // 图生视频
|
||||
}
|
||||
|
||||
var res LumaRespVo
|
||||
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.StatusCode != 200 && r.StatusCode != 201 {
|
||||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
// update the last_use_at for api key
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
session.Updates(&apiKey)
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
|
||||
Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
|
||||
var res LumaTaskVo
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if r.StatusCode != 200 {
|
||||
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type KeLingRespVo struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
} `json:"data"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) KeLingCreate(task types.VideoTask) (KeLingRespVo, error) {
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "keling").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return KeLingRespVo{}, errors.New("no available API KEY for keling")
|
||||
}
|
||||
|
||||
// Type assert task.Params to KeLingVideoParams
|
||||
paramsMap, ok := task.Params.(map[string]interface{})
|
||||
if !ok {
|
||||
return KeLingRespVo{}, errors.New("invalid params type for KeLing video task")
|
||||
}
|
||||
|
||||
// Convert map to KeLingVideoParams
|
||||
paramsBytes, err := json.Marshal(paramsMap)
|
||||
if err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to marshal params: %v", err)
|
||||
}
|
||||
|
||||
var params types.KeLingVideoParams
|
||||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err)
|
||||
}
|
||||
|
||||
// 2. 构建API请求参数
|
||||
payload := map[string]interface{}{
|
||||
"model_name": params.Model,
|
||||
"prompt": task.Prompt,
|
||||
"negative_prompt": params.NegPrompt,
|
||||
"cfg_scale": params.CfgScale,
|
||||
"mode": params.Mode,
|
||||
"aspect_ratio": params.AspectRatio,
|
||||
"duration": params.Duration,
|
||||
}
|
||||
|
||||
// 只有当 CameraControl 的类型不为空时,才处理摄像机控制参数
|
||||
if params.CameraControl.Type != "" {
|
||||
cameraControl := map[string]interface{}{
|
||||
"type": params.CameraControl.Type,
|
||||
}
|
||||
|
||||
// 只有在 simple 类型时才添加 config 参数
|
||||
if params.CameraControl.Type == "simple" {
|
||||
cameraControl["config"] = params.CameraControl.Config
|
||||
}
|
||||
|
||||
payload["camera_control"] = cameraControl
|
||||
}
|
||||
|
||||
// 处理图生视频
|
||||
if params.TaskType == "image2video" {
|
||||
payload["image"] = params.Image
|
||||
payload["image_tail"] = params.ImageTail
|
||||
}
|
||||
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// 3. 准备HTTP请求
|
||||
url := fmt.Sprintf("%s/kling/v1/videos/%s", apiKey.ApiURL, params.TaskType)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(jsonPayload))
|
||||
if err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.Value)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 4. 发送请求
|
||||
client := &http.Client{Timeout: time.Duration(30) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 5. 处理响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return KeLingRespVo{}, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var apiResponse = KeLingRespVo{}
|
||||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||||
return KeLingRespVo{}, fmt.Errorf("failed to parse response: %v", err)
|
||||
}
|
||||
// 设置 API 通道
|
||||
apiResponse.Channel = apiKey.ApiURL
|
||||
return apiResponse, nil
|
||||
}
|
||||
|
||||
// VideoCallbackData 表示视频生成任务的回调数据
|
||||
type VideoCallbackData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
TaskStatusMsg string `json:"task_status_msg"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
TaskResult TaskResult `json:"task_result"`
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
Images []CallBackImageResult `json:"images,omitempty"`
|
||||
Videos []CallBackVideoResult `json:"videos,omitempty"`
|
||||
}
|
||||
|
||||
type CallBackImageResult struct {
|
||||
Index int `json:"index"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type CallBackVideoResult struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
Duration string `json:"duration"`
|
||||
}
|
||||
|
||||
func (s *Service) QueryKeLingTask(taskId string, channel string, action string) (VideoCallbackData, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := s.db.Session(&gorm.Session{}).Where("type", "keling").
|
||||
//Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return VideoCallbackData{}, errors.New("no available API KEY for keling")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/kling/v1/videos/%s/%s", apiKey.ApiURL, action, taskId)
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return VideoCallbackData{}, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.Value)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return VideoCallbackData{}, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return VideoCallbackData{}, fmt.Errorf("unexpected status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return VideoCallbackData{}, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
var response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data VideoCallbackData `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return VideoCallbackData{}, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return VideoCallbackData{}, fmt.Errorf("API error: %s", response.Message)
|
||||
}
|
||||
|
||||
return response.Data, nil
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package service
|
||||
|
||||
import "geekai/core/types"
|
||||
|
||||
type WebsocketService struct {
|
||||
Clients *types.LMap[string, *types.WsClient] // clientId => Client
|
||||
}
|
||||
|
||||
func NewWebsocketService() *WebsocketService {
|
||||
return &WebsocketService{
|
||||
Clients: types.NewLMap[string, *types.WsClient](),
|
||||
}
|
||||
}
|
||||
@@ -9,14 +9,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"github.com/xxl-job/xxl-job-executor-go"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
@@ -46,97 +43,13 @@ func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor {
|
||||
|
||||
func (e *XXLJobExecutor) Run() error {
|
||||
e.executor.RegTask("ClearOrders", e.ClearOrders)
|
||||
e.executor.RegTask("ResetVipPower", e.ResetVipPower)
|
||||
e.executor.RegTask("ResetUserPower", e.ResetUserPower)
|
||||
return e.executor.Run()
|
||||
}
|
||||
|
||||
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
|
||||
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||
logger.Info("执行清理未支付订单...")
|
||||
var sysConfig model.Config
|
||||
res := e.db.Where("marker", "system").First(&sysConfig)
|
||||
if res.Error != nil {
|
||||
return "error with get system config: " + res.Error.Error()
|
||||
}
|
||||
|
||||
var config types.SystemConfig
|
||||
err := utils.JsonDecode(sysConfig.Config, &config)
|
||||
if err != nil {
|
||||
return "error with decode system config: " + err.Error()
|
||||
}
|
||||
|
||||
if config.OrderPayTimeout == 0 { // 默认未支付订单的生命周期为 30 分钟
|
||||
config.OrderPayTimeout = 1800
|
||||
}
|
||||
timeout := time.Now().Unix() - int64(config.OrderPayTimeout)
|
||||
start := utils.Stamp2str(timeout)
|
||||
// 这里不是用软删除,而是永久删除订单
|
||||
res = e.db.Unscoped().Where("status IN ? AND created_at < ?", []types.OrderStatus{types.OrderNotPaid, types.OrderScanned}, start).Delete(&model.Order{})
|
||||
logger.Infof("Clear order successfully, affect rows: %d", res.RowsAffected)
|
||||
return "success"
|
||||
}
|
||||
|
||||
// ResetVipPower 重置VIP会员算力
|
||||
// 自动将 VIP 会员的算力补充到每月赠送的最大值
|
||||
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||
logger.Info("开始进行月底账号盘点...")
|
||||
return "success"
|
||||
}
|
||||
|
||||
func (e *XXLJobExecutor) ResetUserPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||
logger.Info("今日算力派发开始:", time.Now())
|
||||
var users []model.User
|
||||
res := e.db.Where("status", 1).Find(&users)
|
||||
if res.Error != nil {
|
||||
return "No matching users"
|
||||
}
|
||||
|
||||
var sysConfig model.Config
|
||||
res = e.db.Where("marker", "system").First(&sysConfig)
|
||||
if res.Error != nil {
|
||||
return "error with get system config: " + res.Error.Error()
|
||||
}
|
||||
|
||||
var config types.SystemConfig
|
||||
err := utils.JsonDecode(sysConfig.Config, &config)
|
||||
if err != nil {
|
||||
return "error with decode system config: " + err.Error()
|
||||
}
|
||||
|
||||
if config.DailyPower <= 0 {
|
||||
return "success"
|
||||
}
|
||||
|
||||
var counter = 0
|
||||
var totalPower = 0
|
||||
for _, u := range users {
|
||||
if u.Power >= config.DailyPower {
|
||||
continue
|
||||
}
|
||||
var power = config.DailyPower - u.Power
|
||||
// update user
|
||||
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
|
||||
// 记录算力充值日志
|
||||
if tx.Error == nil {
|
||||
var user model.User
|
||||
e.db.Where("id", u.Id).First(&user)
|
||||
e.db.Create(&model.PowerLog{
|
||||
UserId: u.Id,
|
||||
Username: u.Username,
|
||||
Type: types.PowerGift,
|
||||
Amount: power,
|
||||
Mark: types.PowerAdd,
|
||||
Balance: user.Power,
|
||||
Model: "系统赠送",
|
||||
Remark: fmt.Sprintf("系统每日算力派发,今日额度:%d", config.DailyPower),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
counter++
|
||||
totalPower += power
|
||||
}
|
||||
logger.Infof("今日派发算力结束!累计派发 %d 人,累计派发算力:%d", counter, totalPower)
|
||||
|
||||
return "success"
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,11 +1,22 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type AdminUser struct {
|
||||
BaseModel
|
||||
Username string
|
||||
Password string
|
||||
Salt string // 密码盐
|
||||
Status bool `gorm:"default:true"` // 当前状态
|
||||
LastLoginAt int64 // 最后登录时间
|
||||
LastLoginIp string // 最后登录 IP
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Username string `gorm:"column:username;type:varchar(30);uniqueIndex;not null;comment:用户名" json:"username"`
|
||||
Password string `gorm:"column:password;type:char(64);not null;comment:密码" json:"password"`
|
||||
Salt string `gorm:"column:salt;type:char(12);not null;comment:密码盐" json:"salt"`
|
||||
Status bool `gorm:"column:status;type:tinyint(1);not null;comment:当前状态" json:"status"`
|
||||
LastLoginAt int64 `gorm:"column:last_login_at;type:int;not null;comment:最后登录时间" json:"last_login_at"`
|
||||
LastLoginIp string `gorm:"column:last_login_ip;type:char(16);not null;comment:最后登录 IP" json:"last_login_ip"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (m *AdminUser) TableName() string {
|
||||
return "chatgpt_admin_users"
|
||||
}
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ApiKey OpenAI API 模型
|
||||
type ApiKey struct {
|
||||
BaseModel
|
||||
Name string
|
||||
Type string // 用途 chat => 聊天,img => 绘图
|
||||
Value string // API Key 的值
|
||||
ApiURL string // 当前 KEY 的 API 地址
|
||||
Enabled bool // 是否启用
|
||||
ProxyURL string // 代理地址
|
||||
LastUsedAt int64 // 最后使用时间
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(30);comment:名称" json:"name"`
|
||||
Value string `gorm:"column:value;type:varchar(255);not null;comment:API KEY value" json:"value"`
|
||||
Type string `gorm:"column:type;type:varchar(10);default:chat;not null;comment:用途(chat=>聊天,img=>图片)" json:"type"`
|
||||
LastUsedAt int64 `gorm:"column:last_used_at;type:int;not null;comment:最后使用时间" json:"last_used_at"`
|
||||
ApiURL string `gorm:"column:api_url;type:varchar(255);comment:API 地址" json:"api_url"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);comment:是否启用" json:"enabled"`
|
||||
ProxyURL string `gorm:"column:proxy_url;type:varchar(100);comment:代理地址" json:"proxy_url"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (m *ApiKey) TableName() string {
|
||||
return "chatgpt_api_keys"
|
||||
}
|
||||
|
||||
@@ -3,10 +3,15 @@ package model
|
||||
import "time"
|
||||
|
||||
type AppType struct {
|
||||
Id uint `gorm:"primarykey"`
|
||||
Name string
|
||||
Icon string
|
||||
Enabled bool
|
||||
SortNum int
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(50);not null;comment:名称" json:"name"`
|
||||
Icon string `gorm:"column:icon;type:varchar(255);not null;comment:图标URL" json:"icon"`
|
||||
SortNum int `gorm:"column:sort_num;type:tinyint;not null;comment:排序" json:"sort_num"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);not null;comment:是否启用" json:"enabled"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (m *AppType) TableName() string {
|
||||
return "chatgpt_app_types"
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type BaseModel struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package model
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type ChatMessage struct {
|
||||
BaseModel
|
||||
ChatId string // 会话 ID
|
||||
UserId uint // 用户 ID
|
||||
RoleId uint // 角色 ID
|
||||
Model string // AI模型
|
||||
Type string
|
||||
Icon string
|
||||
Tokens int
|
||||
TotalTokens int // 总 token 消耗
|
||||
Content string
|
||||
UseContext bool // 是否可以作为聊天上下文
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
|
||||
func (ChatMessage) TableName() string {
|
||||
return "chatgpt_chat_history"
|
||||
}
|
||||
@@ -1,14 +1,21 @@
|
||||
package model
|
||||
|
||||
import "gorm.io/gorm"
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatItem struct {
|
||||
BaseModel
|
||||
ChatId string `gorm:"column:chat_id;unique"` // 会话 ID
|
||||
UserId uint // 用户 ID
|
||||
RoleId uint // 角色 ID
|
||||
ModelId uint // 模型 ID
|
||||
Model string // 模型
|
||||
Title string // 会话标题
|
||||
DeletedAt gorm.DeletedAt
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
ChatId string `gorm:"column:chat_id;type:char(40);uniqueIndex;not null;comment:会话 ID" json:"chat_id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
RoleId uint `gorm:"column:role_id;type:int;not null;comment:角色 ID" json:"role_id"`
|
||||
Title string `gorm:"column:title;type:varchar(100);not null;comment:会话标题" json:"title"`
|
||||
ModelId uint `gorm:"column:model_id;type:int;not null;default:0;comment:模型 ID" json:"model_id"`
|
||||
Model string `gorm:"column:model;type:varchar(30);comment:模型名称" json:"model"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *ChatItem) TableName() string {
|
||||
return "chatgpt_chat_items"
|
||||
}
|
||||
|
||||
25
api/store/model/chat_message.go
Normal file
25
api/store/model/chat_message.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatMessage struct {
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
ChatId string `gorm:"column:chat_id;type:char(40);not null;index;comment:会话 ID" json:"chat_id"`
|
||||
Type string `gorm:"column:type;type:varchar(10);not null;comment:类型:prompt|reply" json:"type"`
|
||||
Icon string `gorm:"column:icon;type:varchar(255);not null;comment:角色图标" json:"icon"`
|
||||
RoleId uint `gorm:"column:role_id;type:int;not null;comment:角色 ID" json:"role_id"`
|
||||
Model string `gorm:"column:model;type:varchar(255);comment:模型名称" json:"model"`
|
||||
Content string `gorm:"column:content;type:text;not null;comment:聊天内容" json:"content"`
|
||||
Tokens int `gorm:"column:tokens;type:smallint;not null;comment:耗费 token 数量" json:"tokens"`
|
||||
TotalTokens int `gorm:"column:total_tokens;type:int;not null;comment:消耗总Token长度" json:"total_tokens"`
|
||||
UseContext bool `gorm:"column:use_context;type:tinyint(1);not null;comment:是否允许作为上下文语料" json:"use_context"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *ChatMessage) TableName() string {
|
||||
return "chatgpt_chat_history"
|
||||
}
|
||||
@@ -1,16 +1,29 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatModel struct {
|
||||
BaseModel
|
||||
Name string
|
||||
Value string // API Key 的值
|
||||
SortNum int
|
||||
Enabled bool
|
||||
Power int // 每次对话消耗算力
|
||||
Open bool // 是否开放模型给所有人使用
|
||||
MaxTokens int // 最大响应长度
|
||||
MaxContext int // 最大上下文长度
|
||||
Temperature float32 // 模型温度
|
||||
KeyId int // 绑定 API KEY ID
|
||||
Type string // 模型类型
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Desc string `gorm:"column:desc;type:varchar(1024);not null;default:'';comment:模型类型描述" json:"desc"`
|
||||
Tag string `gorm:"column:tag;type:varchar(1024);not null;default:'';comment:模型标签" json:"tag"`
|
||||
Type string `gorm:"column:type;type:varchar(10);not null;default:chat;comment:模型类型(chat,img)" json:"type"`
|
||||
Name string `gorm:"column:name;type:varchar(255);not null;comment:模型名称" json:"name"`
|
||||
Value string `gorm:"column:value;type:varchar(255);not null;comment:模型值" json:"value"`
|
||||
SortNum int `gorm:"column:sort_num;type:tinyint(1);not null;comment:排序数字" json:"sort_num"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);not null;default:0;comment:是否启用模型" json:"enabled"`
|
||||
Power int `gorm:"column:power;type:smallint;not null;comment:消耗算力点数" json:"power"`
|
||||
Temperature float32 `gorm:"column:temperature;type:float(3,1);not null;default:1.0;comment:模型创意度" json:"temperature"`
|
||||
MaxTokens int `gorm:"column:max_tokens;type:int;not null;default:1024;comment:最大响应长度" json:"max_tokens"`
|
||||
MaxContext int `gorm:"column:max_context;type:int;not null;default:4096;comment:最大上下文长度" json:"max_context"`
|
||||
Open bool `gorm:"column:open;type:tinyint(1);not null;comment:是否开放模型" json:"open"`
|
||||
KeyId uint `gorm:"column:key_id;type:int;not null;comment:绑定API KEY ID" json:"key_id"`
|
||||
Options string `gorm:"column:options;type:text;not null;comment:模型自定义选项" json:"options"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *ChatModel) TableName() string {
|
||||
return "chatgpt_chat_models"
|
||||
}
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatRole struct {
|
||||
BaseModel
|
||||
Tid int
|
||||
Key string `gorm:"column:marker;unique"` // 角色唯一标识
|
||||
Name string // 角色名称
|
||||
Context string `gorm:"column:context_json"` // 角色语料信息 json
|
||||
HelloMsg string // 打招呼的消息
|
||||
Icon string // 角色聊天图标
|
||||
Enable bool // 是否启用被启用
|
||||
SortNum int //排序数字
|
||||
ModelId int // 绑定模型ID,绑定模型ID的角色只能用指定的模型来问答
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(30);not null;comment:角色名称" json:"name"`
|
||||
Tid uint `gorm:"column:tid;type:int;not null;comment:分类ID" json:"tid"`
|
||||
Key string `gorm:"column:marker;type:varchar(30);uniqueIndex;not null;comment:角色标识" json:"marker"`
|
||||
Context string `gorm:"column:context_json;type:text;not null;comment:角色语料 json" json:"context_json"`
|
||||
HelloMsg string `gorm:"column:hello_msg;type:varchar(255);not null;comment:打招呼信息" json:"hello_msg"`
|
||||
Icon string `gorm:"column:icon;type:varchar(255);not null;comment:角色图标" json:"icon"`
|
||||
Enable bool `gorm:"column:enable;type:tinyint(1);not null;comment:是否被启用" json:"enable"`
|
||||
SortNum int `gorm:"column:sort_num;type:smallint;not null;default:0;comment:角色排序" json:"sort_num"`
|
||||
ModelId uint `gorm:"column:model_id;type:int;not null;default:0;comment:绑定模型ID" json:"model_id"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *ChatRole) TableName() string {
|
||||
return "chatgpt_chat_roles"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package model
|
||||
|
||||
type Config struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Key string `gorm:"column:marker;unique"`
|
||||
Config string `gorm:"column:config_json"`
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `gorm:"column:name;type:varchar(20);uniqueIndex;not null;comment:配置名称"`
|
||||
Value string `gorm:"column:value;type:text;not null"`
|
||||
}
|
||||
|
||||
func (m *Config) TableName() string {
|
||||
return "chatgpt_configs"
|
||||
}
|
||||
|
||||
@@ -3,15 +3,19 @@ package model
|
||||
import "time"
|
||||
|
||||
type DallJob struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
UserId uint
|
||||
Prompt string
|
||||
TaskInfo string // 原始任务信息
|
||||
ImgURL string
|
||||
OrgURL string
|
||||
Publish bool
|
||||
Power int
|
||||
Progress int
|
||||
ErrMsg string
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
|
||||
Prompt string `gorm:"column:prompt;type:text;not null;comment:提示词" json:"prompt"`
|
||||
TaskInfo string `gorm:"column:task_info;type:text;not null;comment:任务详情" json:"task_info"`
|
||||
ImgURL string `gorm:"column:img_url;type:varchar(255);not null;comment:图片地址" json:"img_url"`
|
||||
OrgURL string `gorm:"column:org_url;type:varchar(1024);comment:原图地址" json:"org_url"`
|
||||
Publish int `gorm:"column:publish;type:tinyint(1);not null;comment:是否发布" json:"publish"`
|
||||
Power int `gorm:"column:power;type:smallint;not null;comment:消耗算力" json:"power"`
|
||||
Progress int `gorm:"column:progress;type:smallint;not null;comment:任务进度" json:"progress"`
|
||||
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);not null;comment:错误信息" json:"err_msg"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
func (m *DallJob) TableName() string {
|
||||
return "chatgpt_dall_jobs"
|
||||
}
|
||||
|
||||
@@ -3,12 +3,16 @@ package model
|
||||
import "time"
|
||||
|
||||
type File struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
UserId int
|
||||
Name string
|
||||
ObjKey string
|
||||
URL string
|
||||
Ext string
|
||||
Size int64
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
Name string `gorm:"column:name;type:varchar(255);not null;comment:文件名" json:"name"`
|
||||
ObjKey string `gorm:"column:obj_key;type:varchar(100);comment:文件标识" json:"obj_key"`
|
||||
URL string `gorm:"column:url;type:varchar(255);not null;comment:文件地址" json:"url"`
|
||||
Ext string `gorm:"column:ext;type:varchar(10);not null;comment:文件后缀" json:"ext"`
|
||||
Size int64 `gorm:"column:size;type:bigint;not null;default:0;comment:文件大小" json:"size"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
|
||||
}
|
||||
|
||||
func (m *File) TableName() string {
|
||||
return "chatgpt_files"
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package model
|
||||
|
||||
type Function struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Name string
|
||||
Label string
|
||||
Description string
|
||||
Parameters string
|
||||
Action string
|
||||
Token string
|
||||
Enabled bool
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(30);uniqueIndex;not null;comment:函数名称" json:"name"`
|
||||
Label string `gorm:"column:label;type:varchar(30);comment:函数标签" json:"label"`
|
||||
Description string `gorm:"column:description;type:varchar(255);comment:函数描述" json:"description"`
|
||||
Parameters string `gorm:"column:parameters;type:text;comment:函数参数(JSON)" json:"parameters"`
|
||||
Token string `gorm:"column:token;type:varchar(255);comment:API授权token" json:"token"`
|
||||
Action string `gorm:"column:action;type:varchar(255);comment:函数处理 API" json:"action"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);not null;default:0;comment:是否启用" json:"enabled"`
|
||||
}
|
||||
|
||||
func (m *Function) TableName() string {
|
||||
return "chatgpt_functions"
|
||||
}
|
||||
|
||||
@@ -3,10 +3,14 @@ package model
|
||||
import "time"
|
||||
|
||||
type InviteCode struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
UserId uint
|
||||
Code string
|
||||
Hits int // 点击次数
|
||||
RegNum int // 注册人数
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
|
||||
Code string `gorm:"column:code;type:char(8);uniqueIndex;not null;comment:邀请码" json:"code"`
|
||||
Hits int `gorm:"column:hits;type:int;not null;comment:点击次数" json:"hits"`
|
||||
RegNum int `gorm:"column:reg_num;type:smallint;not null;comment:注册数量" json:"reg_num"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
func (m *InviteCode) TableName() string {
|
||||
return "chatgpt_invite_codes"
|
||||
}
|
||||
|
||||
@@ -5,11 +5,15 @@ import (
|
||||
)
|
||||
|
||||
type InviteLog struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
InviterId uint
|
||||
UserId uint
|
||||
Username string
|
||||
InviteCode string
|
||||
Remark string
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
InviterId uint `gorm:"column:inviter_id;type:int;not null;comment:邀请人ID" json:"inviter_id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:注册用户ID" json:"user_id"`
|
||||
Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"`
|
||||
InviteCode string `gorm:"column:invite_code;type:char(8);not null;comment:邀请码" json:"invite_code"`
|
||||
Remark string `gorm:"column:remark;type:varchar(255);not null;comment:备注" json:"remark"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
func (m *InviteLog) TableName() string {
|
||||
return "chatgpt_invite_logs"
|
||||
}
|
||||
|
||||
55
api/store/model/jimeng_job.go
Normal file
55
api/store/model/jimeng_job.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// JimengJob 即梦AI任务模型
|
||||
type JimengJob struct {
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;index;comment:用户ID" json:"user_id"`
|
||||
TaskId string `gorm:"column:task_id;type:varchar(100);not null;index;comment:任务ID" json:"task_id"`
|
||||
Type JMTaskType `gorm:"column:type;type:varchar(50);not null;comment:任务类型" json:"type"`
|
||||
ReqKey string `gorm:"column:req_key;type:varchar(100);comment:请求Key" json:"req_key"`
|
||||
Prompt string `gorm:"column:prompt;type:text;comment:提示词" json:"prompt"`
|
||||
TaskParams string `gorm:"column:task_params;type:text;comment:任务参数JSON" json:"task_params"`
|
||||
ImgURL string `gorm:"column:img_url;type:varchar(1024);comment:图片或封面URL" json:"img_url"`
|
||||
VideoURL string `gorm:"column:video_url;type:varchar(1024);comment:视频URL" json:"video_url"`
|
||||
RawData string `gorm:"column:raw_data;type:text;comment:原始API响应" json:"raw_data"`
|
||||
Progress int `gorm:"column:progress;type:int;default:0;comment:进度百分比" json:"progress"`
|
||||
Status JMTaskStatus `gorm:"column:status;type:varchar(20);default:'pending';comment:任务状态" json:"status"`
|
||||
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
|
||||
Power int `gorm:"column:power;type:int(11);default:0;comment:消耗算力" json:"power"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
|
||||
}
|
||||
|
||||
// JMTaskStatus 任务状态
|
||||
type JMTaskStatus string
|
||||
|
||||
const (
|
||||
JMTaskStatusInQueue = JMTaskStatus("in_queue") // 任务已提交
|
||||
JMTaskStatusGenerating = JMTaskStatus("generating") // 任务处理中
|
||||
JMTaskStatusDone = JMTaskStatus("done") // 处理完成
|
||||
JMTaskStatusNotFound = JMTaskStatus("not_found") // 任务未找到
|
||||
JMTaskStatusSuccess = JMTaskStatus("success") // 任务成功
|
||||
JMTaskStatusFailed = JMTaskStatus("failed") // 任务失败
|
||||
JMTaskStatusExpired = JMTaskStatus("expired") // 任务过期
|
||||
)
|
||||
|
||||
// JMTaskType 任务类型
|
||||
type JMTaskType string
|
||||
|
||||
const (
|
||||
JMTaskTypeTextToImage = JMTaskType("text_to_image") // 文生图
|
||||
JMTaskTypeImageToImage = JMTaskType("image_to_image") // 图生图
|
||||
JMTaskTypeImageEdit = JMTaskType("image_edit") // 图像编辑
|
||||
JMTaskTypeImageEffects = JMTaskType("image_effects") // 图像特效
|
||||
JMTaskTypeTextToVideo = JMTaskType("text_to_video") // 文生视频
|
||||
JMTaskTypeImageToVideo = JMTaskType("image_to_video") // 图生视频
|
||||
)
|
||||
|
||||
// TableName 返回数据表名称
|
||||
func (JimengJob) TableName() string {
|
||||
return "chatgpt_jimeng_jobs"
|
||||
}
|
||||
@@ -2,10 +2,14 @@ package model
|
||||
|
||||
// Menu 系统菜单
|
||||
type Menu struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Name string // 菜单名称
|
||||
Icon string // 菜单图标
|
||||
URL string // 菜单跳转地址
|
||||
SortNum int // 排序
|
||||
Enabled bool // 启用状态
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(30);not null;comment:菜单名称" json:"name"`
|
||||
Icon string `gorm:"column:icon;type:varchar(150);not null;comment:菜单图标" json:"icon"`
|
||||
URL string `gorm:"column:url;type:varchar(100);not null;comment:地址" json:"url"`
|
||||
SortNum int `gorm:"column:sort_num;type:smallint;not null;comment:排序" json:"sort_num"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);not null;comment:是否启用" json:"enabled"`
|
||||
}
|
||||
|
||||
func (m *Menu) TableName() string {
|
||||
return "chatgpt_menus"
|
||||
}
|
||||
|
||||
@@ -3,26 +3,26 @@ package model
|
||||
import "time"
|
||||
|
||||
type MidJourneyJob struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Type string
|
||||
UserId int
|
||||
TaskId string
|
||||
TaskInfo string // 原始任务信息
|
||||
ChannelId string
|
||||
MessageId string
|
||||
ReferenceId string
|
||||
ImgURL string
|
||||
OrgURL string // 原图地址
|
||||
Hash string // message hash
|
||||
Progress int
|
||||
Prompt string
|
||||
UseProxy bool // 是否使用反代加载图片
|
||||
Publish bool //是否发布图片到画廊
|
||||
ErrMsg string // 报错信息
|
||||
Power int // 消耗算力
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
TaskId string `gorm:"column:task_id;type:varchar(20);uniqueIndex;comment:任务 ID" json:"task_id"`
|
||||
TaskInfo string `gorm:"column:task_info;type:text;not null;comment:任务详情" json:"task_info"`
|
||||
Type string `gorm:"column:type;type:varchar(20);default:image;comment:任务类别" json:"type"`
|
||||
MessageId string `gorm:"column:message_id;type:char(40);not null;index;comment:消息 ID" json:"message_id"`
|
||||
ChannelId string `gorm:"column:channel_id;type:varchar(100);comment:频道ID" json:"channel_id"`
|
||||
RefId string `gorm:"column:reference_id;type:char(40);comment:引用消息 ID" json:"reference_id"`
|
||||
Prompt string `gorm:"column:prompt;type:text;not null;comment:会话提示词" json:"prompt"`
|
||||
ImgURL string `gorm:"column:img_url;type:varchar(400);comment:图片URL" json:"img_url"`
|
||||
OrgURL string `gorm:"column:org_url;type:varchar(400);comment:原始图片地址" json:"org_url"`
|
||||
Hash string `gorm:"column:hash;type:varchar(100);comment:message hash" json:"hash"`
|
||||
Progress int `gorm:"column:progress;type:smallint;default:0;comment:任务进度" json:"progress"`
|
||||
UseProxy int `gorm:"column:use_proxy;type:tinyint(1);not null;default:0;comment:是否使用反代" json:"use_proxy"`
|
||||
Publish int `gorm:"column:publish;type:tinyint(1);not null;comment:是否发布" json:"publish"`
|
||||
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
|
||||
Power int `gorm:"column:power;type:smallint;not null;default:0;comment:消耗算力" json:"power"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
func (MidJourneyJob) TableName() string {
|
||||
func (m *MidJourneyJob) TableName() string {
|
||||
return "chatgpt_mj_jobs"
|
||||
}
|
||||
|
||||
@@ -2,21 +2,28 @@ package model
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Order 充值订单
|
||||
type Order struct {
|
||||
BaseModel
|
||||
UserId uint
|
||||
ProductId uint
|
||||
Username string
|
||||
OrderNo string
|
||||
TradeNo string
|
||||
Subject string
|
||||
Amount float64
|
||||
Status types.OrderStatus
|
||||
Remark string
|
||||
PayTime int64
|
||||
PayWay string // 支付渠道
|
||||
PayType string // 支付类型
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
|
||||
ProductId uint `gorm:"column:product_id;type:int;not null;comment:产品ID" json:"product_id"`
|
||||
Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"`
|
||||
OrderNo string `gorm:"column:order_no;type:varchar(30);uniqueIndex;not null;comment:订单ID" json:"order_no"`
|
||||
TradeNo string `gorm:"column:trade_no;type:varchar(60);comment:支付平台交易流水号" json:"trade_no"`
|
||||
Subject string `gorm:"column:subject;type:varchar(100);not null;comment:订单产品" json:"subject"`
|
||||
Amount float64 `gorm:"column:amount;type:decimal(10,2);not null;default:0.00;comment:订单金额" json:"amount"`
|
||||
Status types.OrderStatus `gorm:"column:status;type:tinyint(1);not null;default:0;comment:订单状态(0:待支付,1:已扫码,2:支付成功)" json:"status"`
|
||||
Remark string `gorm:"column:remark;type:varchar(255);not null;comment:备注" json:"remark"`
|
||||
PayTime int64 `gorm:"column:pay_time;type:int;comment:支付时间" json:"pay_time"`
|
||||
PayWay string `gorm:"column:pay_way;type:varchar(20);not null;comment:支付方式" json:"pay_way"`
|
||||
PayType string `gorm:"column:pay_type;type:varchar(30);not null;comment:支付类型" json:"pay_type"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *Order) TableName() string {
|
||||
return "chatgpt_orders"
|
||||
}
|
||||
|
||||
@@ -7,14 +7,18 @@ import (
|
||||
|
||||
// PowerLog 算力消费日志
|
||||
type PowerLog struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
UserId uint
|
||||
Username string
|
||||
Type types.PowerType
|
||||
Amount int
|
||||
Balance int
|
||||
Model string // 模型
|
||||
Remark string // 备注
|
||||
Mark types.PowerMark // 资金类型
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
|
||||
Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"`
|
||||
Type types.PowerType `gorm:"column:type;type:tinyint(1);not null;comment:类型(1:充值,2:消费,3:退费)" json:"type"`
|
||||
Amount int `gorm:"column:amount;type:smallint;not null;comment:算力数值" json:"amount"`
|
||||
Balance int `gorm:"column:balance;type:int;not null;comment:余额" json:"balance"`
|
||||
Model string `gorm:"column:model;type:varchar(30);not null;comment:模型" json:"model"`
|
||||
Remark string `gorm:"column:remark;type:varchar(512);not null;comment:备注" json:"remark"`
|
||||
Mark types.PowerMark `gorm:"column:mark;type:tinyint(1);not null;comment:资金类型(0:支出,1:收入)" json:"mark"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
|
||||
}
|
||||
|
||||
func (m *PowerLog) TableName() string {
|
||||
return "chatgpt_power_logs"
|
||||
}
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Product 充值产品
|
||||
type Product struct {
|
||||
BaseModel
|
||||
Name string
|
||||
Price float64
|
||||
Discount float64
|
||||
Days int
|
||||
Power int
|
||||
Enabled bool
|
||||
Sales int
|
||||
SortNum int
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(30);not null;comment:名称" json:"name"`
|
||||
Price float64 `gorm:"column:price;type:decimal(10,2);not null;default:0.00;comment:价格" json:"price"`
|
||||
Discount float64 `gorm:"column:discount;type:decimal(10,2);not null;default:0.00;comment:优惠金额" json:"discount"`
|
||||
Days int `gorm:"column:days;type:smallint;not null;default:0;comment:延长天数" json:"days"`
|
||||
Power int `gorm:"column:power;type:int;not null;default:0;comment:增加算力值" json:"power"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);not null;default:0;comment:是否启动" json:"enabled"`
|
||||
Sales int `gorm:"column:sales;type:int;not null;default:0;comment:销量" json:"sales"`
|
||||
SortNum int `gorm:"column:sort_num;type:tinyint;not null;default:0;comment:排序" json:"sort_num"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
AppUrl string `gorm:"column:app_url;type:varchar(255);comment:App跳转地址" json:"app_url"`
|
||||
Url string `gorm:"column:url;type:varchar(255);comment:跳转地址" json:"url"`
|
||||
}
|
||||
|
||||
func (m *Product) TableName() string {
|
||||
return "chatgpt_products"
|
||||
}
|
||||
|
||||
@@ -5,12 +5,16 @@ import "time"
|
||||
// 兑换码
|
||||
|
||||
type Redeem struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
UserId uint // 用户 ID
|
||||
Name string // 名称
|
||||
Power int // 算力
|
||||
Code string // 兑换码
|
||||
Enabled bool // 启用状态
|
||||
RedeemedAt int64 // 兑换时间
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
Name string `gorm:"column:name;type:varchar(30);not null;comment:兑换码名称" json:"name"`
|
||||
Power int `gorm:"column:power;type:int;not null;comment:算力" json:"power"`
|
||||
Code string `gorm:"column:code;type:varchar(100);uniqueIndex;not null;comment:兑换码" json:"code"`
|
||||
Enabled bool `gorm:"column:enabled;type:tinyint(1);not null;comment:是否启用" json:"enabled"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
RedeemedAt int64 `gorm:"column:redeemed_at;type:int;not null;comment:兑换时间" json:"redeemed_at"`
|
||||
}
|
||||
|
||||
func (m *Redeem) TableName() string {
|
||||
return "chatgpt_redeems"
|
||||
}
|
||||
|
||||
@@ -3,21 +3,21 @@ package model
|
||||
import "time"
|
||||
|
||||
type SdJob struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Type string
|
||||
UserId int
|
||||
TaskId string
|
||||
TaskInfo string // 原始任务信息
|
||||
ImgURL string
|
||||
Progress int
|
||||
Prompt string
|
||||
Params string
|
||||
Publish bool //是否发布图片到画廊
|
||||
ErrMsg string // 报错信息
|
||||
Power int // 消耗算力
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
Type string `gorm:"column:type;type:varchar(20);default:txt2img;comment:任务类别" json:"type"`
|
||||
TaskId string `gorm:"column:task_id;type:char(30);uniqueIndex;not null;comment:任务 ID" json:"task_id"`
|
||||
TaskInfo string `gorm:"column:task_info;type:text;not null;comment:任务详情" json:"task_info"`
|
||||
Prompt string `gorm:"column:prompt;type:text;not null;comment:会话提示词" json:"prompt"`
|
||||
ImgURL string `gorm:"column:img_url;type:varchar(255);comment:图片URL" json:"img_url"`
|
||||
Params string `gorm:"column:params;type:text;comment:绘画参数json" json:"params"`
|
||||
Progress int `gorm:"column:progress;type:smallint;default:0;comment:任务进度" json:"progress"`
|
||||
Publish int `gorm:"column:publish;type:tinyint(1);not null;comment:是否发布" json:"publish"`
|
||||
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
|
||||
Power int `gorm:"column:power;type:smallint;not null;default:0;comment:消耗算力" json:"power"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
func (SdJob) TableName() string {
|
||||
func (m *SdJob) TableName() string {
|
||||
return "chatgpt_sd_jobs"
|
||||
}
|
||||
|
||||
@@ -3,33 +3,33 @@ package model
|
||||
import "time"
|
||||
|
||||
type SunoJob struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
UserId int
|
||||
Channel string // 频道
|
||||
Title string
|
||||
Type int
|
||||
TaskId string
|
||||
TaskInfo string // 原始任务信息
|
||||
RefTaskId string // 续写的任务id
|
||||
Tags string // 歌曲风格和标签
|
||||
Instrumental bool // 是否生成纯音乐
|
||||
ExtendSecs int // 续写秒数
|
||||
SongId string // 续写的歌曲id
|
||||
RefSongId string
|
||||
Prompt string // 提示词
|
||||
CoverURL string // 封面图 URL
|
||||
AudioURL string // 音频 URL
|
||||
ModelName string // 模型名称
|
||||
Progress int // 任务进度
|
||||
Duration int // 银屏时长,秒
|
||||
Publish bool // 是否发布
|
||||
ErrMsg string // 错误信息
|
||||
RawData string // 原始数据 json
|
||||
Power int // 消耗算力
|
||||
PlayTimes int // 播放次数
|
||||
CreatedAt time.Time
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户 ID" json:"user_id"`
|
||||
Channel string `gorm:"column:channel;type:varchar(100);not null;comment:渠道" json:"channel"`
|
||||
Title string `gorm:"column:title;type:varchar(100);comment:歌曲标题" json:"title"`
|
||||
Type int `gorm:"column:type;type:tinyint(1);default:0;comment:任务类型,1:灵感创作,2:自定义创作" json:"type"`
|
||||
TaskId string `gorm:"column:task_id;type:varchar(50);comment:任务 ID" json:"task_id"`
|
||||
TaskInfo string `gorm:"column:task_info;type:text;not null;comment:任务详情" json:"task_info"`
|
||||
RefTaskId string `gorm:"column:ref_task_id;type:char(50);comment:引用任务 ID" json:"ref_task_id"`
|
||||
Tags string `gorm:"column:tags;type:varchar(100);comment:歌曲风格" json:"tags"`
|
||||
Instrumental bool `gorm:"column:instrumental;type:tinyint(1);default:0;comment:是否为纯音乐" json:"instrumental"`
|
||||
ExtendSecs int `gorm:"column:extend_secs;type:smallint;default:0;comment:延长秒数" json:"extend_secs"`
|
||||
SongId string `gorm:"column:song_id;type:varchar(50);comment:要续写的歌曲 ID" json:"song_id"`
|
||||
RefSongId string `gorm:"column:ref_song_id;type:varchar(50);not null;comment:引用的歌曲ID" json:"ref_song_id"`
|
||||
Prompt string `gorm:"column:prompt;type:varchar(2000);not null;comment:提示词" json:"prompt"`
|
||||
CoverURL string `gorm:"column:cover_url;type:varchar(512);comment:封面图地址" json:"cover_url"`
|
||||
AudioURL string `gorm:"column:audio_url;type:varchar(512);comment:音频地址" json:"audio_url"`
|
||||
ModelName string `gorm:"column:model_name;type:varchar(30);comment:模型地址" json:"model_name"`
|
||||
Progress int `gorm:"column:progress;type:smallint;default:0;comment:任务进度" json:"progress"`
|
||||
Duration int `gorm:"column:duration;type:smallint;not null;default:0;comment:歌曲时长" json:"duration"`
|
||||
Publish int `gorm:"column:publish;type:tinyint(1);not null;comment:是否发布" json:"publish"`
|
||||
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
|
||||
RawData string `gorm:"column:raw_data;type:text;comment:原始数据" json:"raw_data"`
|
||||
Power int `gorm:"column:power;type:smallint;not null;default:0;comment:消耗算力" json:"power"`
|
||||
PlayTimes int `gorm:"column:play_times;type:int;comment:播放次数" json:"play_times"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
func (SunoJob) TableName() string {
|
||||
func (m *SunoJob) TableName() string {
|
||||
return "chatgpt_suno_jobs"
|
||||
}
|
||||
|
||||
@@ -1,23 +1,33 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
BaseModel
|
||||
Username string
|
||||
Nickname string
|
||||
Email string
|
||||
Mobile string
|
||||
Password string
|
||||
Avatar string
|
||||
Salt string // 密码盐
|
||||
Power int // 剩余算力
|
||||
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
|
||||
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
|
||||
ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型
|
||||
ExpiredTime int64 // 账户到期时间
|
||||
Status bool `gorm:"default:true"` // 当前状态
|
||||
LastLoginAt int64 // 最后登录时间
|
||||
LastLoginIp string // 最后登录 IP
|
||||
OpenId string `gorm:"column:openid"`
|
||||
Platform string `json:"platform"`
|
||||
Vip bool // 是否 VIP 会员
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Username string `gorm:"column:username;type:varchar(30);uniqueIndex;not null;comment:用户名" json:"username"`
|
||||
Mobile string `gorm:"column:mobile;type:char(11);comment:手机号" json:"mobile"`
|
||||
Email string `gorm:"column:email;type:varchar(50);comment:邮箱地址" json:"email"`
|
||||
Nickname string `gorm:"column:nickname;type:varchar(30);not null;comment:昵称" json:"nickname"`
|
||||
Password string `gorm:"column:password;type:char(64);not null;comment:密码" json:"password"`
|
||||
Avatar string `gorm:"column:avatar;type:varchar(255);not null;comment:头像" json:"avatar"`
|
||||
Salt string `gorm:"column:salt;type:char(12);not null;comment:密码盐" json:"salt"`
|
||||
Power int `gorm:"column:power;type:int;default:0;comment:剩余算力" json:"power"`
|
||||
ExpiredTime int64 `gorm:"column:expired_time;type:int;not null;comment:用户过期时间" json:"expired_time"`
|
||||
Status bool `gorm:"column:status;type:tinyint(1);not null;comment:当前状态" json:"status"`
|
||||
ChatConfig string `gorm:"column:chat_config_json;type:text;default:null;comment:聊天配置json" json:"chat_config"`
|
||||
ChatRoles string `gorm:"column:chat_roles_json;type:text;default:null;comment:聊天角色 json" json:"chat_roles"`
|
||||
ChatModels string `gorm:"column:chat_models_json;type:text;default:null;comment:AI模型 json" json:"chat_models"`
|
||||
LastLoginAt int64 `gorm:"column:last_login_at;type:int;not null;comment:最后登录时间" json:"last_login_at"`
|
||||
Vip bool `gorm:"column:vip;type:tinyint(1);not null;default:0;comment:是否会员" json:"vip"`
|
||||
LastLoginIp string `gorm:"column:last_login_ip;type:char(16);not null;comment:最后登录 IP" json:"last_login_ip"`
|
||||
OpenId string `gorm:"column:openid;type:varchar(100);comment:第三方登录账号ID" json:"openid"`
|
||||
Platform string `gorm:"column:platform;type:varchar(30);comment:登录平台" json:"platform"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *User) TableName() string {
|
||||
return "chatgpt_users"
|
||||
}
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserLoginLog struct {
|
||||
BaseModel
|
||||
UserId uint
|
||||
Username string
|
||||
LoginIp string
|
||||
LoginAddress string
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
|
||||
Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"`
|
||||
LoginIp string `gorm:"column:login_ip;type:char(16);not null;comment:登录IP" json:"login_ip"`
|
||||
LoginAddress string `gorm:"column:login_address;type:varchar(30);not null;comment:登录地址" json:"login_address"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (m *UserLoginLog) TableName() string {
|
||||
return "chatgpt_user_login_logs"
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user