mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-05 00:33:47 +08:00
Compare commits
140 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8eed7ff534 | ||
|
|
c79c4e74d0 | ||
|
|
f1855fd0a1 | ||
|
|
1f964c74e9 | ||
|
|
4fb2c5803c | ||
|
|
b5947545cb | ||
|
|
342b76f666 | ||
|
|
49b5906bc7 | ||
|
|
3075bfb7fc | ||
|
|
82e06fad33 | ||
|
|
4a9028747b | ||
|
|
4a8ff0ccf0 | ||
|
|
99341f0484 | ||
|
|
f58ac29ad0 | ||
|
|
7060edb3e5 | ||
|
|
41ae411f9b | ||
|
|
79b7fee47c | ||
|
|
0044bf10af | ||
|
|
e9348d3611 | ||
|
|
b9236e09a7 | ||
|
|
09b38d5f42 | ||
|
|
7bb539a06e | ||
|
|
5cdada8265 | ||
|
|
4147c217b1 | ||
|
|
8dda639b23 | ||
|
|
8487d2c9eb | ||
|
|
c5e583b215 | ||
|
|
549f618cff | ||
|
|
e9a3510346 | ||
|
|
30e6e963b3 | ||
|
|
c72d963f45 | ||
|
|
172d498618 | ||
|
|
313993532e | ||
|
|
e53db3582c | ||
|
|
72c6bd3f77 | ||
|
|
ca8b349df3 | ||
|
|
1b206c3640 | ||
|
|
c60276fc9f | ||
|
|
d00a3167c0 | ||
|
|
6b1cd8c30c | ||
|
|
46f12dc9ad | ||
|
|
a3e1d8ae21 | ||
|
|
72a066b93e | ||
|
|
0327a829ac | ||
|
|
882e9b8819 | ||
|
|
ef58cfadaa | ||
|
|
bf958d6113 | ||
|
|
71611273d7 | ||
|
|
b27c654311 | ||
|
|
90930ea9f9 | ||
|
|
1ab2185ff1 | ||
|
|
0f2f978d4c | ||
|
|
f61963b0b0 | ||
|
|
2aa413960d | ||
|
|
aa4bbba5ec | ||
|
|
eba61fea2d | ||
|
|
34e3455128 | ||
|
|
07dca3e739 | ||
|
|
4cb4b145f9 | ||
|
|
1ed417cb69 | ||
|
|
6cf91a84ca | ||
|
|
0b566980fc | ||
|
|
f86176b342 | ||
|
|
c700b32670 | ||
|
|
22641b452a | ||
|
|
d3fbb8c19e | ||
|
|
e3bb69ff10 | ||
|
|
770360c614 | ||
|
|
f302a0478f | ||
|
|
a88697b43a | ||
|
|
cc6f140812 | ||
|
|
424f2b3bdc | ||
|
|
ec0c13a600 | ||
|
|
a1f03bec4c | ||
|
|
b5bd4a5e0e | ||
|
|
7c2e49bfdb | ||
|
|
f80fe6d041 | ||
|
|
72f80a96bc | ||
|
|
2de655a1cf | ||
|
|
da2bd4a501 | ||
|
|
e0aa62c40d | ||
|
|
9d26a892d1 | ||
|
|
4ece7f2847 | ||
|
|
32368caf1b | ||
|
|
e91f54e79e | ||
|
|
bb8f4c57c4 | ||
|
|
43bfac99b6 | ||
|
|
be379b6d63 | ||
|
|
17f3c9b840 | ||
|
|
24de97fac2 | ||
|
|
bf27b44fee | ||
|
|
1802b4fe4d | ||
|
|
241a5c7bc9 | ||
|
|
557d547bf1 | ||
|
|
2e7b75affb | ||
|
|
bc21a1d443 | ||
|
|
3fc9e10a24 | ||
|
|
5fa1aa2060 | ||
|
|
ff4b267858 | ||
|
|
a590d0497f | ||
|
|
ac30d906f0 | ||
|
|
5bc071e038 | ||
|
|
88b956cf98 | ||
|
|
f725cf4661 | ||
|
|
057cc1e8a6 | ||
|
|
de122735b8 | ||
|
|
e87ede981c | ||
|
|
606fb498e1 | ||
|
|
a0c06e40a4 | ||
|
|
aba8f57279 | ||
|
|
960286a350 | ||
|
|
8c93fa51f6 | ||
|
|
cb0e7d64ff | ||
|
|
8e7413da97 | ||
|
|
a36f14eb94 | ||
|
|
f2f9f6e488 | ||
|
|
85068b8ca2 | ||
|
|
f2cfcfeefc | ||
|
|
755273a898 | ||
|
|
d4a24a0f1d | ||
|
|
92281fcbb7 | ||
|
|
636db4afcc | ||
|
|
ba25b8755e | ||
|
|
6399d13a49 | ||
|
|
06fa54fd25 | ||
|
|
a335b965d0 | ||
|
|
725adaa7d0 | ||
|
|
7e7e81e974 | ||
|
|
8cfe6bfc17 | ||
|
|
33de83f2ac | ||
|
|
3f856afec8 | ||
|
|
02a9c422fe | ||
|
|
ca69341024 | ||
|
|
169bf069ce | ||
|
|
1bee0ab04d | ||
|
|
440d91dd0e | ||
|
|
8168e246a8 | ||
|
|
2ef07574ae | ||
|
|
37392f2bb2 | ||
|
|
a80cd3848e |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
deploy
|
||||||
|
docs
|
||||||
|
api/static
|
||||||
|
web/node_modules
|
||||||
|
desktop
|
||||||
|
|
||||||
29
CHANGELOG.md
29
CHANGELOG.md
@@ -1,4 +1,33 @@
|
|||||||
# 更新日志
|
# 更新日志
|
||||||
|
## 4.0.2
|
||||||
|
* 功能新增:支持前端菜单可以配置
|
||||||
|
* 功能优化:手机端支持免登录预览功能
|
||||||
|
* 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||||
|
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
|
||||||
|
|
||||||
|
## v4.0.1
|
||||||
|
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些
|
||||||
|
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 MJ-Plus 中转
|
||||||
|
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||||
|
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
|
||||||
|
* Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
|
||||||
|
|
||||||
|
## v4.0.0
|
||||||
|
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
||||||
|
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
|
||||||
|
|
||||||
|
* 功能重构:重构整体系统,全部采用算力来进行结算
|
||||||
|
* 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
|
||||||
|
* 功能优化:移动端聊天页面图片支持预览和放大功能
|
||||||
|
* 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
|
||||||
|
* 功能优化:**PC端不登录也可以预览功能,只有在发起操作的时候才需要登录**
|
||||||
|
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
|
||||||
|
* 功能新增:支持H5支付
|
||||||
|
* 功能优化:支持数学公式的识别和美化输出
|
||||||
|
* 功能新增:新增算力消费日志功能
|
||||||
|
* 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
|
||||||
|
* 功能新增:管理后台新增7日内新增用户和新增订单统计
|
||||||
|
|
||||||
## v3.2.7
|
## v3.2.7
|
||||||
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
|
||||||
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
|
||||||
|
|||||||
@@ -73,9 +73,11 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
|
|||||||
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
|
**演示站不提供任何充值点卡售卖或者VIP充值服务。** 如果您体验过后觉得还不错的话,可以花两分钟用下面的一键部署脚本自己部署一套。
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.5-400fea2598.sh)"
|
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.7-6c232bdaf8.sh)"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
最新版本的一键部署脚本请参考 [**ChatGPT-Plus 文档**](https://ai.r9it.com/docs/install/)。
|
||||||
|
|
||||||
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
|
目前仅支持 Ubuntu 和 Centos 系统。 部署成功之后可以访问下面地址
|
||||||
|
|
||||||
* 前端访问地址:http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
|
* 前端访问地址:http://localhost:8080/chat 使用移动设备访问会自动跳转到移动端页面。
|
||||||
@@ -145,6 +147,3 @@ KEY。
|
|||||||

|

|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
Listen = "0.0.0.0:5678"
|
Listen = "0.0.0.0:5678"
|
||||||
ProxyURL = "" # 如 http://127.0.0.1:7777
|
ProxyURL = "" # 如 http://127.0.0.1:7777
|
||||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
|
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
|
||||||
StaticDir = "./static" # 静态资源的目录
|
StaticDir = "./static" # 静态资源的目录
|
||||||
StaticUrl = "/static" # 静态资源访问 URL
|
StaticUrl = "/static" # 静态资源访问 URL
|
||||||
AesEncryptKey = ""
|
AesEncryptKey = ""
|
||||||
@@ -10,10 +10,6 @@ WeChatBot = false
|
|||||||
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
|
SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
|
||||||
MaxAge = 86400
|
MaxAge = 86400
|
||||||
|
|
||||||
[Manager]
|
|
||||||
Username = "admin"
|
|
||||||
Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
|
|
||||||
|
|
||||||
[Redis] # redis 配置信息
|
[Redis] # redis 配置信息
|
||||||
Host = "localhost"
|
Host = "localhost"
|
||||||
Port = 6379
|
Port = 6379
|
||||||
@@ -46,7 +42,7 @@ WeChatBot = false
|
|||||||
Active = "local" # 默认使用本地文件存储引擎
|
Active = "local" # 默认使用本地文件存储引擎
|
||||||
[OSS.Local]
|
[OSS.Local]
|
||||||
BasePath = "./static/upload" # 本地文件上传根路径
|
BasePath = "./static/upload" # 本地文件上传根路径
|
||||||
BaseURL = "/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可
|
BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
|
||||||
[OSS.Minio]
|
[OSS.Minio]
|
||||||
Endpoint = "" # 如 172.22.11.200:9000
|
Endpoint = "" # 如 172.22.11.200:9000
|
||||||
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
|
AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
|
||||||
@@ -60,25 +56,24 @@ WeChatBot = false
|
|||||||
AccessSecret = ""
|
AccessSecret = ""
|
||||||
Bucket = ""
|
Bucket = ""
|
||||||
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
|
Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
|
||||||
|
[OSS.AliYun]
|
||||||
|
Endpoint = "oss-cn-hangzhou.aliyuncs.com"
|
||||||
|
AccessKey = ""
|
||||||
|
AccessSecret = ""
|
||||||
|
Bucket = "chatgpt-plus"
|
||||||
|
SubDir = ""
|
||||||
|
Domain = ""
|
||||||
|
|
||||||
[[MjConfigs]]
|
[[MjProxyConfigs]]
|
||||||
Enabled = false
|
Enabled = true
|
||||||
UserToken = ""
|
ApiURL = "http://midjourney-proxy:8082"
|
||||||
BotToken = ""
|
ApiKey = "sk-geekmaster"
|
||||||
GuildId = ""
|
|
||||||
ChanelId = ""
|
|
||||||
UseCDN = false #是否使用反向代理访问,设置为true下面的设置才会生效
|
|
||||||
DiscordAPI = "" # discord API 反代地址
|
|
||||||
DiscordCDN = "" # mj 图片反代地址
|
|
||||||
DiscordGateway = "" # discord 机器人反代地址
|
|
||||||
|
|
||||||
[[MjPlusConfigs]]
|
[[MjPlusConfigs]]
|
||||||
Enabled = false
|
Enabled = false
|
||||||
ApiURL = "https://api.chatgpt-plus.net" # 目前暂时不支持更改
|
ApiURL = "https://api.chat-plus.net"
|
||||||
CdnURL = "" # CND 加速的 URL,如果有的话就设置
|
|
||||||
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
|
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
|
||||||
ApiKey = "sk-xxx"
|
ApiKey = "sk-xxx"
|
||||||
NotifyURL = "https://ai.r9it.com/api/mj/notify" # 这里需要改成你的域名
|
|
||||||
|
|
||||||
[[SdConfigs]]
|
[[SdConfigs]]
|
||||||
Enabled = false
|
Enabled = false
|
||||||
|
|||||||
@@ -28,10 +28,9 @@ type AppServer struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
Config *types.AppConfig
|
Config *types.AppConfig
|
||||||
Engine *gin.Engine
|
Engine *gin.Engine
|
||||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||||
|
|
||||||
ChatConfig *types.ChatConfig // chat config cache
|
SysConfig *types.SystemConfig // system config cache
|
||||||
SysConfig *types.SystemConfig // system config cache
|
|
||||||
|
|
||||||
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
||||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||||
@@ -47,7 +46,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
|||||||
Debug: false,
|
Debug: false,
|
||||||
Config: appConfig,
|
Config: appConfig,
|
||||||
Engine: gin.Default(),
|
Engine: gin.Default(),
|
||||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||||
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
@@ -69,23 +68,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *AppServer) Run(db *gorm.DB) error {
|
func (s *AppServer) Run(db *gorm.DB) error {
|
||||||
// load chat config from database
|
|
||||||
var chatConfig model.Config
|
|
||||||
res := db.Where("marker", "chat").First(&chatConfig)
|
|
||||||
if res.Error != nil {
|
|
||||||
return res.Error
|
|
||||||
}
|
|
||||||
err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// load system configs
|
// load system configs
|
||||||
var sysConfig model.Config
|
var sysConfig model.Config
|
||||||
res = db.Where("marker", "system").First(&sysConfig)
|
res := db.Where("marker", "system").First(&sysConfig)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -143,73 +132,64 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
// 用户授权验证
|
// 用户授权验证
|
||||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if c.Request.URL.Path == "/api/user/login" ||
|
|
||||||
c.Request.URL.Path == "/api/user/resetPass" ||
|
|
||||||
c.Request.URL.Path == "/api/admin/login" ||
|
|
||||||
c.Request.URL.Path == "/api/user/register" ||
|
|
||||||
c.Request.URL.Path == "/api/chat/history" ||
|
|
||||||
c.Request.URL.Path == "/api/chat/detail" ||
|
|
||||||
c.Request.URL.Path == "/api/role/list" ||
|
|
||||||
c.Request.URL.Path == "/api/mj/jobs" ||
|
|
||||||
c.Request.URL.Path == "/api/mj/client" ||
|
|
||||||
c.Request.URL.Path == "/api/mj/notify" ||
|
|
||||||
c.Request.URL.Path == "/api/invite/hits" ||
|
|
||||||
c.Request.URL.Path == "/api/sd/jobs" ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
|
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
|
|
||||||
c.Request.URL.Path == "/api/admin/config/get" {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokenString string
|
var tokenString string
|
||||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
|
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||||
|
if isAdminApi { // 后台管理 API
|
||||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||||
} else if c.Request.URL.Path == "/api/chat/new" {
|
} else if c.Request.URL.Path == "/api/chat/new" {
|
||||||
tokenString = c.Query("token")
|
tokenString = c.Query("token")
|
||||||
} else {
|
} else {
|
||||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
resp.ERROR(c, "You should put Authorization in request headers")
|
if needLogin(c) {
|
||||||
c.Abort()
|
resp.ERROR(c, "You should put Authorization in request headers")
|
||||||
return
|
c.Abort()
|
||||||
|
return
|
||||||
|
} else { // 直接放行
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
|
if isAdminApi {
|
||||||
|
return []byte(s.Config.AdminSession.SecretKey), nil
|
||||||
|
} else {
|
||||||
|
return []byte(s.Config.Session.SecretKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
return []byte(s.Config.Session.SecretKey), nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil && needLogin(c) {
|
||||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
if !ok || !token.Valid {
|
if !ok || !token.Valid && needLogin(c) {
|
||||||
resp.NotAuth(c, "Token is invalid")
|
resp.NotAuth(c, "Token is invalid")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
|
||||||
resp.NotAuth(c, "Token is expired")
|
resp.NotAuth(c, "Token is expired")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||||
if _, err := client.Get(context.Background(), key).Result(); err != nil {
|
if isAdminApi {
|
||||||
|
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
||||||
|
}
|
||||||
|
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
|
||||||
resp.NotAuth(c, "Token is not found in redis")
|
resp.NotAuth(c, "Token is not found in redis")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
@@ -218,6 +198,36 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func needLogin(c *gin.Context) bool {
|
||||||
|
if c.Request.URL.Path == "/api/user/login" ||
|
||||||
|
c.Request.URL.Path == "/api/user/resetPass" ||
|
||||||
|
c.Request.URL.Path == "/api/admin/login" ||
|
||||||
|
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||||
|
c.Request.URL.Path == "/api/user/register" ||
|
||||||
|
c.Request.URL.Path == "/api/chat/history" ||
|
||||||
|
c.Request.URL.Path == "/api/chat/detail" ||
|
||||||
|
c.Request.URL.Path == "/api/chat/list" ||
|
||||||
|
c.Request.URL.Path == "/api/role/list" ||
|
||||||
|
c.Request.URL.Path == "/api/model/list" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/client" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/notify" ||
|
||||||
|
c.Request.URL.Path == "/api/invite/hits" ||
|
||||||
|
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||||
|
c.Request.URL.Path == "/api/sd/client" ||
|
||||||
|
c.Request.URL.Path == "/api/config/get" ||
|
||||||
|
c.Request.URL.Path == "/api/product/list" ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/static/") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// 统一参数处理
|
// 统一参数处理
|
||||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
func parameterHandlerMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
return &types.AppConfig{
|
return &types.AppConfig{
|
||||||
Listen: "0.0.0.0:5678",
|
Listen: "0.0.0.0:5678",
|
||||||
ProxyURL: "",
|
ProxyURL: "",
|
||||||
Manager: types.Manager{Username: "admin", Password: "admin123"},
|
|
||||||
StaticDir: "./static",
|
StaticDir: "./static",
|
||||||
StaticUrl: "http://localhost/5678/static",
|
StaticUrl: "http://localhost/5678/static",
|
||||||
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
|
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
|
||||||
|
|||||||
@@ -54,10 +54,14 @@ type ChatSession struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatModel struct {
|
type ChatModel struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Platform Platform `json:"platform"`
|
Platform Platform `json:"platform"`
|
||||||
Value string `json:"value"`
|
Name string `json:"name"`
|
||||||
Weight int `json:"weight"`
|
Value string `json:"value"`
|
||||||
|
Power int `json:"power"`
|
||||||
|
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||||
|
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||||
|
Temperature float32 `json:"temperature"` // 模型温度
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiError struct {
|
type ApiError struct {
|
||||||
@@ -72,23 +76,36 @@ type ApiError struct {
|
|||||||
const PromptMsg = "prompt" // prompt message
|
const PromptMsg = "prompt" // prompt message
|
||||||
const ReplyMsg = "reply" // reply message
|
const ReplyMsg = "reply" // reply message
|
||||||
|
|
||||||
var ModelToTokens = map[string]int{
|
// PowerType 算力日志类型
|
||||||
"gpt-3.5-turbo": 4096,
|
type PowerType int
|
||||||
"gpt-3.5-turbo-16k": 16384,
|
|
||||||
"gpt-4": 8192,
|
const (
|
||||||
"gpt-4-32k": 32768,
|
PowerRecharge = PowerType(1) // 充值
|
||||||
"chatglm_pro": 32768, // 清华智普
|
PowerConsume = PowerType(2) // 消费
|
||||||
"chatglm_std": 16384,
|
PowerRefund = PowerType(3) // 任务(SD,MJ)执行失败,退款
|
||||||
"chatglm_lite": 4096,
|
PowerInvite = PowerType(4) // 邀请奖励
|
||||||
"ernie_bot_turbo": 8192, // 文心一言
|
PowerReward = PowerType(5) // 众筹
|
||||||
"general": 8192, // 科大讯飞
|
PowerGift = PowerType(6) // 系统赠送
|
||||||
"general2": 8192,
|
)
|
||||||
"general3": 8192,
|
|
||||||
|
func (t PowerType) String() string {
|
||||||
|
switch t {
|
||||||
|
case PowerRecharge:
|
||||||
|
return "充值"
|
||||||
|
case PowerConsume:
|
||||||
|
return "消费"
|
||||||
|
case PowerRefund:
|
||||||
|
return "退款"
|
||||||
|
case PowerReward:
|
||||||
|
return "众筹"
|
||||||
|
|
||||||
|
}
|
||||||
|
return "其他"
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelMaxToken(model string) int {
|
type PowerMark int
|
||||||
if token, ok := ModelToTokens[model]; ok {
|
|
||||||
return token
|
const (
|
||||||
}
|
PowerSub = PowerMark(0)
|
||||||
return 4096
|
PowerAdd = PowerMark(1)
|
||||||
}
|
)
|
||||||
|
|||||||
@@ -5,22 +5,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
Path string `toml:"-"`
|
Path string `toml:"-"`
|
||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
ProxyURL string
|
AdminSession Session
|
||||||
MysqlDns string // mysql 连接地址
|
ProxyURL string
|
||||||
Manager Manager // 后台管理员账户信息
|
MysqlDns string // mysql 连接地址
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||||
SMS SMSConfig // send mobile message config
|
SMS SMSConfig // send mobile message config
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
MjConfigs []MidJourneyConfig // mj AI draw service pool
|
MjProxyConfigs []MjProxyConfig // MJ proxy config
|
||||||
MjPlusConfigs []MidJourneyPlusConfig // MJ plus config
|
MjPlusConfigs []MjPlusConfig // MJ plus config
|
||||||
WeChatBot bool // 是否启用微信机器人
|
WeChatBot bool // 是否启用微信机器人
|
||||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
||||||
|
|
||||||
XXLConfig XXLConfig
|
XXLConfig XXLConfig
|
||||||
AlipayConfig AlipayConfig
|
AlipayConfig AlipayConfig
|
||||||
@@ -43,32 +43,25 @@ type ChatPlusApiConfig struct {
|
|||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyConfig struct {
|
type MjProxyConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
UserToken string
|
ApiURL string // api 地址
|
||||||
BotToken string
|
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||||
GuildId string // Server ID
|
ApiKey string
|
||||||
ChanelId string // Chanel ID
|
|
||||||
UseCDN bool
|
|
||||||
ImgCdnURL string // 图片反代加速地址
|
|
||||||
DiscordAPI string
|
|
||||||
DiscordGateway string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StableDiffusionConfig struct {
|
type StableDiffusionConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
ApiURL string
|
Model string // 模型名称
|
||||||
ApiKey string
|
ApiURL string
|
||||||
Txt2ImgJsonPath string
|
ApiKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyPlusConfig struct {
|
type MjPlusConfig struct {
|
||||||
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
||||||
ApiURL string // api 地址
|
ApiURL string // api 地址
|
||||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||||
CdnURL string // CDN 加速地址
|
ApiKey string
|
||||||
ApiKey string
|
|
||||||
NotifyURL string // 任务进度更新回调地址
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AlipayConfig struct {
|
type AlipayConfig struct {
|
||||||
@@ -81,6 +74,7 @@ type AlipayConfig struct {
|
|||||||
AlipayPublicKey string // 支付宝公钥文件路径
|
AlipayPublicKey string // 支付宝公钥文件路径
|
||||||
RootCert string // Root 秘钥路径
|
RootCert string // Root 秘钥路径
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知回调
|
||||||
|
ReturnURL string // 支付成功返回地址
|
||||||
}
|
}
|
||||||
|
|
||||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||||
@@ -90,6 +84,7 @@ type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
|||||||
AppSecret string // app 密钥
|
AppSecret string // app 密钥
|
||||||
ApiURL string // 支付网关
|
ApiURL string // 支付网关
|
||||||
NotifyURL string // 异步通知回调
|
NotifyURL string // 异步通知回调
|
||||||
|
ReturnURL string // 支付成功返回地址
|
||||||
}
|
}
|
||||||
|
|
||||||
// JPayConfig PayJs 支付配置
|
// JPayConfig PayJs 支付配置
|
||||||
@@ -100,6 +95,7 @@ type JPayConfig struct {
|
|||||||
PrivateKey string // 私钥
|
PrivateKey string // 私钥
|
||||||
ApiURL string // API 网关
|
ApiURL string // API 网关
|
||||||
NotifyURL string // 异步回调地址
|
NotifyURL string // 异步回调地址
|
||||||
|
ReturnURL string // 支付成功返回地址
|
||||||
}
|
}
|
||||||
|
|
||||||
type XXLConfig struct { // XXL 任务调度配置
|
type XXLConfig struct { // XXL 任务调度配置
|
||||||
@@ -122,26 +118,6 @@ func (c RedisConfig) Url() string {
|
|||||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager 管理员
|
|
||||||
type Manager struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatConfig 系统默认的聊天配置
|
|
||||||
type ChatConfig struct {
|
|
||||||
OpenAI ModelAPIConfig `json:"open_ai"`
|
|
||||||
Azure ModelAPIConfig `json:"azure"`
|
|
||||||
ChatGML ModelAPIConfig `json:"chat_gml"`
|
|
||||||
Baidu ModelAPIConfig `json:"baidu"`
|
|
||||||
XunFei ModelAPIConfig `json:"xun_fei"`
|
|
||||||
|
|
||||||
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
|
|
||||||
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
|
|
||||||
ContextDeep int `json:"context_deep"` // 上下文深度
|
|
||||||
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
|
|
||||||
}
|
|
||||||
|
|
||||||
type Platform string
|
type Platform string
|
||||||
|
|
||||||
const OpenAI = Platform("OpenAI")
|
const OpenAI = Platform("OpenAI")
|
||||||
@@ -151,42 +127,33 @@ const Baidu = Platform("Baidu")
|
|||||||
const XunFei = Platform("XunFei")
|
const XunFei = Platform("XunFei")
|
||||||
const QWen = Platform("QWen")
|
const QWen = Platform("QWen")
|
||||||
|
|
||||||
// UserChatConfig 用户的聊天配置
|
|
||||||
type UserChatConfig struct {
|
|
||||||
ApiKeys map[Platform]string `json:"api_keys"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type InviteReward struct {
|
|
||||||
ChatCalls int `json:"chat_calls"`
|
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelAPIConfig struct {
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SystemConfig struct {
|
type SystemConfig struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title,omitempty"`
|
||||||
AdminTitle string `json:"admin_title"`
|
AdminTitle string `json:"admin_title,omitempty"`
|
||||||
InitChatCalls int `json:"init_chat_calls"` // 新用户注册赠送对话次数
|
Logo string `json:"logo,omitempty"`
|
||||||
InitImgCalls int `json:"init_img_calls"` // 新用户注册赠送绘图次数
|
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||||
VipMonthCalls int `json:"vip_month_calls"` // VIP 会员每月赠送的对话次数
|
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
|
||||||
VipMonthImgCalls int `json:"vip_month_img_calls"` // VIP 会员每月赠送绘图次数
|
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||||
|
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
||||||
|
|
||||||
RegisterWays []string `json:"register_ways"` // 注册方式:支持手机,邮箱注册
|
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册
|
||||||
EnabledRegister bool `json:"enabled_register"` // 是否开放注册
|
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||||
|
|
||||||
RewardImg string `json:"reward_img"` // 众筹收款二维码地址
|
RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址
|
||||||
EnabledReward bool `json:"enabled_reward"` // 启用众筹功能
|
EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能
|
||||||
ChatCallPrice float64 `json:"chat_call_price"` // 对话单次调用费用
|
PowerPrice float64 `json:"power_price,omitempty"` // 算力单价
|
||||||
ImgCallPrice float64 `json:"img_call_price"` // 绘图单次调用费用
|
|
||||||
|
|
||||||
OrderPayTimeout int `json:"order_pay_timeout"` //订单支付超时时间
|
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||||
DefaultModels []string `json:"default_models"` // 默认开通的 AI 模型
|
VipInfoText string `json:"vip_info_text"` // 会员页面充值说明
|
||||||
OrderPayInfoText string `json:"order_pay_info_text"` // 订单支付页面说明文字
|
DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型
|
||||||
InviteChatCalls int `json:"invite_chat_calls"` // 邀请用户注册奖励对话次数
|
|
||||||
InviteImgCalls int `json:"invite_img_calls"` // 邀请用户注册奖励绘图次数
|
|
||||||
|
|
||||||
WechatCardURL string `json:"wechat_card_url"` // 微信客服地址
|
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||||
|
MjActionPower int `json:"mj_action_power"` // MJ 操作(放大,变换)消耗算力
|
||||||
|
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||||
|
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
|
||||||
|
|
||||||
|
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||||
|
|
||||||
|
EnableContext bool `json:"enable_context,omitempty"`
|
||||||
|
ContextDeep int `json:"context_deep,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ type MKey interface {
|
|||||||
string | int | uint
|
string | int | uint
|
||||||
}
|
}
|
||||||
type MValue interface {
|
type MValue interface {
|
||||||
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
*WsClient | *ChatSession | context.CancelFunc | []Message
|
||||||
}
|
}
|
||||||
type LMap[K MKey, T MValue] struct {
|
type LMap[K MKey, T MValue] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
|||||||
@@ -9,10 +9,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type OrderRemark struct {
|
type OrderRemark struct {
|
||||||
Days int `json:"days"` // 有效期
|
Days int `json:"days"` // 有效期
|
||||||
Calls int `json:"calls"` // 增加对话次数
|
Power int `json:"power"` // 增加算力点数
|
||||||
ImgCalls int `json:"img_calls"` // 增加绘图次数
|
Name string `json:"name"` // 产品名称
|
||||||
Name string `json:"name"` // 产品名称
|
|
||||||
Price float64 `json:"price"`
|
Price float64 `json:"price"`
|
||||||
Discount float64 `json:"discount"`
|
Discount float64 `json:"discount"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ type SdTask struct {
|
|||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Params SdTaskParams `json:"params"`
|
Params SdTaskParams `json:"params"`
|
||||||
RetryCount int `json:"retry_count"`
|
RetryCount int `json:"retry_count"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ const (
|
|||||||
Success = BizCode(0)
|
Success = BizCode(0)
|
||||||
Failed = BizCode(1)
|
Failed = BizCode(1)
|
||||||
NotAuthorized = BizCode(400) // 未授权
|
NotAuthorized = BizCode(400) // 未授权
|
||||||
|
NotPermission = BizCode(403) // 没有权限
|
||||||
|
|
||||||
OkMsg = "Success"
|
OkMsg = "Success"
|
||||||
ErrorMsg = "系统开小差了"
|
ErrorMsg = "系统开小差了"
|
||||||
|
|||||||
10
api/go.mod
10
api/go.mod
@@ -25,7 +25,15 @@ require (
|
|||||||
|
|
||||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
||||||
|
|
||||||
require github.com/bg5t/mydiscordgo v0.28.1
|
require (
|
||||||
|
github.com/mojocn/base64Captcha v1.3.1
|
||||||
|
github.com/shopspring/decimal v1.3.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||||
|
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/andybalholm/brotli v1.0.4 // indirect
|
github.com/andybalholm/brotli v1.0.4 // indirect
|
||||||
|
|||||||
12
api/go.sum
12
api/go.sum
@@ -7,8 +7,6 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
|
|||||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||||
github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8=
|
|
||||||
github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8=
|
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
@@ -65,6 +63,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||||
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
@@ -76,7 +76,6 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
|
|||||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
@@ -129,6 +128,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
|||||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
|
||||||
|
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||||
@@ -166,6 +167,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
|||||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
|
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||||
|
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
@@ -219,7 +222,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
|||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||||
@@ -227,6 +229,8 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
|||||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
|
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
|
||||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
|
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
|
||||||
|
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
|
||||||
|
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
||||||
|
|||||||
@@ -5,10 +5,15 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/handler"
|
"chatplus/handler"
|
||||||
logger2 "chatplus/logger"
|
logger2 "chatplus/logger"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/store/vo"
|
||||||
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/mojocn/base64Captcha"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -17,47 +22,88 @@ import (
|
|||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
// Manager 管理员
|
||||||
|
type Manager struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Captcha string `json:"captcha"` // 验证码
|
||||||
|
CaptchaId string `json:"captcha_id"` // 验证码id
|
||||||
|
}
|
||||||
|
|
||||||
|
const SuperManagerID = 1
|
||||||
|
|
||||||
type ManagerHandler struct {
|
type ManagerHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
|
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
|
||||||
h := ManagerHandler{db: db, redis: client}
|
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login 登录
|
// Login 登录
|
||||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||||
var data types.Manager
|
var data Manager
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
manager := h.App.Config.Manager
|
|
||||||
if data.Username == manager.Username && data.Password == manager.Password {
|
// add captcha
|
||||||
// 创建 token
|
if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
resp.ERROR(c, "验证码错误!")
|
||||||
"user_id": manager.Username,
|
return
|
||||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
|
||||||
})
|
|
||||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 保存到 redis
|
|
||||||
key := "users/" + manager.Username
|
|
||||||
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
|
|
||||||
resp.ERROR(c, "error with save token: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c, tokenString)
|
|
||||||
} else {
|
|
||||||
resp.ERROR(c, "用户名或者密码错误")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var manager model.AdminUser
|
||||||
|
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
password := utils.GenPassword(data.Password, manager.Salt)
|
||||||
|
if password != manager.Password {
|
||||||
|
resp.ERROR(c, "用户名或密码错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 超级管理员默认是ID:1
|
||||||
|
if manager.Id != SuperManagerID && manager.Status == false {
|
||||||
|
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 token
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"user_id": manager.Id,
|
||||||
|
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||||
|
})
|
||||||
|
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 保存到 redis
|
||||||
|
key := fmt.Sprintf("admin/%d", manager.Id)
|
||||||
|
if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
|
||||||
|
resp.ERROR(c, "error with save token: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新最后登录时间和IP
|
||||||
|
manager.LastLoginIp = c.ClientIP()
|
||||||
|
manager.LastLoginAt = time.Now().Unix()
|
||||||
|
h.DB.Updates(&manager)
|
||||||
|
|
||||||
|
var result = struct {
|
||||||
|
IsSuperAdmin bool `json:"is_super_admin"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}{
|
||||||
|
IsSuperAdmin: manager.Id == 1,
|
||||||
|
Token: tokenString,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout 注销
|
// Logout 注销
|
||||||
@@ -72,10 +118,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
|
|||||||
|
|
||||||
// Session 会话检测
|
// Session 会话检测
|
||||||
func (h *ManagerHandler) Session(c *gin.Context) {
|
func (h *ManagerHandler) Session(c *gin.Context) {
|
||||||
token := c.GetHeader(types.AdminAuthHeader)
|
id := h.GetLoginUserId(c)
|
||||||
if token == "" {
|
key := fmt.Sprintf("admin/%d", id)
|
||||||
|
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
} else {
|
return
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
}
|
||||||
|
var manager model.AdminUser
|
||||||
|
res := h.DB.Where("id", id).First(&manager)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 数据列表
|
||||||
|
func (h *ManagerHandler) List(c *gin.Context) {
|
||||||
|
var items []model.AdminUser
|
||||||
|
res := h.DB.Find(&items)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]vo.AdminUser, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
var u vo.AdminUser
|
||||||
|
err := utils.CopyObject(item, &u)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
u.Id = item.Id
|
||||||
|
u.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
users = append(users, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, users)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ManagerHandler) Save(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Status bool `json:"status"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.AdminUser
|
||||||
|
res := h.DB.Where("username", data.Username).First(&user)
|
||||||
|
if res.Error == nil {
|
||||||
|
resp.ERROR(c, "用户名已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成密码
|
||||||
|
salt := utils.RandString(8)
|
||||||
|
password := utils.GenPassword(data.Password, salt)
|
||||||
|
res = h.DB.Save(&model.AdminUser{
|
||||||
|
Username: data.Username,
|
||||||
|
Password: password,
|
||||||
|
Salt: salt,
|
||||||
|
Status: data.Status,
|
||||||
|
})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "failed with update database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove 删除管理员
|
||||||
|
func (h *ManagerHandler) Remove(c *gin.Context) {
|
||||||
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == SuperManagerID {
|
||||||
|
resp.ERROR(c, "超级管理员不能删除")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 启用/禁用
|
||||||
|
func (h *ManagerHandler) Enable(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPass 重置密码
|
||||||
|
func (h *ManagerHandler) ResetPass(c *gin.Context) {
|
||||||
|
id := h.GetLoginUserId(c)
|
||||||
|
if id != SuperManagerID {
|
||||||
|
resp.ERROR(c, "只有超级管理员能够进行该操作")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.AdminUser
|
||||||
|
res := h.DB.Where("id", data.Id).First(&user)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
|
user.Password = password
|
||||||
|
res = h.DB.Updates(&user)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,13 +14,10 @@ import (
|
|||||||
|
|
||||||
type ApiKeyHandler struct {
|
type ApiKeyHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||||
h := ApiKeyHandler{db: db}
|
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||||
@@ -32,7 +29,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
ApiURL string `json:"api_url"`
|
ApiURL string `json:"api_url"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
UseProxy bool `json:"use_proxy"`
|
ProxyURL string `json:"proxy_url"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -41,16 +38,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
|
|
||||||
apiKey := model.ApiKey{}
|
apiKey := model.ApiKey{}
|
||||||
if data.Id > 0 {
|
if data.Id > 0 {
|
||||||
h.db.Find(&apiKey, data.Id)
|
h.DB.Find(&apiKey, data.Id)
|
||||||
}
|
}
|
||||||
apiKey.Platform = data.Platform
|
apiKey.Platform = data.Platform
|
||||||
apiKey.Value = data.Value
|
apiKey.Value = data.Value
|
||||||
apiKey.Type = data.Type
|
apiKey.Type = data.Type
|
||||||
apiKey.ApiURL = data.ApiURL
|
apiKey.ApiURL = data.ApiURL
|
||||||
apiKey.Enabled = data.Enabled
|
apiKey.Enabled = data.Enabled
|
||||||
apiKey.UseProxy = data.UseProxy
|
apiKey.ProxyURL = data.ProxyURL
|
||||||
apiKey.Name = data.Name
|
apiKey.Name = data.Name
|
||||||
res := h.db.Save(&apiKey)
|
res := h.DB.Save(&apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -68,9 +65,14 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var items []model.ApiKey
|
var items []model.ApiKey
|
||||||
var keys = make([]vo.ApiKey, 0)
|
var keys = make([]vo.ApiKey, 0)
|
||||||
res := h.db.Find(&items)
|
res := h.DB.Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var key vo.ApiKey
|
var key vo.ApiKey
|
||||||
@@ -100,7 +102,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -109,10 +111,15 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ApiKeyHandler) Remove(c *gin.Context) {
|
func (h *ApiKeyHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
var data struct {
|
||||||
|
Id uint
|
||||||
if id > 0 {
|
}
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data.Id > 0 {
|
||||||
|
res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
39
api/handler/admin/captcha_handler.go
Normal file
39
api/handler/admin/captcha_handler.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/handler"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mojocn/base64Captcha"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CaptchaHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
|
||||||
|
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CaptchaVo struct {
|
||||||
|
CaptchaId string `json:"captcha_id"`
|
||||||
|
PicPath string `json:"pic_path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCaptcha 获取验证码
|
||||||
|
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
|
||||||
|
var captchaVo CaptchaVo
|
||||||
|
driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
|
||||||
|
cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
|
||||||
|
// b64s是图片的base64编码
|
||||||
|
id, b64s, err := cp.Generate()
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "生成验证码错误!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
captchaVo.CaptchaId = id
|
||||||
|
captchaVo.PicPath = b64s
|
||||||
|
|
||||||
|
resp.SUCCESS(c, captchaVo)
|
||||||
|
}
|
||||||
@@ -14,27 +14,30 @@ import (
|
|||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||||
h := ChatHandler{db: db}
|
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatItemVo struct {
|
type chatItemVo struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
UserId uint `json:"user_id"`
|
UserId uint `json:"user_id"`
|
||||||
ChatId string `json:"chat_id"`
|
ChatId string `json:"chat_id"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Model string `json:"model"`
|
Role vo.ChatRole `json:"role"`
|
||||||
Token int `json:"token"`
|
Model string `json:"model"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
Token int `json:"token"`
|
||||||
MsgNum int `json:"msg_num"` // 消息数量
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
MsgNum int `json:"msg_num"` // 消息数量
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) List(c *gin.Context) {
|
func (h *ChatHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
UserId uint `json:"user_id"`
|
UserId uint `json:"user_id"`
|
||||||
@@ -48,7 +51,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if data.Title != "" {
|
if data.Title != "" {
|
||||||
session = session.Where("title LIKE ?", "%"+data.Title+"%")
|
session = session.Where("title LIKE ?", "%"+data.Title+"%")
|
||||||
}
|
}
|
||||||
@@ -73,18 +76,23 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
userIds := make([]uint, 0)
|
userIds := make([]uint, 0)
|
||||||
chatIds := make([]string, 0)
|
chatIds := make([]string, 0)
|
||||||
|
roleIds := make([]uint, 0)
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
userIds = append(userIds, item.UserId)
|
userIds = append(userIds, item.UserId)
|
||||||
chatIds = append(chatIds, item.ChatId)
|
chatIds = append(chatIds, item.ChatId)
|
||||||
|
roleIds = append(roleIds, item.RoleId)
|
||||||
}
|
}
|
||||||
var messages []model.ChatMessage
|
var messages []model.ChatMessage
|
||||||
var users []model.User
|
var users []model.User
|
||||||
h.db.Where("chat_id IN ?", chatIds).Find(&messages)
|
var roles []model.ChatRole
|
||||||
h.db.Where("id IN ?", userIds).Find(&users)
|
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
||||||
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
|
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||||
|
|
||||||
tokenMap := make(map[string]int)
|
tokenMap := make(map[string]int)
|
||||||
userMap := make(map[uint]string)
|
userMap := make(map[uint]string)
|
||||||
msgMap := make(map[string]int)
|
msgMap := make(map[string]int)
|
||||||
|
roleMap := make(map[uint]vo.ChatRole)
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
tokenMap[msg.ChatId] += msg.Tokens
|
tokenMap[msg.ChatId] += msg.Tokens
|
||||||
msgMap[msg.ChatId] += 1
|
msgMap[msg.ChatId] += 1
|
||||||
@@ -92,6 +100,14 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
userMap[user.Id] = user.Username
|
userMap[user.Id] = user.Username
|
||||||
}
|
}
|
||||||
|
for _, r := range roles {
|
||||||
|
var roleVo vo.ChatRole
|
||||||
|
err := utils.CopyObject(r, &roleVo)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
roleMap[r.Id] = roleVo
|
||||||
|
}
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
list = append(list, chatItemVo{
|
list = append(list, chatItemVo{
|
||||||
UserId: item.UserId,
|
UserId: item.UserId,
|
||||||
@@ -101,6 +117,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
Model: item.Model,
|
Model: item.Model,
|
||||||
Token: tokenMap[item.ChatId],
|
Token: tokenMap[item.ChatId],
|
||||||
MsgNum: msgMap[item.ChatId],
|
MsgNum: msgMap[item.ChatId],
|
||||||
|
Role: roleMap[item.RoleId],
|
||||||
CreatedAt: item.CreatedAt.Unix(),
|
CreatedAt: item.CreatedAt.Unix(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -135,7 +152,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if data.Content != "" {
|
if data.Content != "" {
|
||||||
session = session.Where("content LIKE ?", "%"+data.Content+"%")
|
session = session.Where("content LIKE ?", "%"+data.Content+"%")
|
||||||
}
|
}
|
||||||
@@ -163,7 +180,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
|
|||||||
userIds = append(userIds, item.UserId)
|
userIds = append(userIds, item.UserId)
|
||||||
}
|
}
|
||||||
var users []model.User
|
var users []model.User
|
||||||
h.db.Where("id IN ?", userIds).Find(&users)
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
userMap := make(map[uint]string)
|
userMap := make(map[uint]string)
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
userMap[user.Id] = user.Username
|
userMap[user.Id] = user.Username
|
||||||
@@ -190,7 +207,7 @@ func (h *ChatHandler) History(c *gin.Context) {
|
|||||||
chatId := c.Query("chat_id") // 会话 ID
|
chatId := c.Query("chat_id") // 会话 ID
|
||||||
var items []model.ChatMessage
|
var items []model.ChatMessage
|
||||||
var messages = make([]vo.HistoryMessage, 0)
|
var messages = make([]vo.HistoryMessage, 0)
|
||||||
res := h.db.Where("chat_id = ?", chatId).Find(&items)
|
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No history message")
|
resp.ERROR(c, "No history message")
|
||||||
return
|
return
|
||||||
@@ -212,9 +229,14 @@ func (h *ChatHandler) History(c *gin.Context) {
|
|||||||
// RemoveChat 删除对话
|
// RemoveChat 删除对话
|
||||||
func (h *ChatHandler) RemoveChat(c *gin.Context) {
|
func (h *ChatHandler) RemoveChat(c *gin.Context) {
|
||||||
chatId := h.GetTrim(c, "chat_id")
|
chatId := h.GetTrim(c, "chat_id")
|
||||||
tx := h.db.Begin()
|
if chatId == "" {
|
||||||
|
resp.ERROR(c, "请传入 ChatId")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := h.DB.Begin()
|
||||||
// 删除聊天记录
|
// 删除聊天记录
|
||||||
res := tx.Unscoped().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
|
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "failed to remove chat message")
|
resp.ERROR(c, "failed to remove chat message")
|
||||||
return
|
return
|
||||||
@@ -235,7 +257,7 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
|
|||||||
// RemoveMessage 删除聊天记录
|
// RemoveMessage 删除聊天记录
|
||||||
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
tx := h.db.Unscoped().Delete(&model.ChatMessage{}, id)
|
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -15,26 +15,26 @@ import (
|
|||||||
|
|
||||||
type ChatModelHandler struct {
|
type ChatModelHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||||
h := ChatModelHandler{db: db}
|
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatModelHandler) Save(c *gin.Context) {
|
func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
SortNum int `json:"sort_num"`
|
SortNum int `json:"sort_num"`
|
||||||
Open bool `json:"open"`
|
Open bool `json:"open"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Weight int `json:"weight"`
|
Power int `json:"power"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||||
|
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||||
|
Temperature float32 `json:"temperature"` // 模型温度
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -42,18 +42,21 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
item := model.ChatModel{
|
item := model.ChatModel{
|
||||||
Platform: data.Platform,
|
Platform: data.Platform,
|
||||||
Name: data.Name,
|
Name: data.Name,
|
||||||
Value: data.Value,
|
Value: data.Value,
|
||||||
Enabled: data.Enabled,
|
Enabled: data.Enabled,
|
||||||
SortNum: data.SortNum,
|
SortNum: data.SortNum,
|
||||||
Open: data.Open,
|
Open: data.Open,
|
||||||
Weight: data.Weight}
|
MaxTokens: data.MaxTokens,
|
||||||
|
MaxContext: data.MaxContext,
|
||||||
|
Temperature: data.Temperature,
|
||||||
|
Power: data.Power}
|
||||||
item.Id = data.Id
|
item.Id = data.Id
|
||||||
if item.Id > 0 {
|
if item.Id > 0 {
|
||||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
}
|
}
|
||||||
res := h.db.Save(&item)
|
res := h.DB.Save(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -72,7 +75,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
|
|||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||||
session := h.db.Session(&gorm.Session{})
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
enable := h.GetBool(c, "enable")
|
enable := h.GetBool(c, "enable")
|
||||||
if enable {
|
if enable {
|
||||||
session = session.Where("enabled", enable)
|
session = session.Where("enabled", enable)
|
||||||
@@ -109,7 +117,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -129,7 +137,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -141,13 +149,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *ChatModelHandler) Remove(c *gin.Context) {
|
func (h *ChatModelHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
if id <= 0 {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if id > 0 {
|
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
|
if res.Error != nil {
|
||||||
if res.Error != nil {
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,13 +15,10 @@ import (
|
|||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatRoleHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||||
h := ChatRoleHandler{db: db}
|
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save 创建或者更新某个角色
|
// Save 创建或者更新某个角色
|
||||||
@@ -41,7 +38,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
|||||||
if data.CreatedAt > 0 {
|
if data.CreatedAt > 0 {
|
||||||
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
}
|
}
|
||||||
res := h.db.Save(&role)
|
res := h.DB.Save(&role)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -53,9 +50,14 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var items []model.ChatRole
|
var items []model.ChatRole
|
||||||
var roles = make([]vo.ChatRole, 0)
|
var roles = make([]vo.ChatRole, 0)
|
||||||
res := h.db.Order("sort_num ASC").Find(&items)
|
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No data found")
|
resp.ERROR(c, "No data found")
|
||||||
return
|
return
|
||||||
@@ -88,7 +90,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -110,7 +112,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -119,13 +121,18 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
var data struct {
|
||||||
if id <= 0 {
|
Id uint
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if data.Id <= 0 {
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "删除失败!")
|
resp.ERROR(c, "删除失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -14,19 +14,20 @@ import (
|
|||||||
|
|
||||||
type ConfigHandler struct {
|
type ConfigHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
||||||
h := ConfigHandler{db: db}
|
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ConfigHandler) Update(c *gin.Context) {
|
func (h *ConfigHandler) Update(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Config map[string]interface{} `json:"config"`
|
Config struct {
|
||||||
|
types.SystemConfig
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Updated bool `json:"updated,omitempty"`
|
||||||
|
} `json:"config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
@@ -36,7 +37,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
value := utils.JsonEncode(&data.Config)
|
value := utils.JsonEncode(&data.Config)
|
||||||
config := model.Config{Key: data.Key, Config: value}
|
config := model.Config{Key: data.Key, Config: value}
|
||||||
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
|
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -44,7 +45,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
if config.Id > 0 {
|
if config.Id > 0 {
|
||||||
config.Config = value
|
config.Config = value
|
||||||
res := h.db.Updates(&config)
|
res := h.DB.Updates(&config)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -52,12 +53,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
// update config cache for AppServer
|
// update config cache for AppServer
|
||||||
var cfg model.Config
|
var cfg model.Config
|
||||||
h.db.Where("marker", data.Key).First(&cfg)
|
h.DB.Where("marker", data.Key).First(&cfg)
|
||||||
var err error
|
var err error
|
||||||
if data.Key == "system" {
|
if data.Key == "system" {
|
||||||
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
||||||
} else if data.Key == "chat" {
|
|
||||||
err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
||||||
@@ -71,20 +70,25 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
// Get 获取指定的系统配置
|
// Get 获取指定的系统配置
|
||||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
key := c.Query("key")
|
key := c.Query("key")
|
||||||
var config model.Config
|
var config model.Config
|
||||||
res := h.db.Where("marker", key).First(&config)
|
res := h.DB.Where("marker", key).First(&config)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var m map[string]interface{}
|
var value map[string]interface{}
|
||||||
err := utils.JsonDecode(config.Config, &m)
|
err := utils.JsonDecode(config.Config, &value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, m)
|
resp.SUCCESS(c, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,26 +7,25 @@ import (
|
|||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
||||||
h := DashboardHandler{db: db}
|
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type statsVo struct {
|
type statsVo struct {
|
||||||
Users int64 `json:"users"`
|
Users int64 `json:"users"`
|
||||||
Chats int64 `json:"chats"`
|
Chats int64 `json:"chats"`
|
||||||
Tokens int `json:"tokens"`
|
Tokens int `json:"tokens"`
|
||||||
Income float64 `json:"income"`
|
Income float64 `json:"income"`
|
||||||
|
Chart map[string]map[string]float64 `json:"chart"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DashboardHandler) Stats(c *gin.Context) {
|
func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||||
@@ -35,37 +34,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
|||||||
var userCount int64
|
var userCount int64
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||||
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
stats.Users = userCount
|
stats.Users = userCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// new chats statistic
|
// new chats statistic
|
||||||
var chatCount int64
|
var chatCount int64
|
||||||
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
stats.Chats = chatCount
|
stats.Chats = chatCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokens took stats
|
// tokens took stats
|
||||||
var historyMessages []model.ChatMessage
|
var historyMessages []model.ChatMessage
|
||||||
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
||||||
for _, item := range historyMessages {
|
for _, item := range historyMessages {
|
||||||
stats.Tokens += item.Tokens
|
stats.Tokens += item.Tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// 众筹收入
|
// 众筹收入
|
||||||
var rewards []model.Reward
|
var rewards []model.Reward
|
||||||
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards)
|
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
|
||||||
for _, item := range rewards {
|
for _, item := range rewards {
|
||||||
stats.Income += item.Amount
|
stats.Income += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
// 订单收入
|
// 订单收入
|
||||||
var orders []model.Order
|
var orders []model.Order
|
||||||
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
||||||
for _, item := range orders {
|
for _, item := range orders {
|
||||||
stats.Income += item.Amount
|
stats.Income += item.Amount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 统计7天的订单的图表
|
||||||
|
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
|
||||||
|
var statsChart = make(map[string]map[string]float64)
|
||||||
|
//// 初始化
|
||||||
|
var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
|
||||||
|
for i := 0; i < 7; i++ {
|
||||||
|
var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
|
||||||
|
userStatistic[initTime] = float64(0)
|
||||||
|
historyMessagesStatistic[initTime] = float64(0)
|
||||||
|
incomeStatistic[initTime] = float64(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计用户7天增加的曲线
|
||||||
|
var users []model.User
|
||||||
|
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range users {
|
||||||
|
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计7天Token 消耗
|
||||||
|
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
|
||||||
|
for _, item := range historyMessages {
|
||||||
|
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 浮点数相加?
|
||||||
|
// 统计最近7天的众筹
|
||||||
|
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
|
||||||
|
for _, item := range rewards {
|
||||||
|
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计最近7天的订单
|
||||||
|
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
||||||
|
for _, item := range orders {
|
||||||
|
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
statsChart["users"] = userStatistic
|
||||||
|
statsChart["historyMessage"] = historyMessagesStatistic
|
||||||
|
statsChart["orders"] = incomeStatistic
|
||||||
|
|
||||||
|
stats.Chart = statsChart
|
||||||
|
|
||||||
resp.SUCCESS(c, stats)
|
resp.SUCCESS(c, stats)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,13 +17,10 @@ import (
|
|||||||
|
|
||||||
type FunctionHandler struct {
|
type FunctionHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
||||||
h := FunctionHandler{db: db}
|
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *FunctionHandler) Save(c *gin.Context) {
|
func (h *FunctionHandler) Save(c *gin.Context) {
|
||||||
@@ -44,7 +41,7 @@ func (h *FunctionHandler) Save(c *gin.Context) {
|
|||||||
Enabled: data.Enabled,
|
Enabled: data.Enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Save(&f)
|
res := h.DB.Save(&f)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "error with save data:"+res.Error.Error())
|
resp.ERROR(c, "error with save data:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -65,7 +62,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -74,8 +71,13 @@ func (h *FunctionHandler) Set(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *FunctionHandler) List(c *gin.Context) {
|
func (h *FunctionHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var items []model.Function
|
var items []model.Function
|
||||||
res := h.db.Find(&items)
|
res := h.DB.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No data found")
|
resp.ERROR(c, "No data found")
|
||||||
return
|
return
|
||||||
@@ -97,7 +99,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.db.Delete(&model.Function{Id: uint(id)})
|
res := h.DB.Delete(&model.Function{Id: uint(id)})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -8,24 +8,28 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OrderHandler struct {
|
type OrderHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||||
h := OrderHandler{db: db}
|
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OrderHandler) List(c *gin.Context) {
|
func (h *OrderHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
OrderNo string `json:"order_no"`
|
OrderNo string `json:"order_no"`
|
||||||
|
Status int `json:"status"`
|
||||||
PayTime []string `json:"pay_time"`
|
PayTime []string `json:"pay_time"`
|
||||||
Page int `json:"page"`
|
Page int `json:"page"`
|
||||||
PageSize int `json:"page_size"`
|
PageSize int `json:"page_size"`
|
||||||
@@ -35,7 +39,7 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if data.OrderNo != "" {
|
if data.OrderNo != "" {
|
||||||
session = session.Where("order_no", data.OrderNo)
|
session = session.Where("order_no", data.OrderNo)
|
||||||
}
|
}
|
||||||
@@ -44,8 +48,9 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
|
end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
|
||||||
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
|
session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
|
||||||
}
|
}
|
||||||
session = session.Where("status = ?", types.OrderPaidSuccess)
|
if data.Status >= 0 {
|
||||||
|
session = session.Where("status", data.Status)
|
||||||
|
}
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.Order{}).Count(&total)
|
session.Model(&model.Order{}).Count(&total)
|
||||||
var items []model.Order
|
var items []model.Order
|
||||||
@@ -74,7 +79,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
|||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
var item model.Order
|
var item model.Order
|
||||||
res := h.db.First(&item, id)
|
res := h.DB.First(&item, id)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "记录不存在!")
|
resp.ERROR(c, "记录不存在!")
|
||||||
return
|
return
|
||||||
@@ -85,7 +90,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Where("id = ?", id).Delete(&model.Order{})
|
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
77
api/handler/admin/power_log_handler.go
Normal file
77
api/handler/admin/power_log_handler.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/handler"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/store/vo"
|
||||||
|
"chatplus/utils"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PowerLogHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||||
|
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Type int `json:"type"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Date []string `json:"date"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if data.Model != "" {
|
||||||
|
session = session.Where("model", data.Model)
|
||||||
|
}
|
||||||
|
if data.Type > 0 {
|
||||||
|
session = session.Where("type", data.Type)
|
||||||
|
}
|
||||||
|
if len(data.Date) == 2 {
|
||||||
|
start := data.Date[0] + " 00:00:00"
|
||||||
|
end := data.Date[1] + " 00:00:00"
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.PowerLog{}).Count(&total)
|
||||||
|
var items []model.PowerLog
|
||||||
|
var list = make([]vo.PowerLog, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range items {
|
||||||
|
var log vo.PowerLog
|
||||||
|
err := utils.CopyObject(item, &log)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Id = item.Id
|
||||||
|
log.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
log.TypeStr = item.Type.String()
|
||||||
|
list = append(list, log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计消费算力总和
|
||||||
|
var totalPower float64
|
||||||
|
if len(data.Date) == 2 {
|
||||||
|
session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
|
||||||
|
}
|
||||||
@@ -15,13 +15,10 @@ import (
|
|||||||
|
|
||||||
type ProductHandler struct {
|
type ProductHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||||
h := ProductHandler{db: db}
|
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ProductHandler) Save(c *gin.Context) {
|
func (h *ProductHandler) Save(c *gin.Context) {
|
||||||
@@ -32,8 +29,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
Discount float64 `json:"discount"`
|
Discount float64 `json:"discount"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Days int `json:"days"`
|
Days int `json:"days"`
|
||||||
Calls int `json:"calls"`
|
Power int `json:"power"`
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
@@ -46,14 +42,13 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
Price: data.Price,
|
Price: data.Price,
|
||||||
Discount: data.Discount,
|
Discount: data.Discount,
|
||||||
Days: data.Days,
|
Days: data.Days,
|
||||||
Calls: data.Calls,
|
Power: data.Power,
|
||||||
ImgCalls: data.ImgCalls,
|
|
||||||
Enabled: data.Enabled}
|
Enabled: data.Enabled}
|
||||||
item.Id = data.Id
|
item.Id = data.Id
|
||||||
if item.Id > 0 {
|
if item.Id > 0 {
|
||||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||||
}
|
}
|
||||||
res := h.db.Save(&item)
|
res := h.DB.Save(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -72,7 +67,12 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
|||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ProductHandler) List(c *gin.Context) {
|
func (h *ProductHandler) List(c *gin.Context) {
|
||||||
session := h.db.Session(&gorm.Session{})
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
enable := h.GetBool(c, "enable")
|
enable := h.GetBool(c, "enable")
|
||||||
if enable {
|
if enable {
|
||||||
session = session.Where("enabled", enable)
|
session = session.Where("enabled", enable)
|
||||||
@@ -108,7 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
|
res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -128,7 +128,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index, id := range data.Ids {
|
for index, id := range data.Ids {
|
||||||
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
@@ -142,7 +142,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
|
|||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
|
|
||||||
if id > 0 {
|
if id > 0 {
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.Product{})
|
res := h.DB.Where("id = ?", id).Delete(&model.Product{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
|
"chatplus/core/types"
|
||||||
"chatplus/handler"
|
"chatplus/handler"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
@@ -13,18 +14,20 @@ import (
|
|||||||
|
|
||||||
type RewardHandler struct {
|
type RewardHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
||||||
h := RewardHandler{db: db}
|
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RewardHandler) List(c *gin.Context) {
|
func (h *RewardHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var items []model.Reward
|
var items []model.Reward
|
||||||
res := h.db.Order("id DESC").Find(&items)
|
res := h.DB.Order("id DESC").Find(&items)
|
||||||
var rewards = make([]vo.Reward, 0)
|
var rewards = make([]vo.Reward, 0)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
userIds := make([]uint, 0)
|
userIds := make([]uint, 0)
|
||||||
@@ -32,7 +35,7 @@ func (h *RewardHandler) List(c *gin.Context) {
|
|||||||
userIds = append(userIds, v.UserId)
|
userIds = append(userIds, v.UserId)
|
||||||
}
|
}
|
||||||
var users []model.User
|
var users []model.User
|
||||||
h.db.Where("id IN ?", userIds).Find(&users)
|
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||||
var userMap = make(map[uint]model.User)
|
var userMap = make(map[uint]model.User)
|
||||||
for _, u := range users {
|
for _, u := range users {
|
||||||
userMap[u.Id] = u
|
userMap[u.Id] = u
|
||||||
@@ -57,10 +60,15 @@ func (h *RewardHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RewardHandler) Remove(c *gin.Context) {
|
func (h *RewardHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
var data struct {
|
||||||
|
Id uint
|
||||||
if id > 0 {
|
}
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.Reward{})
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data.Id > 0 {
|
||||||
|
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
|||||||
45
api/handler/admin/upload_handler.go
Normal file
45
api/handler/admin/upload_handler.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/handler"
|
||||||
|
"chatplus/service/oss"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UploadHandler struct {
|
||||||
|
handler.BaseHandler
|
||||||
|
uploaderManager *oss.UploaderManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
||||||
|
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||||
|
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId := 0
|
||||||
|
res := h.DB.Create(&model.File{
|
||||||
|
UserId: userId,
|
||||||
|
Name: file.Name,
|
||||||
|
ObjKey: file.ObjKey,
|
||||||
|
URL: file.URL,
|
||||||
|
Ext: file.Ext,
|
||||||
|
Size: file.Size,
|
||||||
|
CreatedAt: time.Time{},
|
||||||
|
})
|
||||||
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "error with update database: "+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, file)
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -16,17 +17,19 @@ import (
|
|||||||
|
|
||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
|
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
|
||||||
h := UserHandler{db: db}
|
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 用户列表
|
// List 用户列表
|
||||||
func (h *UserHandler) List(c *gin.Context) {
|
func (h *UserHandler) List(c *gin.Context) {
|
||||||
|
if err := utils.CheckPermission(c, h.DB); err != nil {
|
||||||
|
resp.NotPermission(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
pageSize := h.GetInt(c, "page_size", 20)
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
username := h.GetTrim(c, "username")
|
username := h.GetTrim(c, "username")
|
||||||
@@ -36,7 +39,7 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
var users = make([]vo.User, 0)
|
var users = make([]vo.User, 0)
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.DB.Session(&gorm.Session{})
|
||||||
if username != "" {
|
if username != "" {
|
||||||
session = session.Where("username LIKE ?", "%"+username+"%")
|
session = session.Where("username LIKE ?", "%"+username+"%")
|
||||||
}
|
}
|
||||||
@@ -66,13 +69,12 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Calls int `json:"calls"`
|
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
ChatRoles []string `json:"chat_roles"`
|
ChatRoles []string `json:"chat_roles"`
|
||||||
ChatModels []string `json:"chat_models"`
|
ChatModels []int `json:"chat_models"`
|
||||||
ExpiredTime string `json:"expired_time"`
|
ExpiredTime string `json:"expired_time"`
|
||||||
Status bool `json:"status"`
|
Status bool `json:"status"`
|
||||||
Vip bool `json:"vip"`
|
Vip bool `json:"vip"`
|
||||||
|
Power int `json:"power"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
@@ -82,18 +84,45 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
var res *gorm.DB
|
var res *gorm.DB
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
if data.Id > 0 { // 更新
|
if data.Id > 0 { // 更新
|
||||||
user.Id = data.Id
|
res = h.DB.Where("id", data.Id).First(&user)
|
||||||
// 此处需要用 map 更新,用结构体无法更新 0 值
|
if res.Error != nil {
|
||||||
res = h.db.Model(&user).Updates(map[string]interface{}{
|
resp.ERROR(c, "user not found")
|
||||||
"username": data.Username,
|
return
|
||||||
"calls": data.Calls,
|
}
|
||||||
"img_calls": data.ImgCalls,
|
var oldPower = user.Power
|
||||||
"status": data.Status,
|
user.Username = data.Username
|
||||||
"vip": data.Vip,
|
user.Status = data.Status
|
||||||
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
|
user.Vip = data.Vip
|
||||||
"chat_models_json": utils.JsonEncode(data.ChatModels),
|
user.Power = data.Power
|
||||||
"expired_time": utils.Str2stamp(data.ExpiredTime),
|
user.ChatRoles = utils.JsonEncode(data.ChatRoles)
|
||||||
})
|
user.ChatModels = utils.JsonEncode(data.ChatModels)
|
||||||
|
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
|
||||||
|
|
||||||
|
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 记录算力日志
|
||||||
|
if oldPower != user.Power {
|
||||||
|
mark := types.PowerAdd
|
||||||
|
amount := user.Power - oldPower
|
||||||
|
if oldPower > user.Power {
|
||||||
|
mark = types.PowerSub
|
||||||
|
amount = oldPower - user.Power
|
||||||
|
}
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerGift,
|
||||||
|
Amount: amount,
|
||||||
|
Balance: user.Power,
|
||||||
|
Mark: mark,
|
||||||
|
Model: "管理员",
|
||||||
|
Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID:%d", oldPower, user.Power, h.GetLoginUserId(c)),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
salt := utils.RandString(8)
|
salt := utils.RandString(8)
|
||||||
u := model.User{
|
u := model.User{
|
||||||
@@ -102,21 +131,13 @@ func (h *UserHandler) Save(c *gin.Context) {
|
|||||||
Password: utils.GenPassword(data.Password, salt),
|
Password: utils.GenPassword(data.Password, salt),
|
||||||
Avatar: "/images/avatar/user.png",
|
Avatar: "/images/avatar/user.png",
|
||||||
Salt: salt,
|
Salt: salt,
|
||||||
|
Power: data.Power,
|
||||||
Status: true,
|
Status: true,
|
||||||
ChatRoles: utils.JsonEncode(data.ChatRoles),
|
ChatRoles: utils.JsonEncode(data.ChatRoles),
|
||||||
ChatModels: utils.JsonEncode(data.ChatModels),
|
ChatModels: utils.JsonEncode(data.ChatModels),
|
||||||
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
|
||||||
ChatConfig: utils.JsonEncode(types.UserChatConfig{
|
|
||||||
ApiKeys: map[types.Platform]string{
|
|
||||||
types.OpenAI: "",
|
|
||||||
types.Azure: "",
|
|
||||||
types.ChatGLM: "",
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
Calls: data.Calls,
|
|
||||||
ImgCalls: data.ImgCalls,
|
|
||||||
}
|
}
|
||||||
res = h.db.Create(&u)
|
res = h.DB.Create(&u)
|
||||||
_ = utils.CopyObject(u, &userVo)
|
_ = utils.CopyObject(u, &userVo)
|
||||||
userVo.Id = u.Id
|
userVo.Id = u.Id
|
||||||
userVo.CreatedAt = u.CreatedAt.Unix()
|
userVo.CreatedAt = u.CreatedAt.Unix()
|
||||||
@@ -143,7 +164,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.First(&user, data.Id)
|
res := h.DB.First(&user, data.Id)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No user found")
|
resp.ERROR(c, "No user found")
|
||||||
return
|
return
|
||||||
@@ -151,7 +172,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
user.Password = password
|
user.Password = password
|
||||||
res = h.db.Updates(&user)
|
res = h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c)
|
resp.ERROR(c)
|
||||||
} else {
|
} else {
|
||||||
@@ -161,36 +182,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *UserHandler) Remove(c *gin.Context) {
|
func (h *UserHandler) Remove(c *gin.Context) {
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
if id > 0 {
|
if id <= 0 {
|
||||||
tx := h.db.Begin()
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
res := h.db.Where("id = ?", id).Delete(&model.User{})
|
return
|
||||||
if res.Error != nil {
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 删除聊天记录
|
|
||||||
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 删除聊天历史记录
|
|
||||||
res = h.db.Where("user_id = ?", id).Delete(&model.ChatMessage{})
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 删除登录日志
|
|
||||||
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
resp.ERROR(c, "删除失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
// 删除用户
|
||||||
|
res := h.DB.Where("id = ?", id).Delete(&model.User{})
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "删除失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除聊天记录
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
|
||||||
|
// 删除聊天历史记录
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
|
||||||
|
// 删除登录日志
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
|
||||||
|
// 删除算力日志
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
|
||||||
|
// 删除众筹日志
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
|
||||||
|
// 删除绘图任务
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
|
||||||
|
// 删除订单
|
||||||
|
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,10 +215,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
|
|||||||
page := h.GetInt(c, "page", 1)
|
page := h.GetInt(c, "page", 1)
|
||||||
pageSize := h.GetInt(c, "page_size", 20)
|
pageSize := h.GetInt(c, "page_size", 20)
|
||||||
var total int64
|
var total int64
|
||||||
h.db.Model(&model.UserLoginLog{}).Count(&total)
|
h.DB.Model(&model.UserLoginLog{}).Count(&total)
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
var items []model.UserLoginLog
|
var items []model.UserLoginLog
|
||||||
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
|
res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "获取数据失败")
|
resp.ERROR(c, "获取数据失败")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ import (
|
|||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
logger2 "chatplus/logger"
|
logger2 "chatplus/logger"
|
||||||
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"gorm.io/gorm"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -15,6 +18,7 @@ var logger = logger2.GetLogger()
|
|||||||
|
|
||||||
type BaseHandler struct {
|
type BaseHandler struct {
|
||||||
App *core.AppServer
|
App *core.AppServer
|
||||||
|
DB *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
|
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
|
||||||
@@ -57,3 +61,27 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
|
|||||||
}
|
}
|
||||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
|
||||||
|
return h.GetLoginUserId(c) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
|
||||||
|
value, exists := c.Get(types.LoginUserCache)
|
||||||
|
if exists {
|
||||||
|
return value.(model.User), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userId, ok := c.Get(types.LoginUserID)
|
||||||
|
if !ok {
|
||||||
|
return model.User{}, errors.New("user not login")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.User
|
||||||
|
res := h.DB.First(&user, userId)
|
||||||
|
// 更新缓存
|
||||||
|
if res.Error == nil {
|
||||||
|
c.Set(types.LoginUserCache, user)
|
||||||
|
}
|
||||||
|
return user, res.Error
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,3 +45,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SlideGet 获取滑动验证图片
|
||||||
|
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||||
|
data, err := h.service.SlideGet()
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlideCheck 滑动验证结果校验
|
||||||
|
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
X int `json:"x"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.service.SlideCheck(data) {
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
} else {
|
||||||
|
resp.ERROR(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,37 +12,34 @@ import (
|
|||||||
|
|
||||||
type ChatModelHandler struct {
|
type ChatModelHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||||
h := ChatModelHandler{db: db}
|
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||||
var items []model.ChatModel
|
var items []model.ChatModel
|
||||||
var chatModels = make([]vo.ChatModel, 0)
|
var chatModels = make([]vo.ChatModel, 0)
|
||||||
// 只加载用户订阅的 AI 模型
|
var res *gorm.DB
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
// 如果用户没有登录,则加载所有开放模型
|
||||||
if err != nil {
|
if !h.IsLogin(c) {
|
||||||
resp.NotAuth(c)
|
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
|
||||||
return
|
} else {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
var models []int
|
||||||
|
err := utils.JsonDecode(user.ChatModels, &models)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "当前用户没有订阅任何模型")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 查询用户有权限访问的模型以及所有开放的模型
|
||||||
|
res = h.DB.Where("enabled = ?", true).Where(
|
||||||
|
h.DB.Where("id IN ?", models).Or("open =?", true),
|
||||||
|
).Order("sort_num ASC").Find(&items)
|
||||||
}
|
}
|
||||||
|
|
||||||
var models []string
|
|
||||||
err = utils.JsonDecode(user.ChatModels, &models)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "当前用户没有订阅任何模型")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查询用户有权限访问的模型以及所有开放的模型
|
|
||||||
res := h.db.Where("enabled = ?", true).Where(
|
|
||||||
h.db.Where("value IN ?", models).Or("open =?", true),
|
|
||||||
).Order("sort_num ASC").Find(&items)
|
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var cm vo.ChatModel
|
var cm vo.ChatModel
|
||||||
|
|||||||
@@ -14,27 +14,25 @@ import (
|
|||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatRoleHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||||
handler := &ChatRoleHandler{db: db}
|
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
handler.App = app
|
|
||||||
return handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List get user list
|
// List 获取用户聊天应用列表
|
||||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||||
all := h.GetBool(c, "all")
|
all := h.GetBool(c, "all")
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatRole
|
||||||
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No roles found,"+res.Error.Error())
|
resp.ERROR(c, "No roles found,"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取所有角色
|
// 获取所有角色
|
||||||
if all {
|
if userId == 0 || all {
|
||||||
// 转成 vo
|
// 转成 vo
|
||||||
var roleVos = make([]vo.ChatRole, 0)
|
var roleVos = make([]vo.ChatRole, 0)
|
||||||
for _, r := range roles {
|
for _, r := range roles {
|
||||||
@@ -49,13 +47,8 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
|
||||||
if userId == 0 {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var user model.User
|
var user model.User
|
||||||
h.db.First(&user, userId)
|
h.DB.First(&user, userId)
|
||||||
var roleKeys []string
|
var roleKeys []string
|
||||||
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -80,7 +73,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
// UpdateRole 更新用户聊天角色
|
// UpdateRole 更新用户聊天角色
|
||||||
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
@@ -94,7 +87,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
|
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("添加应用失败:", err)
|
logger.Error("添加应用失败:", err)
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
// 微软 Azure 模型消息发送实现
|
// 微软 Azure 模型消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendAzureMessage(
|
func (h *ChatHandler) sendAzureMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -103,8 +103,6 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
@@ -113,66 +111,64 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
// for prompt
|
||||||
// for prompt
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
if err != nil {
|
||||||
if err != nil {
|
logger.Error(err)
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens += getTotalTokens(req)
|
|
||||||
|
|
||||||
historyReplyMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
|
replyTokens += getTotalTokens(req)
|
||||||
|
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: replyTokens,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = session.UserId
|
chatItem.UserId = session.UserId
|
||||||
@@ -184,7 +180,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
chatItem.Title = prompt
|
chatItem.Title = prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.db.Create(&chatItem)
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ type baiduResp struct {
|
|||||||
// 百度文心一言消息发送实现
|
// 百度文心一言消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendBaiduMessage(
|
func (h *ChatHandler) sendBaiduMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
}
|
}
|
||||||
@@ -138,65 +135,63 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
// for prompt
|
||||||
// for prompt
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
if err != nil {
|
||||||
if err != nil {
|
logger.Error(err)
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for reply
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens := replyTokens + getTotalTokens(req)
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: totalTokens,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = session.UserId
|
chatItem.UserId = session.UserId
|
||||||
@@ -208,7 +203,7 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
chatItem.Title = prompt
|
chatItem.Title = prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.db.Create(&chatItem)
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -35,19 +35,16 @@ var logger = logger2.GetLogger()
|
|||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
|
||||||
h := ChatHandler{
|
return &ChatHandler{
|
||||||
db: db,
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
redis: redis,
|
redis: redis,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
}
|
}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) Init() {
|
func (h *ChatHandler) Init() {
|
||||||
@@ -57,8 +54,6 @@ func (h *ChatHandler) Init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var chatConfig types.ChatConfig
|
|
||||||
|
|
||||||
// ChatHandle 处理聊天 WebSocket 请求
|
// ChatHandle 处理聊天 WebSocket 请求
|
||||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
@@ -75,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
client := types.NewWsClient(ws)
|
client := types.NewWsClient(ws)
|
||||||
// get model info
|
// get model info
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
res := h.db.First(&chatModel, modelId)
|
res := h.DB.First(&chatModel, modelId)
|
||||||
if res.Error != nil || chatModel.Enabled == false {
|
if res.Error != nil || chatModel.Enabled == false {
|
||||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@@ -84,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
|
|
||||||
session := h.App.ChatSession.Get(sessionId)
|
session := h.App.ChatSession.Get(sessionId)
|
||||||
if session == nil {
|
if session == nil {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("用户未登录")
|
logger.Info("用户未登录")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@@ -101,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
|
|
||||||
// use old chat data override the chat model and role ID
|
// use old chat data override the chat model and role ID
|
||||||
var chat model.ChatItem
|
var chat model.ChatItem
|
||||||
res = h.db.Where("chat_id = ?", chatId).First(&chat)
|
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
chatModel.Id = chat.ModelId
|
chatModel.Id = chat.ModelId
|
||||||
roleId = int(chat.RoleId)
|
roleId = int(chat.RoleId)
|
||||||
@@ -109,28 +104,24 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
|
|
||||||
session.ChatId = chatId
|
session.ChatId = chatId
|
||||||
session.Model = types.ChatModel{
|
session.Model = types.ChatModel{
|
||||||
Id: chatModel.Id,
|
Id: chatModel.Id,
|
||||||
Value: chatModel.Value,
|
Name: chatModel.Name,
|
||||||
Weight: chatModel.Weight,
|
Value: chatModel.Value,
|
||||||
Platform: types.Platform(chatModel.Platform)}
|
Power: chatModel.Power,
|
||||||
|
MaxTokens: chatModel.MaxTokens,
|
||||||
|
MaxContext: chatModel.MaxContext,
|
||||||
|
Temperature: chatModel.Temperature,
|
||||||
|
Platform: types.Platform(chatModel.Platform)}
|
||||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||||
var chatRole model.ChatRole
|
var chatRole model.ChatRole
|
||||||
res = h.db.First(&chatRole, roleId)
|
res = h.DB.First(&chatRole, roleId)
|
||||||
if res.Error != nil || !chatRole.Enable {
|
if res.Error != nil || !chatRole.Enable {
|
||||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化聊天配置
|
h.Init()
|
||||||
var config model.Config
|
|
||||||
h.db.Where("marker", "chat").First(&config)
|
|
||||||
err = utils.JsonDecode(config.Config, &chatConfig)
|
|
||||||
if err != nil {
|
|
||||||
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存会话连接
|
// 保存会话连接
|
||||||
h.App.ChatClients.Put(sessionId, client)
|
h.App.ChatClients.Put(sessionId, client)
|
||||||
@@ -188,9 +179,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Model(&model.User{}).First(&user, session.UserId)
|
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
|
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!")
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
@@ -206,14 +197,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if userVo.Calls < session.Model.Weight {
|
if userVo.Power < session.Model.Power {
|
||||||
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d)已不足以支付当前模型的单次对话需要消耗的对话额度(%d)!", userVo.Calls, session.Model.Weight))
|
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型的单次对话需要消耗的算力(%d)!", userVo.Power, session.Model.Power))
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
|
||||||
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -223,35 +208,34 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||||
|
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||||
|
if promptTokens > session.Model.MaxContext {
|
||||||
|
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var req = types.ApiRequest{
|
var req = types.ApiRequest{
|
||||||
Model: session.Model.Value,
|
Model: session.Model.Value,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
}
|
||||||
switch session.Model.Platform {
|
switch session.Model.Platform {
|
||||||
case types.Azure:
|
case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
|
||||||
req.Temperature = h.App.ChatConfig.Azure.Temperature
|
req.Temperature = session.Model.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
|
req.MaxTokens = session.Model.MaxTokens
|
||||||
break
|
|
||||||
case types.ChatGLM:
|
|
||||||
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
|
|
||||||
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
|
|
||||||
break
|
|
||||||
case types.Baidu:
|
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
|
||||||
// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
|
|
||||||
break
|
break
|
||||||
case types.OpenAI:
|
case types.OpenAI:
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
req.Temperature = session.Model.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
req.MaxTokens = session.Model.MaxTokens
|
||||||
// OpenAI 支持函数功能
|
// OpenAI 支持函数功能
|
||||||
var items []model.Function
|
var items []model.Function
|
||||||
res := h.db.Where("enabled", true).Find(&items)
|
res := h.DB.Where("enabled", true).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools = make([]interface{}, 0)
|
var tools = make([]interface{}, 0)
|
||||||
var functions = make([]interface{}, 0)
|
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var parameters map[string]interface{}
|
var parameters map[string]interface{}
|
||||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||||
@@ -269,30 +253,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
"required": required,
|
"required": required,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
functions = append(functions, gin.H{
|
|
||||||
"name": v.Name,
|
|
||||||
"description": v.Description,
|
|
||||||
"parameters": parameters,
|
|
||||||
"required": required,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
// req.Tools = tools
|
req.Tools = tools
|
||||||
// req.ToolChoice = "auto"
|
req.ToolChoice = "auto"
|
||||||
//}
|
|
||||||
if len(functions) > 0 {
|
|
||||||
req.Functions = functions
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case types.XunFei:
|
|
||||||
req.Temperature = h.App.ChatConfig.XunFei.Temperature
|
|
||||||
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
|
|
||||||
break
|
|
||||||
case types.QWen:
|
case types.QWen:
|
||||||
req.Input = map[string]interface{}{"messages": []map[string]string{{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}}}
|
req.Parameters = map[string]interface{}{
|
||||||
req.Parameters = map[string]interface{}{}
|
"max_tokens": session.Model.MaxTokens,
|
||||||
|
"temperature": session.Model.Temperature,
|
||||||
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
default:
|
default:
|
||||||
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
@@ -300,40 +273,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 加载聊天上下文
|
// 加载聊天上下文
|
||||||
var chatCtx []interface{}
|
chatCtx := make([]types.Message, 0)
|
||||||
if h.App.ChatConfig.EnableContext {
|
messages := make([]types.Message, 0)
|
||||||
|
if h.App.SysConfig.EnableContext {
|
||||||
if h.App.ChatContexts.Has(session.ChatId) {
|
if h.App.ChatContexts.Has(session.ChatId) {
|
||||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
messages = h.App.ChatContexts.Get(session.ChatId)
|
||||||
} else {
|
} else {
|
||||||
// calculate the tokens of current request, to prevent to exceeding the max tokens num
|
_ = utils.JsonDecode(role.Context, &messages)
|
||||||
tokens := req.MaxTokens
|
if h.App.SysConfig.ContextDeep > 0 {
|
||||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
|
||||||
tokens += tks
|
|
||||||
// loading the role context
|
|
||||||
var messages []types.Message
|
|
||||||
err := utils.JsonDecode(role.Context, &messages)
|
|
||||||
if err == nil {
|
|
||||||
for _, v := range messages {
|
|
||||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
|
||||||
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens += tks
|
|
||||||
chatCtx = append(chatCtx, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loading recent chat history as chat context
|
|
||||||
if chatConfig.ContextDeep > 0 {
|
|
||||||
var historyMessages []model.ChatMessage
|
var historyMessages []model.ChatMessage
|
||||||
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
|
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||||
msg := historyMessages[i]
|
msg := historyMessages[i]
|
||||||
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens += msg.Tokens
|
|
||||||
ms := types.Message{Role: "user", Content: msg.Content}
|
ms := types.Message{Role: "user", Content: msg.Content}
|
||||||
if msg.Type == types.ReplyMsg {
|
if msg.Type == types.ReplyMsg {
|
||||||
ms.Role = "assistant"
|
ms.Role = "assistant"
|
||||||
@@ -343,6 +295,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
|
||||||
|
// MaxContextLength = Response + Tool + Prompt + Context
|
||||||
|
tokens := req.MaxTokens // 最大响应长度
|
||||||
|
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||||
|
tokens += tks + promptTokens
|
||||||
|
|
||||||
|
for _, v := range messages {
|
||||||
|
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||||
|
// 上下文 token 超出了模型的最大上下文长度
|
||||||
|
if tokens+tks >= session.Model.MaxContext {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上下文的深度超出了模型的最大上下文深度
|
||||||
|
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens += tks
|
||||||
|
chatCtx = append(chatCtx, v)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||||
}
|
}
|
||||||
reqMgs := make([]interface{}, 0)
|
reqMgs := make([]interface{}, 0)
|
||||||
@@ -350,10 +325,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
reqMgs = append(reqMgs, m)
|
reqMgs = append(reqMgs, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Messages = append(reqMgs, map[string]interface{}{
|
if session.Model.Platform == types.QWen {
|
||||||
"role": "user",
|
req.Input = map[string]interface{}{"prompt": prompt}
|
||||||
"content": prompt,
|
if len(reqMgs) > 0 {
|
||||||
})
|
req.Input["messages"] = reqMgs
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
req.Messages = append(reqMgs, map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
switch session.Model.Platform {
|
switch session.Model.Platform {
|
||||||
case types.Azure:
|
case types.Azure:
|
||||||
@@ -392,7 +374,7 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
|||||||
if data.Text == "" && data.ChatId != "" {
|
if data.Text == "" && data.ChatId != "" {
|
||||||
var item model.ChatMessage
|
var item model.ChatMessage
|
||||||
userId, _ := c.Get(types.LoginUserID)
|
userId, _ := c.Get(types.LoginUserID)
|
||||||
res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
|
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -443,7 +425,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
|||||||
// 发送请求到 OpenAI 服务器
|
// 发送请求到 OpenAI 服务器
|
||||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
||||||
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return nil, errors.New("no available key, please import key")
|
return nil, errors.New("no available key, please import key")
|
||||||
}
|
}
|
||||||
@@ -469,7 +451,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
apiURL = apiKey.ApiURL
|
apiURL = apiKey.ApiURL
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
// 百度文心,需要串接 access_token
|
// 百度文心,需要串接 access_token
|
||||||
if platform == types.Baidu {
|
if platform == types.Baidu {
|
||||||
token, err := h.getBaiduToken(apiKey.Value)
|
token, err := h.getBaiduToken(apiKey.Value)
|
||||||
@@ -496,9 +478,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if h.App.Config.ProxyURL != "" && apiKey.UseProxy { // 使用代理
|
if apiKey.ProxyURL != "" { // 使用代理
|
||||||
proxyURL = h.App.Config.ProxyURL
|
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||||
proxy, _ := url.Parse(proxyURL)
|
|
||||||
client = &http.Client{
|
client = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Proxy: http.ProxyURL(proxy),
|
Proxy: http.ProxyURL(proxy),
|
||||||
@@ -532,23 +513,30 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
return client.Do(request)
|
return client.Do(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 扣减用户的对话次数
|
// 扣减用户算力
|
||||||
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
|
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
|
||||||
// 仅当用户没有导入自己的 API KEY 时才进行扣减
|
power := 1
|
||||||
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
if session.Model.Power > 0 {
|
||||||
num := 1
|
power = session.Model.Power
|
||||||
if session.Model.Weight > 0 {
|
}
|
||||||
num = session.Model.Weight
|
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
|
||||||
}
|
if res.Error == nil {
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
|
// 记录算力消费日志
|
||||||
|
var u model.User
|
||||||
|
h.DB.Where("id", userVo.Id).First(&u)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
Username: userVo.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Balance: u.Power,
|
||||||
|
Model: session.Model.Value,
|
||||||
|
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userId).
|
|
||||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userId).
|
|
||||||
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 将AI回复消息中生成的图片链接下载到本地
|
// 将AI回复消息中生成的图片链接下载到本地
|
||||||
|
|||||||
@@ -6,27 +6,29 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// List 获取会话列表
|
// List 获取会话列表
|
||||||
func (h *ChatHandler) List(c *gin.Context) {
|
func (h *ChatHandler) List(c *gin.Context) {
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
if !h.IsLogin(c) {
|
||||||
if userId == 0 {
|
resp.SUCCESS(c)
|
||||||
resp.ERROR(c, "The parameter 'user_id' is needed.")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
var items = make([]vo.ChatItem, 0)
|
var items = make([]vo.ChatItem, 0)
|
||||||
var chats []model.ChatItem
|
var chats []model.ChatItem
|
||||||
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
var roleIds = make([]uint, 0)
|
var roleIds = make([]uint, 0)
|
||||||
for _, chat := range chats {
|
for _, chat := range chats {
|
||||||
roleIds = append(roleIds, chat.RoleId)
|
roleIds = append(roleIds, chat.RoleId)
|
||||||
}
|
}
|
||||||
var roles []model.ChatRole
|
var roles []model.ChatRole
|
||||||
res = h.db.Find(&roles, roleIds)
|
res = h.DB.Find(&roles, roleIds)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
roleMap := make(map[uint]model.ChatRole)
|
roleMap := make(map[uint]model.ChatRole)
|
||||||
for _, role := range roles {
|
for _, role := range roles {
|
||||||
@@ -58,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
|
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Failed to update database")
|
resp.ERROR(c, "Failed to update database")
|
||||||
return
|
return
|
||||||
@@ -70,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
|
|||||||
// Clear 清空所有聊天记录
|
// Clear 清空所有聊天记录
|
||||||
func (h *ChatHandler) Clear(c *gin.Context) {
|
func (h *ChatHandler) Clear(c *gin.Context) {
|
||||||
// 获取当前登录用户所有的聊天会话
|
// 获取当前登录用户所有的聊天会话
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var chats []model.ChatItem
|
var chats []model.ChatItem
|
||||||
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
|
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No chats found")
|
resp.ERROR(c, "No chats found")
|
||||||
return
|
return
|
||||||
@@ -89,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
|||||||
// 清空会话上下文
|
// 清空会话上下文
|
||||||
h.App.ChatContexts.Delete(chat.ChatId)
|
h.App.ChatContexts.Delete(chat.ChatId)
|
||||||
}
|
}
|
||||||
err = h.db.Transaction(func(tx *gorm.DB) error {
|
err = h.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
|
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
@@ -118,7 +120,7 @@ func (h *ChatHandler) History(c *gin.Context) {
|
|||||||
chatId := c.Query("chat_id") // 会话 ID
|
chatId := c.Query("chat_id") // 会话 ID
|
||||||
var items []model.ChatMessage
|
var items []model.ChatMessage
|
||||||
var messages = make([]vo.HistoryMessage, 0)
|
var messages = make([]vo.HistoryMessage, 0)
|
||||||
res := h.db.Where("chat_id = ?", chatId).Find(&items)
|
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No history message")
|
resp.ERROR(c, "No history message")
|
||||||
return
|
return
|
||||||
@@ -144,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
|
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Failed to update database")
|
resp.ERROR(c, "Failed to update database")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 删除当前会话的聊天记录
|
// 删除当前会话的聊天记录
|
||||||
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
|
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Failed to remove chat from database.")
|
resp.ERROR(c, "Failed to remove chat from database.")
|
||||||
return
|
return
|
||||||
@@ -179,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
|
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No chat found")
|
resp.ERROR(c, "No chat found")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
// 清华大学 ChatGML 消息发送实现
|
// 清华大学 ChatGML 消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendChatGLMMessage(
|
func (h *ChatHandler) sendChatGLMMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -107,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
}
|
}
|
||||||
@@ -117,65 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
// for prompt
|
||||||
// for prompt
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
if err != nil {
|
||||||
if err != nil {
|
logger.Error(err)
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for reply
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens := replyTokens + getTotalTokens(req)
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: totalTokens,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = session.UserId
|
chatItem.UserId = session.UserId
|
||||||
@@ -187,7 +183,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
chatItem.Title = prompt
|
chatItem.Title = prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.db.Create(&chatItem)
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
// OPenAI 消息发送实现
|
// OPenAI 消息发送实现
|
||||||
func (h *ChatHandler) sendOpenAiMessage(
|
func (h *ChatHandler) sendOpenAiMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
utils.ReplyMessage(ws, ErrorMsg)
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
all, _ := io.ReadAll(response.Body)
|
if response.Body != nil {
|
||||||
logger.Error(string(all))
|
all, _ := io.ReadAll(response.Body)
|
||||||
|
logger.Error(string(all))
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -98,7 +100,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !utils.IsEmptyValue(tool) {
|
if !utils.IsEmptyValue(tool) {
|
||||||
res := h.db.Where("name = ?", tool.Function.Name).First(&function)
|
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
toolCall = true
|
toolCall = true
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
@@ -171,9 +173,6 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
}
|
}
|
||||||
@@ -181,79 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext && toolCall == false {
|
if h.App.SysConfig.EnableContext && toolCall == false {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
useContext := true
|
||||||
useContext := true
|
if toolCall {
|
||||||
if toolCall {
|
useContext = false
|
||||||
useContext = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// for prompt
|
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: useContext,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
var totalTokens = 0
|
|
||||||
if toolCall { // prompt + 函数名 + 参数 token
|
|
||||||
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
} else {
|
|
||||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
|
||||||
}
|
|
||||||
totalTokens += getTotalTokens(req)
|
|
||||||
|
|
||||||
historyReplyMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: h.extractImgUrl(message.Content),
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: useContext,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// for prompt
|
||||||
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: useContext,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
var replyTokens = 0
|
||||||
|
if toolCall { // prompt + 函数名 + 参数 token
|
||||||
|
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
||||||
|
replyTokens += tokens
|
||||||
|
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||||
|
replyTokens += tokens
|
||||||
|
} else {
|
||||||
|
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||||
|
}
|
||||||
|
replyTokens += getTotalTokens(req)
|
||||||
|
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: h.extractImgUrl(message.Content),
|
||||||
|
Tokens: replyTokens,
|
||||||
|
UseContext: useContext,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = session.UserId
|
chatItem.UserId = session.UserId
|
||||||
@@ -265,17 +262,19 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
chatItem.Title = prompt
|
chatItem.Title = prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.db.Create(&chatItem)
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
body, err := io.ReadAll(response.Body)
|
body, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
return fmt.Errorf("error with reading response: %v", err)
|
||||||
}
|
}
|
||||||
var res types.ApiError
|
var res types.ApiError
|
||||||
err = json.Unmarshal(body, &res)
|
err = json.Unmarshal(body, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
return fmt.Errorf("error with decode response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
||||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
||||||
// 移除当前 API key
|
// 移除当前 API key
|
||||||
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
||||||
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
||||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
||||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||||
|
|||||||
@@ -20,18 +20,21 @@ type qWenResp struct {
|
|||||||
Output struct {
|
Output struct {
|
||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
} `json:"output"`
|
} `json:"output,omitempty"`
|
||||||
Usage struct {
|
Usage struct {
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
} `json:"usage"`
|
} `json:"usage,omitempty"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
|
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 通义千问消息发送实现
|
// 通义千问消息发送实现
|
||||||
func (h *ChatHandler) sendQWenMessage(
|
func (h *ChatHandler) sendQWenMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -70,6 +73,7 @@ func (h *ChatHandler) sendQWenMessage(
|
|||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
|
|
||||||
var content, lastText, newText string
|
var content, lastText, newText string
|
||||||
|
var outPutStart = false
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
@@ -77,24 +81,32 @@ func (h *ChatHandler) sendQWenMessage(
|
|||||||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
|
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(line, "data:") {
|
if strings.HasPrefix(line, "data:") {
|
||||||
content = line[5:]
|
content = line[5:]
|
||||||
}
|
}
|
||||||
// 处理代码换行
|
|
||||||
if len(content) == 0 {
|
|
||||||
content = "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp qWenResp
|
var resp qWenResp
|
||||||
err := utils.JsonDecode(content, &resp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with parse data line: ", err)
|
|
||||||
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(contents) == 0 { // 发送消息头
|
if len(contents) == 0 { // 发送消息头
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
if !outPutStart {
|
||||||
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
|
outPutStart = true
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
// 处理代码换行
|
||||||
|
content = "\n"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := utils.JsonDecode(content, &resp)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with parse data line: ", content)
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if resp.Message != "" {
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//通过比较 lastText(上一次的文本)和 currentText(当前的文本),
|
//通过比较 lastText(上一次的文本)和 currentText(当前的文本),
|
||||||
@@ -128,9 +140,6 @@ func (h *ChatHandler) sendQWenMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
}
|
}
|
||||||
@@ -138,65 +147,64 @@ func (h *ChatHandler) sendQWenMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
// for prompt
|
||||||
// for prompt
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
if err != nil {
|
||||||
if err != nil {
|
logger.Error(err)
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for reply
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens := replyTokens + getTotalTokens(req)
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: totalTokens,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = session.UserId
|
chatItem.UserId = session.UserId
|
||||||
@@ -208,7 +216,7 @@ func (h *ChatHandler) sendQWenMessage(
|
|||||||
chatItem.Title = prompt
|
chatItem.Title = prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.db.Create(&chatItem)
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -50,15 +50,16 @@ type xunFeiResp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var Model2URL = map[string]string{
|
var Model2URL = map[string]string{
|
||||||
"general": "v1.1",
|
"general": "v1.1",
|
||||||
"generalv2": "v2.1",
|
"generalv2": "v2.1",
|
||||||
"generalv3": "v3.1",
|
"generalv3": "v3.1",
|
||||||
|
"generalv3.5": "v3.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
// 科大讯飞消息发送实现
|
// 科大讯飞消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendXunFeiMessage(
|
func (h *ChatHandler) sendXunFeiMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -68,13 +69,13 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
@@ -86,6 +87,7 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
|
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
|
||||||
|
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
|
||||||
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
||||||
//握手并建立websocket 连接
|
//握手并建立websocket 连接
|
||||||
conn, resp, err := d.Dial(wsURL, nil)
|
conn, resp, err := d.Dial(wsURL, nil)
|
||||||
@@ -166,9 +168,6 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
|
||||||
h.subUserCalls(userVo, session)
|
|
||||||
|
|
||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
}
|
}
|
||||||
@@ -176,65 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
// for prompt
|
||||||
// for prompt
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
if err != nil {
|
||||||
if err != nil {
|
logger.Error(err)
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
historyUserMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: userVo.Avatar,
|
|
||||||
Content: template.HTMLEscapeString(prompt),
|
|
||||||
Tokens: promptToken,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
|
||||||
totalTokens := replyToken + getTotalTokens(req)
|
|
||||||
historyReplyMsg := model.ChatMessage{
|
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: totalTokens,
|
|
||||||
UseContext: true,
|
|
||||||
Model: req.Model,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
// 更新用户信息
|
|
||||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
|
||||||
}
|
}
|
||||||
|
historyUserMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: userVo.Avatar,
|
||||||
|
Content: template.HTMLEscapeString(prompt),
|
||||||
|
Tokens: promptToken,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.DB.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for reply
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens := replyTokens + getTotalTokens(req)
|
||||||
|
historyReplyMsg := model.ChatMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: totalTokens,
|
||||||
|
UseContext: true,
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.DB.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户算力
|
||||||
|
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
chatItem.ChatId = session.ChatId
|
chatItem.ChatId = session.ChatId
|
||||||
chatItem.UserId = session.UserId
|
chatItem.UserId = session.UserId
|
||||||
@@ -246,7 +244,7 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
chatItem.Title = prompt
|
chatItem.Title = prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
h.db.Create(&chatItem)
|
h.DB.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
|
|||||||
"parameter": map[string]interface{}{
|
"parameter": map[string]interface{}{
|
||||||
"chat": map[string]interface{}{
|
"chat": map[string]interface{}{
|
||||||
"domain": req.Model,
|
"domain": req.Model,
|
||||||
"temperature": float64(req.Temperature),
|
"temperature": req.Temperature,
|
||||||
"top_k": int64(6),
|
"top_k": int64(6),
|
||||||
"max_tokens": int64(req.MaxTokens),
|
"max_tokens": int64(req.MaxTokens),
|
||||||
"auditing": "default",
|
"auditing": "default",
|
||||||
|
|||||||
39
api/handler/config_handler.go
Normal file
39
api/handler/config_handler.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/utils"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
||||||
|
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取指定的系统配置
|
||||||
|
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||||
|
key := c.Query("key")
|
||||||
|
var config model.Config
|
||||||
|
res := h.DB.Where("marker", key).First(&config)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var value map[string]interface{}
|
||||||
|
err := utils.JsonDecode(config.Config, &value)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, value)
|
||||||
|
}
|
||||||
@@ -19,21 +19,18 @@ import (
|
|||||||
|
|
||||||
type FunctionHandler struct {
|
type FunctionHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
config types.ChatPlusApiConfig
|
config types.ChatPlusApiConfig
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
proxyURL string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
|
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
|
||||||
return &FunctionHandler{
|
return &FunctionHandler{
|
||||||
BaseHandler: BaseHandler{
|
BaseHandler: BaseHandler{
|
||||||
App: server,
|
App: server,
|
||||||
|
DB: db,
|
||||||
},
|
},
|
||||||
db: db,
|
|
||||||
config: config.ApiConfig,
|
config: config.ApiConfig,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
proxyURL: config.ProxyURL,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,68 +189,49 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Debugf("绘画参数:%+v", params)
|
logger.Debugf("绘画参数:%+v", params)
|
||||||
// check img calls
|
|
||||||
var user model.User
|
var user model.User
|
||||||
tx := h.db.Where("id = ?", params["user_id"]).First(&user)
|
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
resp.ERROR(c, "当前用户不存在!")
|
resp.ERROR(c, "当前用户不存在!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ImgCalls <= 0 {
|
if user.Power < h.App.SysConfig.DallPower {
|
||||||
resp.ERROR(c, "当前用户的绘图次数额度不足!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt := utils.InterfaceToString(params["prompt"])
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
// get image generation API KEY
|
// get image generation API KEY
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// get image generation api URL
|
|
||||||
var conf model.Config
|
|
||||||
var chatConfig types.ChatConfig
|
|
||||||
tx = h.db.Where("marker", "chat").First(&conf)
|
|
||||||
if tx.Error != nil {
|
|
||||||
resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := utils.JsonDecode(conf.Config, &chatConfig)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "error with decode chat config: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// translate prompt
|
// translate prompt
|
||||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||||
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]), h.App.Config.ProxyURL)
|
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt)
|
||||||
prompt = pt
|
prompt = pt
|
||||||
}
|
}
|
||||||
imgNum := chatConfig.DallImgNum
|
|
||||||
if imgNum <= 0 {
|
|
||||||
imgNum = 1
|
|
||||||
}
|
|
||||||
var res imgRes
|
var res imgRes
|
||||||
var errRes ErrRes
|
var errRes ErrRes
|
||||||
var request *req.Request
|
var request *req.Request
|
||||||
if apiKey.UseProxy && h.proxyURL != "" {
|
if apiKey.ProxyURL != "" {
|
||||||
request = req.C().SetProxyURL(h.proxyURL).R()
|
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
|
||||||
} else {
|
} else {
|
||||||
request = req.C().R()
|
request = req.C().R()
|
||||||
}
|
}
|
||||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, h.proxyURL)
|
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
|
||||||
r, err := request.SetHeader("Content-Type", "application/json").
|
r, err := request.SetHeader("Content-Type", "application/json").
|
||||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
SetBody(imgReq{
|
SetBody(imgReq{
|
||||||
Model: "dall-e-3",
|
Model: "dall-e-3",
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
N: imgNum,
|
N: 1,
|
||||||
Size: "1024x1024",
|
Size: "1024x1024",
|
||||||
}).
|
}).
|
||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
@@ -263,7 +241,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
logger.Debugf("%+v", res)
|
logger.Debugf("%+v", res)
|
||||||
// 存储图片
|
// 存储图片
|
||||||
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
|
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
|
||||||
@@ -273,8 +251,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL)
|
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL)
|
||||||
// update user's img_calls
|
// 更新用户算力
|
||||||
h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
var u model.User
|
||||||
|
h.DB.Where("id", user.Id).First(&u)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: h.App.SysConfig.DallPower,
|
||||||
|
Balance: u.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
resp.SUCCESS(c, content)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,32 +15,29 @@ import (
|
|||||||
// InviteHandler 用户邀请
|
// InviteHandler 用户邀请
|
||||||
type InviteHandler struct {
|
type InviteHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
||||||
h := InviteHandler{db: db}
|
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Code 获取当前用户邀请码
|
// Code 获取当前用户邀请码
|
||||||
func (h *InviteHandler) Code(c *gin.Context) {
|
func (h *InviteHandler) Code(c *gin.Context) {
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
var inviteCode model.InviteCode
|
var inviteCode model.InviteCode
|
||||||
res := h.db.Where("user_id = ?", userId).First(&inviteCode)
|
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
|
||||||
// 如果邀请码不存在,则创建一个
|
// 如果邀请码不存在,则创建一个
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
code := strings.ToUpper(utils.RandString(8))
|
code := strings.ToUpper(utils.RandString(8))
|
||||||
for {
|
for {
|
||||||
res = h.db.Where("code = ?", code).First(&inviteCode)
|
res = h.DB.Where("code = ?", code).First(&inviteCode)
|
||||||
if res.Error != nil { // 不存在相同的邀请码则退出
|
if res.Error != nil { // 不存在相同的邀请码则退出
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inviteCode.UserId = userId
|
inviteCode.UserId = userId
|
||||||
inviteCode.Code = code
|
inviteCode.Code = code
|
||||||
h.db.Create(&inviteCode)
|
h.DB.Create(&inviteCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var codeVo vo.InviteCode
|
var codeVo vo.InviteCode
|
||||||
@@ -65,7 +62,7 @@ func (h *InviteHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
session := h.db.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
|
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.InviteLog{}).Count(&total)
|
session.Model(&model.InviteLog{}).Count(&total)
|
||||||
var items []model.InviteLog
|
var items []model.InviteLog
|
||||||
@@ -91,6 +88,6 @@ func (h *InviteHandler) List(c *gin.Context) {
|
|||||||
// Hits 访问邀请码
|
// Hits 访问邀请码
|
||||||
func (h *InviteHandler) Hits(c *gin.Context) {
|
func (h *InviteHandler) Hits(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
h.db.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service"
|
"chatplus/service"
|
||||||
"chatplus/service/mj"
|
"chatplus/service/mj"
|
||||||
"chatplus/service/mj/plus"
|
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
@@ -24,32 +23,32 @@ import (
|
|||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
pool *mj.ServicePool
|
pool *mj.ServicePool
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
||||||
h := MidJourneyHandler{
|
return &MidJourneyHandler{
|
||||||
db: db,
|
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
pool: pool,
|
pool: pool,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ImgCalls <= 0 {
|
if user.Power < h.App.SysConfig.MjPower {
|
||||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,14 +159,19 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
|
Power: h.App.SysConfig.MjPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
opt := "绘图"
|
||||||
if data.TaskType == types.TaskBlend.String() {
|
if data.TaskType == types.TaskBlend.String() {
|
||||||
data.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
|
job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
|
||||||
|
opt = "融图"
|
||||||
} else if data.TaskType == types.TaskSwapFace.String() {
|
} else if data.TaskType == types.TaskSwapFace.String() {
|
||||||
data.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
|
job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
|
||||||
|
opt = "换脸"
|
||||||
}
|
}
|
||||||
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
|
||||||
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -187,8 +191,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
_ = client.Send([]byte("Task Updated"))
|
_ = client.Send([]byte("Task Updated"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// update user's img calls
|
// update user's power
|
||||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,9 +245,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
|
Power: h.App.SysConfig.MjActionPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -249,7 +269,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
if client != nil {
|
if client != nil {
|
||||||
_ = client.Send([]byte("Task Updated"))
|
_ = client.Send([]byte("Task Updated"))
|
||||||
}
|
}
|
||||||
|
// update user's power
|
||||||
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,9 +312,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
|
Power: h.App.SysConfig.MjActionPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -300,21 +337,60 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
_ = client.Send([]byte("Task Updated"))
|
_ = client.Send([]byte("Task Updated"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// update user's img calls
|
// update user's power
|
||||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImgWall 照片墙
|
||||||
|
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
// JobList 获取 MJ 任务列表
|
// JobList 获取 MJ 任务列表
|
||||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||||
status := h.GetInt(c, "status", 0)
|
status := h.GetBool(c, "status")
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 0)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
publish := h.GetBool(c, "publish")
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
if status == 1 {
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 MJ 任务列表
|
||||||
|
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
@@ -333,8 +409,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
var items []model.MidJourneyJob
|
var items []model.MidJourneyJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, types.NoData)
|
return res.Error, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.MidJourneyJob, 0)
|
var jobs = make([]vo.MidJourneyJob, 0)
|
||||||
@@ -345,13 +420,6 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 失败的任务直接删除
|
|
||||||
if job.Progress == -1 {
|
|
||||||
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
|
|
||||||
jobs = append(jobs, job)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
||||||
// discord 服务器图片需要使用代理转发图片数据流
|
// discord 服务器图片需要使用代理转发图片数据流
|
||||||
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
|
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
|
||||||
@@ -366,7 +434,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
|
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
return nil, jobs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
@@ -382,7 +450,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// remove job recode
|
||||||
res := h.db.Delete(&model.MidJourneyJob{Id: data.Id})
|
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -402,27 +470,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify MidJourney Plus 服务任务回调处理
|
|
||||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|
||||||
var data plus.CBReq
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
logger.Error("非法任务回调:%+v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := h.pool.Notify(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
} else {
|
|
||||||
userId := h.GetLoginUserId(c)
|
|
||||||
client := h.pool.Clients.Get(userId)
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish 发布图片到画廊显示
|
// Publish 发布图片到画廊显示
|
||||||
func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
@@ -434,7 +481,7 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
|
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -7,19 +7,17 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OrderHandler struct {
|
type OrderHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||||
h := OrderHandler{db: db}
|
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OrderHandler) List(c *gin.Context) {
|
func (h *OrderHandler) List(c *gin.Context) {
|
||||||
@@ -31,8 +29,8 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, _ := utils.GetLoginUser(c, h.db)
|
userId := h.GetLoginUserId(c)
|
||||||
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
|
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.Order{}).Count(&total)
|
session.Model(&model.Order{}).Count(&total)
|
||||||
var items []model.Order
|
var items []model.Order
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -34,7 +35,6 @@ type PaymentHandler struct {
|
|||||||
huPiPayService *payment.HuPiPayService
|
huPiPayService *payment.HuPiPayService
|
||||||
js *payment.PayJS
|
js *payment.PayJS
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
db *gorm.DB
|
|
||||||
fs embed.FS
|
fs embed.FS
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
@@ -44,20 +44,21 @@ func NewPaymentHandler(
|
|||||||
alipayService *payment.AlipayService,
|
alipayService *payment.AlipayService,
|
||||||
huPiPayService *payment.HuPiPayService,
|
huPiPayService *payment.HuPiPayService,
|
||||||
js *payment.PayJS,
|
js *payment.PayJS,
|
||||||
snowflake *service.Snowflake,
|
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
|
snowflake *service.Snowflake,
|
||||||
fs embed.FS) *PaymentHandler {
|
fs embed.FS) *PaymentHandler {
|
||||||
h := PaymentHandler{
|
return &PaymentHandler{
|
||||||
alipayService: alipayService,
|
alipayService: alipayService,
|
||||||
huPiPayService: huPiPayService,
|
huPiPayService: huPiPayService,
|
||||||
js: js,
|
js: js,
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
fs: fs,
|
fs: fs,
|
||||||
db: db,
|
|
||||||
lock: sync.Mutex{},
|
lock: sync.Mutex{},
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: server,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
h.App = server
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
||||||
@@ -70,7 +71,7 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.db.Where("order_no = ?", orderNo).First(&order)
|
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Order not found")
|
resp.ERROR(c, "Order not found")
|
||||||
return
|
return
|
||||||
@@ -83,7 +84,7 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 更新扫码状态
|
// 更新扫码状态
|
||||||
h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
|
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
|
||||||
if payWay == "alipay" { // 支付宝
|
if payWay == "alipay" { // 支付宝
|
||||||
// 生成支付链接
|
// 生成支付链接
|
||||||
notifyURL := h.App.Config.AlipayConfig.NotifyURL
|
notifyURL := h.App.Config.AlipayConfig.NotifyURL
|
||||||
@@ -129,7 +130,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
|
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Order not found")
|
resp.ERROR(c, "Order not found")
|
||||||
return
|
return
|
||||||
@@ -144,7 +145,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
var item model.Order
|
var item model.Order
|
||||||
h.db.Where("order_no = ?", data.OrderNo).First(&item)
|
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
|
||||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
||||||
order.Status = item.Status
|
order.Status = item.Status
|
||||||
break
|
break
|
||||||
@@ -168,7 +169,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var product model.Product
|
var product model.Product
|
||||||
res := h.db.First(&product, data.ProductId)
|
res := h.DB.First(&product, data.ProductId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Product not found")
|
resp.ERROR(c, "Product not found")
|
||||||
return
|
return
|
||||||
@@ -180,7 +181,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var user model.User
|
var user model.User
|
||||||
res = h.db.First(&user, data.UserId)
|
res = h.DB.First(&user, data.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "Invalid user ID")
|
resp.ERROR(c, "Invalid user ID")
|
||||||
return
|
return
|
||||||
@@ -202,24 +203,25 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
// 创建订单
|
// 创建订单
|
||||||
remark := types.OrderRemark{
|
remark := types.OrderRemark{
|
||||||
Days: product.Days,
|
Days: product.Days,
|
||||||
Calls: product.Calls,
|
Power: product.Power,
|
||||||
ImgCalls: product.ImgCalls,
|
|
||||||
Name: product.Name,
|
Name: product.Name,
|
||||||
Price: product.Price,
|
Price: product.Price,
|
||||||
Discount: product.Discount,
|
Discount: product.Discount,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||||
order := model.Order{
|
order := model.Order{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
ProductId: product.Id,
|
ProductId: product.Id,
|
||||||
OrderNo: orderNo,
|
OrderNo: orderNo,
|
||||||
Subject: product.Name,
|
Subject: product.Name,
|
||||||
Amount: product.Price - product.Discount,
|
Amount: amount,
|
||||||
Status: types.OrderNotPaid,
|
Status: types.OrderNotPaid,
|
||||||
PayWay: payWay,
|
PayWay: payWay,
|
||||||
Remark: utils.JsonEncode(remark),
|
Remark: utils.JsonEncode(remark),
|
||||||
}
|
}
|
||||||
res = h.db.Create(&order)
|
res = h.DB.Create(&order)
|
||||||
if res.Error != nil || res.RowsAffected == 0 {
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -275,10 +277,121 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
|
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mobile 移动端支付
|
||||||
|
func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
PayWay string `json:"pay_way"` // 支付方式
|
||||||
|
ProductId uint `json:"product_id"`
|
||||||
|
UserId int `json:"user_id"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var product model.Product
|
||||||
|
res := h.DB.First(&product, data.ProductId)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "Product not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
orderNo, err := h.snowflake.Next(false)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
res = h.DB.First(&user, data.UserId)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||||
|
var payWay string
|
||||||
|
var notifyURL, returnURL string
|
||||||
|
var payURL string
|
||||||
|
switch data.PayWay {
|
||||||
|
case "hupi":
|
||||||
|
payWay = PayWayXunHu
|
||||||
|
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||||
|
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||||
|
params := payment.HuPiPayReq{
|
||||||
|
Version: "1.1",
|
||||||
|
TradeOrderId: orderNo,
|
||||||
|
TotalFee: fmt.Sprintf("%f", amount),
|
||||||
|
Title: product.Name,
|
||||||
|
NotifyURL: notifyURL,
|
||||||
|
ReturnURL: returnURL,
|
||||||
|
CallbackURL: returnURL,
|
||||||
|
WapName: "极客学长",
|
||||||
|
}
|
||||||
|
r, err := h.huPiPayService.Pay(params)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with generating Pay URL: ", err.Error())
|
||||||
|
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payURL = r.URL
|
||||||
|
case "payjs":
|
||||||
|
payWay = PayWayJs
|
||||||
|
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||||
|
returnURL = h.App.Config.JPayConfig.ReturnURL
|
||||||
|
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
|
||||||
|
params.Add("out_trade_no", orderNo)
|
||||||
|
params.Add("body", product.Name)
|
||||||
|
params.Add("notify_url", notifyURL)
|
||||||
|
params.Add("auto", "0")
|
||||||
|
payURL = h.js.PayH5(params)
|
||||||
|
case "alipay":
|
||||||
|
payWay = PayWayAlipay
|
||||||
|
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||||
|
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||||
|
payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "error with generating Pay URL: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建订单
|
||||||
|
remark := types.OrderRemark{
|
||||||
|
Days: product.Days,
|
||||||
|
Power: product.Power,
|
||||||
|
Name: product.Name,
|
||||||
|
Price: product.Price,
|
||||||
|
Discount: product.Discount,
|
||||||
|
}
|
||||||
|
|
||||||
|
order := model.Order{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
ProductId: product.Id,
|
||||||
|
OrderNo: orderNo,
|
||||||
|
Subject: product.Name,
|
||||||
|
Amount: amount,
|
||||||
|
Status: types.OrderNotPaid,
|
||||||
|
PayWay: payWay,
|
||||||
|
Remark: utils.JsonEncode(remark),
|
||||||
|
}
|
||||||
|
res = h.DB.Create(&order)
|
||||||
|
if res.Error != nil || res.RowsAffected == 0 {
|
||||||
|
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, payURL)
|
||||||
|
}
|
||||||
|
|
||||||
// 异步通知回调公共逻辑
|
// 异步通知回调公共逻辑
|
||||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||||
var order model.Order
|
var order model.Order
|
||||||
res := h.db.Where("order_no = ?", orderNo).First(&order)
|
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with fetch order: %v", res.Error)
|
err := fmt.Errorf("error with fetch order: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -294,7 +407,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res = h.db.First(&user, order.UserId)
|
res = h.DB.First(&user, order.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with fetch user info: %v", res.Error)
|
err := fmt.Errorf("error with fetch user info: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -309,29 +422,33 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var opt string
|
||||||
|
var power int
|
||||||
if user.Vip { // 已经是 VIP 用户
|
if user.Vip { // 已经是 VIP 用户
|
||||||
if remark.Days > 0 { // 只延期 VIP,不增加调用次数
|
if remark.Days > 0 { // 只延期 VIP,不增加调用次数
|
||||||
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
||||||
} else { // 充值点卡,直接增加次数即可
|
} else { // 充值点卡,直接增加次数即可
|
||||||
user.Calls += remark.Calls
|
user.Power += remark.Power
|
||||||
user.ImgCalls += remark.ImgCalls
|
opt = "点卡充值"
|
||||||
|
power = remark.Power
|
||||||
}
|
}
|
||||||
|
|
||||||
} else { // 非 VIP 用户
|
} else { // 非 VIP 用户
|
||||||
if remark.Days > 0 { // vip 套餐:days > 0, calls == 0
|
if remark.Days > 0 { // vip 套餐:days > 0, power == 0
|
||||||
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
||||||
user.Calls += h.App.SysConfig.VipMonthCalls
|
user.Power += h.App.SysConfig.VipMonthPower
|
||||||
user.ImgCalls += h.App.SysConfig.VipMonthImgCalls
|
|
||||||
user.Vip = true
|
user.Vip = true
|
||||||
|
opt = "VIP充值"
|
||||||
|
power = h.App.SysConfig.VipMonthPower
|
||||||
} else { //点卡:days == 0, calls > 0
|
} else { //点卡:days == 0, calls > 0
|
||||||
user.Calls += remark.Calls
|
user.Power += remark.Power
|
||||||
user.ImgCalls += remark.ImgCalls
|
opt = "点卡充值"
|
||||||
|
power = remark.Power
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新用户信息
|
// 更新用户信息
|
||||||
res = h.db.Updates(&user)
|
res = h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with update user info: %v", res.Error)
|
err := fmt.Errorf("error with update user info: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -342,7 +459,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
order.PayTime = time.Now().Unix()
|
order.PayTime = time.Now().Unix()
|
||||||
order.Status = types.OrderPaidSuccess
|
order.Status = types.OrderPaidSuccess
|
||||||
order.TradeNo = tradeNo
|
order.TradeNo = tradeNo
|
||||||
res = h.db.Updates(&order)
|
res = h.DB.Updates(&order)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
err := fmt.Errorf("error with update order info: %v", res.Error)
|
err := fmt.Errorf("error with update order info: %v", res.Error)
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -350,7 +467,23 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 更新产品销量
|
// 更新产品销量
|
||||||
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
||||||
|
|
||||||
|
// 记录算力充值日志
|
||||||
|
if opt != "" {
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerRecharge,
|
||||||
|
Amount: power,
|
||||||
|
Balance: user.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: order.PayWay,
|
||||||
|
Remark: fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
67
api/handler/power_log_handler.go
Normal file
67
api/handler/power_log_handler.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/store/vo"
|
||||||
|
"chatplus/utils"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PowerLogHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||||
|
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Date []string `json:"date"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
|
session = session.Where("user_id", userId)
|
||||||
|
if data.Model != "" {
|
||||||
|
session = session.Where("model", data.Model)
|
||||||
|
}
|
||||||
|
if len(data.Date) == 2 {
|
||||||
|
start := data.Date[0] + " 00:00:00"
|
||||||
|
end := data.Date[1] + " 00:00:00"
|
||||||
|
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
session.Model(&model.PowerLog{}).Count(&total)
|
||||||
|
var items []model.PowerLog
|
||||||
|
var list = make([]vo.PowerLog, 0)
|
||||||
|
offset := (data.Page - 1) * data.PageSize
|
||||||
|
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||||||
|
if res.Error == nil {
|
||||||
|
for _, item := range items {
|
||||||
|
var log vo.PowerLog
|
||||||
|
err := utils.CopyObject(item, &log)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Id = item.Id
|
||||||
|
log.CreatedAt = item.CreatedAt.Unix()
|
||||||
|
log.TypeStr = item.Type.String()
|
||||||
|
list = append(list, log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||||
|
}
|
||||||
@@ -12,20 +12,17 @@ import (
|
|||||||
|
|
||||||
type ProductHandler struct {
|
type ProductHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||||
h := ProductHandler{db: db}
|
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 模型列表
|
// List 模型列表
|
||||||
func (h *ProductHandler) List(c *gin.Context) {
|
func (h *ProductHandler) List(c *gin.Context) {
|
||||||
var items []model.Product
|
var items []model.Product
|
||||||
var list = make([]vo.Product, 0)
|
var list = make([]vo.Product, 0)
|
||||||
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
|
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
var product vo.Product
|
var product vo.Product
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core"
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
|
||||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
|
||||||
|
|
||||||
type PromptHandler struct {
|
|
||||||
BaseHandler
|
|
||||||
db *gorm.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
|
|
||||||
h := &PromptHandler{db: db}
|
|
||||||
h.App = app
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewrite translate and rewrite prompt with ChatGPT
|
|
||||||
func (h *PromptHandler) Rewrite(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PromptHandler) Translate(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt), h.App.Config.ProxyURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, content)
|
|
||||||
}
|
|
||||||
@@ -7,37 +7,35 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RewardHandler struct {
|
type RewardHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
|
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
|
||||||
h := RewardHandler{db: db, lock: sync.Mutex{}}
|
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||||
h.App = server
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify 打赏码核销
|
// Verify 打赏码核销
|
||||||
func (h *RewardHandler) Verify(c *gin.Context) {
|
func (h *RewardHandler) Verify(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
TxId string `json:"tx_id"`
|
TxId string `json:"tx_id"`
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.HACKER(c)
|
resp.HACKER(c)
|
||||||
return
|
return
|
||||||
@@ -50,7 +48,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
|||||||
defer h.lock.Unlock()
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
var item model.Reward
|
var item model.Reward
|
||||||
res := h.db.Where("tx_id = ?", data.TxId).First(&item)
|
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "无效的众筹交易流水号!")
|
resp.ERROR(c, "无效的众筹交易流水号!")
|
||||||
return
|
return
|
||||||
@@ -61,18 +59,13 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := h.db.Begin()
|
tx := h.DB.Begin()
|
||||||
exchange := vo.RewardExchange{}
|
exchange := vo.RewardExchange{}
|
||||||
if data.Type == "chat" {
|
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
|
||||||
calls := math.Ceil(item.Amount / h.App.SysConfig.ChatCallPrice)
|
exchange.Power = int(power)
|
||||||
exchange.Calls = int(calls)
|
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
|
||||||
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
|
|
||||||
} else if data.Type == "img" {
|
|
||||||
calls := math.Ceil(item.Amount / h.App.SysConfig.ImgCallPrice)
|
|
||||||
exchange.ImgCalls = int(calls)
|
|
||||||
res = h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", calls))
|
|
||||||
}
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
tx.Rollback()
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -81,13 +74,25 @@ func (h *RewardHandler) Verify(c *gin.Context) {
|
|||||||
item.Status = true
|
item.Status = true
|
||||||
item.UserId = user.Id
|
item.UserId = user.Id
|
||||||
item.Exchange = utils.JsonEncode(exchange)
|
item.Exchange = utils.JsonEncode(exchange)
|
||||||
res = h.db.Updates(&item)
|
res = tx.Updates(&item)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录算力充值日志
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerReward,
|
||||||
|
Amount: exchange.Power,
|
||||||
|
Balance: user.Power + exchange.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "众筹支付",
|
||||||
|
Remark: fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ import (
|
|||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -23,19 +24,19 @@ import (
|
|||||||
type SdJobHandler struct {
|
type SdJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
db *gorm.DB
|
|
||||||
pool *sd.ServicePool
|
pool *sd.ServicePool
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
|
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
|
||||||
h := SdJobHandler{
|
return &SdJobHandler{
|
||||||
db: db,
|
|
||||||
pool: pool,
|
pool: pool,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: app,
|
||||||
|
DB: db,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
h.App = app
|
|
||||||
return &h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
// Client WebSocket 客户端,用于通知任务状态变更
|
||||||
@@ -60,7 +61,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return false
|
return false
|
||||||
@@ -71,8 +72,8 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ImgCalls <= 0 {
|
if user.Power < h.App.SysConfig.SdPower {
|
||||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,6 +133,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
HdScaleAlg: data.HdScaleAlg,
|
HdScaleAlg: data.HdScaleAlg,
|
||||||
HdSteps: data.HdSteps,
|
HdSteps: data.HdSteps,
|
||||||
}
|
}
|
||||||
|
|
||||||
job := model.SdJob{
|
job := model.SdJob{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Type: types.TaskImage.String(),
|
Type: types.TaskImage.String(),
|
||||||
@@ -139,9 +141,10 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
Params: utils.JsonEncode(params),
|
Params: utils.JsonEncode(params),
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
|
Power: h.App.SysConfig.SdPower,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
res := h.db.Create(&job)
|
res := h.DB.Create(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -151,27 +154,71 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
Id: int(job.Id),
|
Id: int(job.Id),
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Type: types.TaskImage,
|
Type: types.TaskImage,
|
||||||
Prompt: data.Prompt,
|
|
||||||
Params: params,
|
Params: params,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
})
|
})
|
||||||
|
|
||||||
// update user's img calls
|
client := h.pool.Clients.Get(uint(job.UserId))
|
||||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// update user's power
|
||||||
|
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
|
||||||
|
// 记录算力变化日志
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
user, _ := h.GetLoginUser(c)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power - job.Power,
|
||||||
|
Mark: types.PowerSub,
|
||||||
|
Model: "stable-diffusion",
|
||||||
|
Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobList 获取 stable diffusion 任务列表
|
// ImgWall 照片墙
|
||||||
|
func (h *SdJobHandler) ImgWall(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 SD 任务列表
|
||||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||||
status := h.GetInt(c, "status", 0)
|
status := h.GetBool(c, "status")
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 0)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
publish := h.GetBool(c, "publish")
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
if status == 1 {
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 MJ 任务列表
|
||||||
|
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
|
||||||
|
|
||||||
|
session := h.DB.Session(&gorm.Session{})
|
||||||
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
@@ -190,8 +237,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
var items []model.SdJob
|
var items []model.SdJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, types.NoData)
|
return res.Error, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.SdJob, 0)
|
var jobs = make([]vo.SdJob, 0)
|
||||||
@@ -202,18 +248,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if job.Progress == -1 {
|
|
||||||
h.db.Delete(&model.SdJob{Id: job.Id})
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.Progress < 100 {
|
if item.Progress < 100 {
|
||||||
// 5 分钟还没完成的任务直接删除
|
|
||||||
if time.Now().Sub(item.CreatedAt) > time.Minute*5 {
|
|
||||||
h.db.Delete(&item)
|
|
||||||
// 退回绘图次数
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 正在运行中任务使用代理访问图片
|
// 正在运行中任务使用代理访问图片
|
||||||
image, err := utils.DownloadImage(item.ImgURL, "")
|
image, err := utils.DownloadImage(item.ImgURL, "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -222,13 +257,15 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
|
||||||
|
return nil, jobs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
func (h *SdJobHandler) Remove(c *gin.Context) {
|
func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
ImgURL string `json:"img_url"`
|
ImgURL string `json:"img_url"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
@@ -237,7 +274,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove job recode
|
// remove job recode
|
||||||
res := h.db.Delete(&model.SdJob{Id: data.Id})
|
res := h.DB.Delete(&model.SdJob{Id: data.Id})
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, res.Error.Error())
|
resp.ERROR(c, res.Error.Error())
|
||||||
return
|
return
|
||||||
@@ -249,6 +286,11 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
|||||||
logger.Error("remove image failed: ", err)
|
logger.Error("remove image failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := h.pool.Clients.Get(data.UserId)
|
||||||
|
if client != nil {
|
||||||
|
_ = client.Send([]byte("Task Updated"))
|
||||||
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,7 +305,7 @@ func (h *SdJobHandler) Publish(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := h.db.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
|
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -29,9 +29,12 @@ func NewSmsHandler(
|
|||||||
sms *sms.ServiceManager,
|
sms *sms.ServiceManager,
|
||||||
smtp *service.SmtpService,
|
smtp *service.SmtpService,
|
||||||
captcha *service.CaptchaService) *SmsHandler {
|
captcha *service.CaptchaService) *SmsHandler {
|
||||||
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha, smtp: smtp}
|
return &SmsHandler{
|
||||||
handler.App = app
|
redis: client,
|
||||||
return handler
|
sms: sms,
|
||||||
|
captcha: captcha,
|
||||||
|
smtp: smtp,
|
||||||
|
BaseHandler: BaseHandler{App: app}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendCode 发送验证码
|
// SendCode 发送验证码
|
||||||
|
|||||||
@@ -3,12 +3,6 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"chatplus/service"
|
"chatplus/service"
|
||||||
"chatplus/service/payment"
|
"chatplus/service/payment"
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"chatplus/utils/resp"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,208 +15,3 @@ type TestHandler struct {
|
|||||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
|
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
|
||||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||||
}
|
}
|
||||||
|
|
||||||
type reqBody struct {
|
|
||||||
BotType string `json:"botType"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Base64Array []interface{} `json:"base64Array,omitempty"`
|
|
||||||
AccountFilter struct {
|
|
||||||
InstanceId string `json:"instanceId"`
|
|
||||||
Modes []interface{} `json:"modes"`
|
|
||||||
Remix bool `json:"remix"`
|
|
||||||
RemixAutoConsidered bool `json:"remixAutoConsidered"`
|
|
||||||
} `json:"accountFilter,omitempty"`
|
|
||||||
NotifyHook string `json:"notifyHook"`
|
|
||||||
State string `json:"state,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type resBody struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Properties struct {
|
|
||||||
} `json:"properties"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) Test(c *gin.Context) {
|
|
||||||
image(c)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func upscale(c *gin.Context) {
|
|
||||||
apiURL := "https://api.openai1s.cn/mj/submit/action"
|
|
||||||
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
|
|
||||||
body := map[string]string{
|
|
||||||
"customId": "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a",
|
|
||||||
"taskId": "1704880156226095",
|
|
||||||
"notifyHook": "http://r9it.com:6004/api/test/mj",
|
|
||||||
}
|
|
||||||
var res resBody
|
|
||||||
var resErr errRes
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("Authorization", "Bearer "+token).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&resErr).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "请求出错:"+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, res)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type queryRes struct {
|
|
||||||
Action string `json:"action"`
|
|
||||||
Buttons []struct {
|
|
||||||
CustomId string `json:"customId"`
|
|
||||||
Emoji string `json:"emoji"`
|
|
||||||
Label string `json:"label"`
|
|
||||||
Style int `json:"style"`
|
|
||||||
Type int `json:"type"`
|
|
||||||
} `json:"buttons"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
FailReason string `json:"failReason"`
|
|
||||||
FinishTime int `json:"finishTime"`
|
|
||||||
Id string `json:"id"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Properties struct {
|
|
||||||
} `json:"properties"`
|
|
||||||
StartTime int `json:"startTime"`
|
|
||||||
State string `json:"state"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
SubmitTime int `json:"submitTime"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func query(c *gin.Context) {
|
|
||||||
apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch"
|
|
||||||
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
|
|
||||||
var res queryRes
|
|
||||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+token).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
Get(apiURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "请求出错:"+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
resp.ERROR(c, "返回错误状态:"+r.Status)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
type errRes struct {
|
|
||||||
Error struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func image(c *gin.Context) {
|
|
||||||
apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine"
|
|
||||||
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
|
|
||||||
body := reqBody{
|
|
||||||
BotType: "MID_JOURNEY",
|
|
||||||
Prompt: "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6",
|
|
||||||
NotifyHook: "http://r9it.com:6004/api/test/mj",
|
|
||||||
}
|
|
||||||
var res resBody
|
|
||||||
var resErr errRes
|
|
||||||
r, err := req.C().R().
|
|
||||||
SetHeader("Authorization", "Bearer "+token).
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
SetErrorResult(&resErr).
|
|
||||||
Post(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, "请求出错:"+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.IsErrorState() {
|
|
||||||
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
type cbReq struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
SubmitTime int64 `json:"submitTime"`
|
|
||||||
StartTime int64 `json:"startTime"`
|
|
||||||
FinishTime int64 `json:"finishTime"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
FailReason interface{} `json:"failReason"`
|
|
||||||
Properties struct {
|
|
||||||
FinalPrompt string `json:"finalPrompt"`
|
|
||||||
} `json:"properties"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) Mj(c *gin.Context) {
|
|
||||||
var data cbReq
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
logger.Debugf("任务ID:%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt)
|
|
||||||
apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch"
|
|
||||||
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
|
|
||||||
var res queryRes
|
|
||||||
_, _ = req.C().R().SetHeader("Authorization", "Bearer "+token).
|
|
||||||
SetSuccessResult(&res).
|
|
||||||
Get(apiURL)
|
|
||||||
|
|
||||||
fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) initUserNickname(c *gin.Context) {
|
|
||||||
var users []model.User
|
|
||||||
tx := h.db.Find(&users)
|
|
||||||
if tx.Error != nil {
|
|
||||||
resp.ERROR(c, tx.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, u := range users {
|
|
||||||
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
|
|
||||||
h.db.Updates(&u)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TestHandler) initMjTaskId(c *gin.Context) {
|
|
||||||
var jobs []model.MidJourneyJob
|
|
||||||
tx := h.db.Find(&jobs)
|
|
||||||
if tx.Error != nil {
|
|
||||||
resp.ERROR(c, tx.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, job := range jobs {
|
|
||||||
id, _ := h.snowflake.Next(true)
|
|
||||||
job.TaskId = id
|
|
||||||
h.db.Updates(&job)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -14,14 +14,11 @@ import (
|
|||||||
|
|
||||||
type UploadHandler struct {
|
type UploadHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
uploaderManager *oss.UploaderManager
|
uploaderManager *oss.UploaderManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
|
||||||
handler := &UploadHandler{db: db, uploaderManager: manager}
|
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
||||||
handler.App = app
|
|
||||||
return handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||||
@@ -32,8 +29,8 @@ func (h *UploadHandler) Upload(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
res := h.db.Create(&model.File{
|
res := h.DB.Create(&model.File{
|
||||||
UserId: userId,
|
UserId: int(userId),
|
||||||
Name: file.Name,
|
Name: file.Name,
|
||||||
ObjKey: file.ObjKey,
|
ObjKey: file.ObjKey,
|
||||||
URL: file.URL,
|
URL: file.URL,
|
||||||
@@ -53,7 +50,7 @@ func (h *UploadHandler) List(c *gin.Context) {
|
|||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
var items []model.File
|
var items []model.File
|
||||||
var files = make([]vo.File, 0)
|
var files = make([]vo.File, 0)
|
||||||
h.db.Where("user_id = ?", userId).Find(&items)
|
h.DB.Where("user_id = ?", userId).Find(&items)
|
||||||
if len(items) > 0 {
|
if len(items) > 0 {
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var file vo.File
|
var file vo.File
|
||||||
@@ -75,14 +72,14 @@ func (h *UploadHandler) Remove(c *gin.Context) {
|
|||||||
userId := h.GetLoginUserId(c)
|
userId := h.GetLoginUserId(c)
|
||||||
id := h.GetInt(c, "id", 0)
|
id := h.GetInt(c, "id", 0)
|
||||||
var file model.File
|
var file model.File
|
||||||
tx := h.db.Where("user_id = ? AND id = ?", userId, id).First(&file)
|
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
|
||||||
if tx.Error != nil || file.Id == 0 {
|
if tx.Error != nil || file.Id == 0 {
|
||||||
resp.ERROR(c, "file not existed")
|
resp.ERROR(c, "file not existed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove database
|
// remove database
|
||||||
tx = h.db.Model(&model.File{}).Delete("id = ?", id)
|
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
|
||||||
if tx.Error != nil || tx.RowsAffected == 0 {
|
if tx.Error != nil || tx.RowsAffected == 0 {
|
||||||
resp.ERROR(c, "failed to update database")
|
resp.ERROR(c, "failed to update database")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
|
|
||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
|
||||||
searcher *xdb.Searcher
|
searcher *xdb.Searcher
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
}
|
}
|
||||||
@@ -31,15 +30,14 @@ func NewUserHandler(
|
|||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
searcher *xdb.Searcher,
|
searcher *xdb.Searcher,
|
||||||
client *redis.Client) *UserHandler {
|
client *redis.Client) *UserHandler {
|
||||||
handler := &UserHandler{db: db, searcher: searcher, redis: client}
|
return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
|
||||||
handler.App = app
|
|
||||||
return handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register user register
|
// Register user register
|
||||||
func (h *UserHandler) Register(c *gin.Context) {
|
func (h *UserHandler) Register(c *gin.Context) {
|
||||||
// parameters process
|
// parameters process
|
||||||
var data struct {
|
var data struct {
|
||||||
|
RegWay string `json:"reg_way"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
@@ -57,8 +55,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查验证码
|
// 检查验证码
|
||||||
var key string
|
var key string
|
||||||
if utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") ||
|
if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
|
||||||
utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
|
|
||||||
key = CodeStorePrefix + data.Username
|
key = CodeStorePrefix + data.Username
|
||||||
code, err := h.redis.Get(c, key).Result()
|
code, err := h.redis.Get(c, key).Result()
|
||||||
if err != nil || code != data.Code {
|
if err != nil || code != data.Code {
|
||||||
@@ -70,7 +67,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
// 验证邀请码
|
// 验证邀请码
|
||||||
inviteCode := model.InviteCode{}
|
inviteCode := model.InviteCode{}
|
||||||
if data.InviteCode != "" {
|
if data.InviteCode != "" {
|
||||||
res := h.db.Where("code = ?", data.InviteCode).First(&inviteCode)
|
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "无效的邀请码")
|
resp.ERROR(c, "无效的邀请码")
|
||||||
return
|
return
|
||||||
@@ -79,8 +76,8 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
|
|
||||||
// check if the username is exists
|
// check if the username is exists
|
||||||
var item model.User
|
var item model.User
|
||||||
res := h.db.Where("username = ?", data.Username).First(&item)
|
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||||||
if res.RowsAffected > 0 {
|
if item.Id > 0 {
|
||||||
resp.ERROR(c, "该用户名已经被注册")
|
resp.ERROR(c, "该用户名已经被注册")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -95,18 +92,10 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
Status: true,
|
Status: true,
|
||||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||||
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
||||||
ChatConfig: utils.JsonEncode(types.UserChatConfig{
|
Power: h.App.SysConfig.InitPower,
|
||||||
ApiKeys: map[types.Platform]string{
|
|
||||||
types.OpenAI: "",
|
|
||||||
types.Azure: "",
|
|
||||||
types.ChatGLM: "",
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
Calls: h.App.SysConfig.InitChatCalls,
|
|
||||||
ImgCalls: h.App.SysConfig.InitImgCalls,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Create(&user)
|
res = h.DB.Create(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "保存数据失败")
|
resp.ERROR(c, "保存数据失败")
|
||||||
logger.Error(res.Error)
|
logger.Error(res.Error)
|
||||||
@@ -116,21 +105,32 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
// 记录邀请关系
|
// 记录邀请关系
|
||||||
if data.InviteCode != "" {
|
if data.InviteCode != "" {
|
||||||
// 增加邀请数量
|
// 增加邀请数量
|
||||||
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||||
if h.App.SysConfig.InviteChatCalls > 0 {
|
if h.App.SysConfig.InvitePower > 0 {
|
||||||
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("calls", gorm.Expr("calls + ?", h.App.SysConfig.InviteChatCalls))
|
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
|
||||||
}
|
// 记录邀请算力充值日志
|
||||||
if h.App.SysConfig.InviteImgCalls > 0 {
|
var inviter model.User
|
||||||
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", h.App.SysConfig.InviteImgCalls))
|
h.DB.Where("id", inviteCode.UserId).First(&inviter)
|
||||||
|
h.DB.Create(&model.PowerLog{
|
||||||
|
UserId: inviter.Id,
|
||||||
|
Username: inviter.Username,
|
||||||
|
Type: types.PowerInvite,
|
||||||
|
Amount: h.App.SysConfig.InvitePower,
|
||||||
|
Balance: inviter.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "",
|
||||||
|
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加邀请记录
|
// 添加邀请记录
|
||||||
h.db.Create(&model.InviteLog{
|
h.DB.Create(&model.InviteLog{
|
||||||
InviterId: inviteCode.UserId,
|
InviterId: inviteCode.UserId,
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
InviteCode: inviteCode.Code,
|
InviteCode: inviteCode.Code,
|
||||||
Reward: utils.JsonEncode(types.InviteReward{ChatCalls: h.App.SysConfig.InviteChatCalls, ImgCalls: h.App.SysConfig.InviteImgCalls}),
|
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +166,7 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Where("username = ?", data.Username).First(&user)
|
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "用户名不存在")
|
resp.ERROR(c, "用户名不存在")
|
||||||
return
|
return
|
||||||
@@ -186,9 +186,9 @@ func (h *UserHandler) Login(c *gin.Context) {
|
|||||||
// 更新最后登录时间和IP
|
// 更新最后登录时间和IP
|
||||||
user.LastLoginIp = c.ClientIP()
|
user.LastLoginIp = c.ClientIP()
|
||||||
user.LastLoginAt = time.Now().Unix()
|
user.LastLoginAt = time.Now().Unix()
|
||||||
h.db.Model(&user).Updates(user)
|
h.DB.Model(&user).Updates(user)
|
||||||
|
|
||||||
h.db.Create(&model.UserLoginLog{
|
h.DB.Create(&model.UserLoginLog{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
LoginIp: c.ClientIP(),
|
LoginIp: c.ClientIP(),
|
||||||
@@ -233,7 +233,7 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
|||||||
|
|
||||||
// Session 获取/验证会话
|
// Session 获取/验证会话
|
||||||
func (h *UserHandler) Session(c *gin.Context) {
|
func (h *UserHandler) Session(c *gin.Context) {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
err := utils.CopyObject(user, &userVo)
|
err := utils.CopyObject(user, &userVo)
|
||||||
@@ -249,27 +249,23 @@ func (h *UserHandler) Session(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type userProfile struct {
|
type userProfile struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Nickname string `json:"nickname"`
|
Nickname string `json:"nickname"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Avatar string `json:"avatar"`
|
Avatar string `json:"avatar"`
|
||||||
ChatConfig types.UserChatConfig `json:"chat_config"`
|
Power int `json:"power"`
|
||||||
Calls int `json:"calls"`
|
ExpiredTime int64 `json:"expired_time"`
|
||||||
ImgCalls int `json:"img_calls"`
|
Vip bool `json:"vip"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
Tokens int64 `json:"tokens"`
|
|
||||||
ExpiredTime int64 `json:"expired_time"`
|
|
||||||
Vip bool `json:"vip"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserHandler) Profile(c *gin.Context) {
|
func (h *UserHandler) Profile(c *gin.Context) {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.db.First(&user, user.Id)
|
h.DB.First(&user, user.Id)
|
||||||
var profile userProfile
|
var profile userProfile
|
||||||
err = utils.CopyObject(user, &profile)
|
err = utils.CopyObject(user, &profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -289,15 +285,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.db.First(&user, user.Id)
|
h.DB.First(&user, user.Id)
|
||||||
user.Avatar = data.Avatar
|
user.Avatar = data.Avatar
|
||||||
user.Nickname = data.Nickname
|
user.Nickname = data.Nickname
|
||||||
res := h.db.Updates(&user)
|
res := h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新用户信息失败")
|
resp.ERROR(c, "更新用户信息失败")
|
||||||
return
|
return
|
||||||
@@ -322,21 +318,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
password := utils.GenPassword(data.OldPass, user.Salt)
|
password := utils.GenPassword(data.OldPass, user.Salt)
|
||||||
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
|
logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
|
||||||
if password != user.Password {
|
if password != user.Password {
|
||||||
resp.ERROR(c, "原密码错误")
|
resp.ERROR(c, "原密码错误")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newPass := utils.GenPassword(data.Password, user.Salt)
|
newPass := utils.GenPassword(data.Password, user.Salt)
|
||||||
res := h.db.Model(&user).UpdateColumn("password", newPass)
|
res := h.DB.Model(&user).UpdateColumn("password", newPass)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("更新数据库失败: ", res.Error)
|
logger.Error("更新数据库失败: ", res.Error)
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
@@ -359,7 +355,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Where("username", data.Username).First(&user)
|
res := h.DB.Where("username", data.Username).First(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "用户不存在!")
|
resp.ERROR(c, "用户不存在!")
|
||||||
return
|
return
|
||||||
@@ -375,7 +371,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
|||||||
|
|
||||||
password := utils.GenPassword(data.Password, user.Salt)
|
password := utils.GenPassword(data.Password, user.Salt)
|
||||||
user.Password = password
|
user.Password = password
|
||||||
res = h.db.Updates(&user)
|
res = h.DB.Updates(&user)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c)
|
resp.ERROR(c)
|
||||||
} else {
|
} else {
|
||||||
@@ -405,19 +401,19 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
|
|||||||
|
|
||||||
// 检查手机号是否被其他账号绑定
|
// 检查手机号是否被其他账号绑定
|
||||||
var item model.User
|
var item model.User
|
||||||
res := h.db.Where("username = ?", data.Username).First(&item)
|
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
resp.ERROR(c, "该账号已经被其他账号绑定")
|
resp.ERROR(c, "该账号已经被其他账号绑定")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := h.GetLoginUser(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = h.db.Model(&user).UpdateColumn("username", data.Username)
|
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败")
|
resp.ERROR(c, "更新数据库失败")
|
||||||
return
|
return
|
||||||
|
|||||||
64
api/main.go
64
api/main.go
@@ -125,6 +125,8 @@ func main() {
|
|||||||
fx.Provide(handler.NewPaymentHandler),
|
fx.Provide(handler.NewPaymentHandler),
|
||||||
fx.Provide(handler.NewOrderHandler),
|
fx.Provide(handler.NewOrderHandler),
|
||||||
fx.Provide(handler.NewProductHandler),
|
fx.Provide(handler.NewProductHandler),
|
||||||
|
fx.Provide(handler.NewConfigHandler),
|
||||||
|
fx.Provide(handler.NewPowerLogHandler),
|
||||||
|
|
||||||
fx.Provide(admin.NewConfigHandler),
|
fx.Provide(admin.NewConfigHandler),
|
||||||
fx.Provide(admin.NewAdminHandler),
|
fx.Provide(admin.NewAdminHandler),
|
||||||
@@ -137,6 +139,7 @@ func main() {
|
|||||||
fx.Provide(admin.NewProductHandler),
|
fx.Provide(admin.NewProductHandler),
|
||||||
fx.Provide(admin.NewOrderHandler),
|
fx.Provide(admin.NewOrderHandler),
|
||||||
fx.Provide(admin.NewChatHandler),
|
fx.Provide(admin.NewChatHandler),
|
||||||
|
fx.Provide(admin.NewPowerLogHandler),
|
||||||
|
|
||||||
// 创建服务
|
// 创建服务
|
||||||
fx.Provide(sms.NewSendServiceManager),
|
fx.Provide(sms.NewSendServiceManager),
|
||||||
@@ -172,6 +175,12 @@ func main() {
|
|||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewServicePool),
|
fx.Provide(sd.NewServicePool),
|
||||||
|
fx.Invoke(func(pool *sd.ServicePool) {
|
||||||
|
if pool.HasAvailableService() {
|
||||||
|
pool.CheckTaskNotify()
|
||||||
|
pool.CheckTaskStatus()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
fx.Provide(payment.NewAlipayService),
|
fx.Provide(payment.NewAlipayService),
|
||||||
fx.Provide(payment.NewHuPiPay),
|
fx.Provide(payment.NewHuPiPay),
|
||||||
@@ -229,6 +238,8 @@ func main() {
|
|||||||
group := s.Engine.Group("/api/captcha/")
|
group := s.Engine.Group("/api/captcha/")
|
||||||
group.GET("get", h.Get)
|
group.GET("get", h.Get)
|
||||||
group.POST("check", h.Check)
|
group.POST("check", h.Check)
|
||||||
|
group.GET("slide/get", h.SlideGet)
|
||||||
|
group.POST("slide/check", h.SlideCheck)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
||||||
group := s.Engine.Group("/api/reward/")
|
group := s.Engine.Group("/api/reward/")
|
||||||
@@ -241,17 +252,23 @@ func main() {
|
|||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
group.POST("notify", h.Notify)
|
|
||||||
group.POST("publish", h.Publish)
|
group.POST("publish", h.Publish)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||||
group := s.Engine.Group("/api/sd")
|
group := s.Engine.Group("/api/sd")
|
||||||
|
group.Any("client", h.Client)
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
group.POST("publish", h.Publish)
|
group.POST("publish", h.Publish)
|
||||||
}),
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||||
|
group := s.Engine.Group("/api/config/")
|
||||||
|
group.GET("get", h.Get)
|
||||||
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||||
@@ -264,13 +281,18 @@ func main() {
|
|||||||
group.POST("login", h.Login)
|
group.POST("login", h.Login)
|
||||||
group.GET("logout", h.Logout)
|
group.GET("logout", h.Logout)
|
||||||
group.GET("session", h.Session)
|
group.GET("session", h.Session)
|
||||||
|
group.GET("list", h.List)
|
||||||
|
group.POST("save", h.Save)
|
||||||
|
group.POST("enable", h.Enable)
|
||||||
|
group.GET("remove", h.Remove)
|
||||||
|
group.POST("resetPass", h.ResetPass)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
||||||
group := s.Engine.Group("/api/admin/apikey/")
|
group := s.Engine.Group("/api/admin/apikey/")
|
||||||
group.POST("save", h.Save)
|
group.POST("save", h.Save)
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.POST("set", h.Set)
|
group.POST("set", h.Set)
|
||||||
group.GET("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
||||||
group := s.Engine.Group("/api/admin/user/")
|
group := s.Engine.Group("/api/admin/user/")
|
||||||
@@ -286,12 +308,12 @@ func main() {
|
|||||||
group.POST("save", h.Save)
|
group.POST("save", h.Save)
|
||||||
group.POST("sort", h.Sort)
|
group.POST("sort", h.Sort)
|
||||||
group.POST("set", h.Set)
|
group.POST("set", h.Set)
|
||||||
group.GET("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
|
||||||
group := s.Engine.Group("/api/admin/reward/")
|
group := s.Engine.Group("/api/admin/reward/")
|
||||||
group.GET("list", h.List)
|
group.GET("list", h.List)
|
||||||
group.GET("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||||
group := s.Engine.Group("/api/admin/dashboard/")
|
group := s.Engine.Group("/api/admin/dashboard/")
|
||||||
@@ -315,6 +337,7 @@ func main() {
|
|||||||
group.GET("payWays", h.GetPayWays)
|
group.GET("payWays", h.GetPayWays)
|
||||||
group.POST("query", h.OrderQuery)
|
group.POST("query", h.OrderQuery)
|
||||||
group.POST("qrcode", h.PayQrcode)
|
group.POST("qrcode", h.PayQrcode)
|
||||||
|
group.POST("mobile", h.Mobile)
|
||||||
group.POST("alipay/notify", h.AlipayNotify)
|
group.POST("alipay/notify", h.AlipayNotify)
|
||||||
group.POST("hupipay/notify", h.HuPiPayNotify)
|
group.POST("hupipay/notify", h.HuPiPayNotify)
|
||||||
group.POST("payjs/notify", h.PayJsNotify)
|
group.POST("payjs/notify", h.PayJsNotify)
|
||||||
@@ -349,13 +372,6 @@ func main() {
|
|||||||
group.GET("hits", h.Hits)
|
group.GET("hits", h.Hits)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewPromptHandler),
|
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
|
||||||
group := s.Engine.Group("/api/prompt/")
|
|
||||||
group.POST("rewrite", h.Rewrite)
|
|
||||||
group.POST("translate", h.Translate)
|
|
||||||
}),
|
|
||||||
|
|
||||||
fx.Provide(admin.NewFunctionHandler),
|
fx.Provide(admin.NewFunctionHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
||||||
group := s.Engine.Group("/api/admin/function/")
|
group := s.Engine.Group("/api/admin/function/")
|
||||||
@@ -366,6 +382,18 @@ func main() {
|
|||||||
group.GET("token", h.GenToken)
|
group.GET("token", h.GenToken)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
// 验证码
|
||||||
|
fx.Provide(admin.NewCaptchaHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/login/")
|
||||||
|
group.GET("captcha", h.GetCaptcha)
|
||||||
|
}),
|
||||||
|
|
||||||
|
fx.Provide(admin.NewUploadHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
||||||
|
s.Engine.POST("/api/admin/upload", h.Upload)
|
||||||
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewFunctionHandler),
|
fx.Provide(handler.NewFunctionHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
||||||
group := s.Engine.Group("/api/function/")
|
group := s.Engine.Group("/api/function/")
|
||||||
@@ -381,10 +409,13 @@ func main() {
|
|||||||
group.GET("remove", h.RemoveChat)
|
group.GET("remove", h.RemoveChat)
|
||||||
group.GET("message/remove", h.RemoveMessage)
|
group.GET("message/remove", h.RemoveMessage)
|
||||||
}),
|
}),
|
||||||
fx.Provide(handler.NewTestHandler),
|
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
group := s.Engine.Group("/api/powerLog/")
|
||||||
s.Engine.GET("/api/test", h.Test)
|
group.POST("list", h.List)
|
||||||
s.Engine.POST("/api/test/mj", h.Mj)
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
|
||||||
|
group := s.Engine.Group("/api/admin/powerLog/")
|
||||||
|
group.POST("list", h.List)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||||
err := s.Run(db)
|
err := s.Run(db)
|
||||||
@@ -392,9 +423,6 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(h *chatimpl.ChatHandler) {
|
|
||||||
h.Init()
|
|
||||||
}),
|
|
||||||
// 注册生命周期回调函数
|
// 注册生命周期回调函数
|
||||||
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
|
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
|
||||||
lifecycle.Append(fx.Hook{
|
lifecycle.Append(fx.Hook{
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
{
|
|
||||||
"data": [
|
|
||||||
"task(cxvkpawy8onnfti)",
|
|
||||||
"a cute girl",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
20,
|
|
||||||
"DPM++ 2M Karras",
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
7,
|
|
||||||
512,
|
|
||||||
512,
|
|
||||||
false,
|
|
||||||
0.7,
|
|
||||||
2,
|
|
||||||
"Latent",
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
"Use same checkpoint",
|
|
||||||
"Use same sampler",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"None",
|
|
||||||
false,
|
|
||||||
"",
|
|
||||||
0.8,
|
|
||||||
-1,
|
|
||||||
false,
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"positive",
|
|
||||||
"comma",
|
|
||||||
0,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"",
|
|
||||||
"Seed",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"Nothing",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"Nothing",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
50,
|
|
||||||
[],
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
""
|
|
||||||
],
|
|
||||||
"event_data": null,
|
|
||||||
"fn_index": 446,
|
|
||||||
"session_hash": "nk5noh1rz1o"
|
|
||||||
}
|
|
||||||
@@ -60,3 +60,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||||
|
if s.config.Token == "" {
|
||||||
|
return nil, errors.New("无效的 API Token")
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
|
||||||
|
var res types.BizVo
|
||||||
|
r, err := s.client.R().
|
||||||
|
SetHeader("AppId", s.config.AppId).
|
||||||
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||||
|
SetSuccessResult(&res).Get(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaptchaService) SlideCheck(data interface{}) bool {
|
||||||
|
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
|
||||||
|
var res types.BizVo
|
||||||
|
r, err := s.client.R().
|
||||||
|
SetHeader("AppId", s.config.AppId).
|
||||||
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||||
|
SetBodyJsonMarshal(data).
|
||||||
|
SetSuccessResult(&res).Post(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,233 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/utils"
|
|
||||||
discordgo "github.com/bg5t/mydiscordgo"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MidJourney 机器人
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
type Bot struct {
|
|
||||||
config types.MidJourneyConfig
|
|
||||||
bot *discordgo.Session
|
|
||||||
name string
|
|
||||||
service *Service
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
|
|
||||||
bot, err := discordgo.New("Bot " + config.BotToken)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// use CDN reverse proxy
|
|
||||||
if config.UseCDN {
|
|
||||||
discordgo.SetEndpointDiscord(config.DiscordAPI)
|
|
||||||
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
|
|
||||||
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
|
|
||||||
bot.MjGateway = config.DiscordGateway + "/"
|
|
||||||
} else { // use proxy
|
|
||||||
discordgo.SetEndpointDiscord("https://discord.com")
|
|
||||||
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
|
|
||||||
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
|
|
||||||
bot.MjGateway = "wss://gateway.discord.gg"
|
|
||||||
|
|
||||||
if proxy != "" {
|
|
||||||
proxy, _ := url.Parse(proxy)
|
|
||||||
bot.Client = &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
bot.Dialer = &websocket.Dialer{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Bot{
|
|
||||||
config: config,
|
|
||||||
bot: bot,
|
|
||||||
name: name,
|
|
||||||
service: service,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) Run() error {
|
|
||||||
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
|
|
||||||
b.bot.AddHandler(b.messageCreate)
|
|
||||||
b.bot.AddHandler(b.messageUpdate)
|
|
||||||
|
|
||||||
logger.Infof("Starting MidJourney %s", b.name)
|
|
||||||
err := b.bot.Open()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Infof("Starting MidJourney %s successfully!", b.name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TaskStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Start = TaskStatus("Started")
|
|
||||||
Running = TaskStatus("Running")
|
|
||||||
Stopped = TaskStatus("Stopped")
|
|
||||||
Finished = TaskStatus("Finished")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Image struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
ProxyURL string `json:"proxy_url"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Width int `json:"width"`
|
|
||||||
Height int `json:"height"`
|
|
||||||
Size int `json:"size"`
|
|
||||||
Hash string `json:"hash"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|
||||||
// ignore messages for other channels
|
|
||||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ignore messages for self
|
|
||||||
if m.Author == nil || m.Author.ID == s.State.User.ID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
|
|
||||||
var referenceId = ""
|
|
||||||
if m.ReferencedMessage != nil {
|
|
||||||
referenceId = m.ReferencedMessage.ID
|
|
||||||
}
|
|
||||||
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
|
||||||
// parse content
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: m.ChannelID,
|
|
||||||
MessageId: m.ID,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Prompt: extractPrompt(m.Content),
|
|
||||||
Content: m.Content,
|
|
||||||
Progress: 0,
|
|
||||||
Status: Start}
|
|
||||||
b.service.Notify(req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
|
||||||
// ignore messages for other channels
|
|
||||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ignore messages for self
|
|
||||||
if m.Author == nil || m.Author.ID == s.State.User.ID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
|
|
||||||
|
|
||||||
var referenceId = ""
|
|
||||||
if m.ReferencedMessage != nil {
|
|
||||||
referenceId = m.ReferencedMessage.ID
|
|
||||||
}
|
|
||||||
if strings.Contains(m.Content, "(Stopped)") {
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: m.ChannelID,
|
|
||||||
MessageId: m.ID,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Prompt: extractPrompt(m.Content),
|
|
||||||
Content: m.Content,
|
|
||||||
Progress: extractProgress(m.Content),
|
|
||||||
Status: Stopped}
|
|
||||||
b.service.Notify(req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
|
||||||
progress := extractProgress(content)
|
|
||||||
var status TaskStatus
|
|
||||||
if progress == 100 {
|
|
||||||
status = Finished
|
|
||||||
} else {
|
|
||||||
status = Running
|
|
||||||
}
|
|
||||||
for _, attachment := range attachments {
|
|
||||||
if attachment.Width == 0 || attachment.Height == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
image := Image{
|
|
||||||
URL: attachment.URL,
|
|
||||||
Height: attachment.Height,
|
|
||||||
ProxyURL: attachment.ProxyURL,
|
|
||||||
Width: attachment.Width,
|
|
||||||
Size: attachment.Size,
|
|
||||||
Filename: attachment.Filename,
|
|
||||||
Hash: extractHashFromFilename(attachment.Filename),
|
|
||||||
}
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: channelId,
|
|
||||||
MessageId: messageId,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Image: image,
|
|
||||||
Prompt: extractPrompt(content),
|
|
||||||
Content: content,
|
|
||||||
Progress: progress,
|
|
||||||
Status: status,
|
|
||||||
}
|
|
||||||
b.service.Notify(req)
|
|
||||||
break // only get one image
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract prompt from string
|
|
||||||
func extractPrompt(input string) string {
|
|
||||||
pattern := `\*\*(.*?)\*\*`
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(input)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return strings.TrimSpace(matches[1])
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractProgress(input string) int {
|
|
||||||
pattern := `\((\d+)\%\)`
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(input)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return utils.IntValue(matches[1], 0)
|
|
||||||
}
|
|
||||||
return 100
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractHashFromFilename(filename string) string {
|
|
||||||
if !strings.HasSuffix(filename, ".png") {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
index := strings.LastIndex(filename, "_")
|
|
||||||
if index != -1 {
|
|
||||||
return filename[index+1 : len(filename)-4]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,159 +1,61 @@
|
|||||||
package mj
|
package mj
|
||||||
|
|
||||||
import (
|
import "chatplus/core/types"
|
||||||
"chatplus/core/types"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
type Client interface {
|
||||||
)
|
Imagine(task types.MjTask) (ImageRes, error)
|
||||||
|
Blend(task types.MjTask) (ImageRes, error)
|
||||||
// MidJourney client
|
SwapFace(task types.MjTask) (ImageRes, error)
|
||||||
|
Upscale(task types.MjTask) (ImageRes, error)
|
||||||
type Client struct {
|
Variation(task types.MjTask) (ImageRes, error)
|
||||||
client *req.Client
|
QueryTask(taskId string) (QueryRes, error)
|
||||||
Config types.MidJourneyConfig
|
|
||||||
apiURL string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
type ImageReq struct {
|
||||||
client := req.C().SetTimeout(10 * time.Second)
|
BotType string `json:"botType,omitempty"`
|
||||||
var apiURL string
|
Prompt string `json:"prompt,omitempty"`
|
||||||
// set proxy URL
|
Dimensions string `json:"dimensions,omitempty"`
|
||||||
if config.UseCDN {
|
Base64Array []string `json:"base64Array,omitempty"`
|
||||||
apiURL = config.DiscordAPI + "/api/v9/interactions"
|
AccountFilter interface{} `json:"accountFilter,omitempty"`
|
||||||
} else {
|
NotifyHook string `json:"notifyHook,omitempty"`
|
||||||
apiURL = "https://discord.com/api/v9/interactions"
|
State string `json:"state,omitempty"`
|
||||||
if proxy != "" {
|
|
||||||
client.SetProxyURL(proxy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{client: client, Config: config, apiURL: apiURL}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Imagine(task types.MjTask) error {
|
type ImageRes struct {
|
||||||
interactionsReq := &InteractionsRequest{
|
Code int `json:"code"`
|
||||||
Type: 2,
|
Description string `json:"description"`
|
||||||
ApplicationID: ApplicationID,
|
Properties struct {
|
||||||
GuildID: c.Config.GuildId,
|
} `json:"properties"`
|
||||||
ChannelID: c.Config.ChanelId,
|
Result string `json:"result"`
|
||||||
SessionID: SessionID,
|
|
||||||
Data: map[string]any{
|
|
||||||
"version": "1166847114203123795",
|
|
||||||
"id": "938956540159881230",
|
|
||||||
"name": "imagine",
|
|
||||||
"type": "1",
|
|
||||||
"options": []map[string]any{
|
|
||||||
{
|
|
||||||
"type": 3,
|
|
||||||
"name": "prompt",
|
|
||||||
"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"application_command": map[string]any{
|
|
||||||
"id": "938956540159881230",
|
|
||||||
"application_id": ApplicationID,
|
|
||||||
"version": "1118961510123847772",
|
|
||||||
"default_permission": true,
|
|
||||||
"default_member_permissions": nil,
|
|
||||||
"type": 1,
|
|
||||||
"nsfw": false,
|
|
||||||
"name": "imagine",
|
|
||||||
"description": "Create images with Midjourney",
|
|
||||||
"dm_permission": true,
|
|
||||||
"options": []map[string]any{
|
|
||||||
{
|
|
||||||
"type": 3,
|
|
||||||
"name": "prompt",
|
|
||||||
"description": "The prompt to imagine",
|
|
||||||
"required": true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"attachments": []any{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
Post(c.apiURL)
|
|
||||||
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %w%v", err, r.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Blend(task types.MjTask) error {
|
type ErrRes struct {
|
||||||
return errors.New("function not implemented")
|
Error struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SwapFace(task types.MjTask) error {
|
type QueryRes struct {
|
||||||
return errors.New("function not implemented")
|
Action string `json:"action"`
|
||||||
}
|
Buttons []struct {
|
||||||
|
CustomId string `json:"customId"`
|
||||||
// Upscale 放大指定的图片
|
Emoji string `json:"emoji"`
|
||||||
func (c *Client) Upscale(task types.MjTask) error {
|
Label string `json:"label"`
|
||||||
flags := 0
|
Style int `json:"style"`
|
||||||
interactionsReq := &InteractionsRequest{
|
Type int `json:"type"`
|
||||||
Type: 3,
|
} `json:"buttons"`
|
||||||
ApplicationID: ApplicationID,
|
Description string `json:"description"`
|
||||||
GuildID: c.Config.GuildId,
|
FailReason string `json:"failReason"`
|
||||||
ChannelID: c.Config.ChanelId,
|
FinishTime int `json:"finishTime"`
|
||||||
MessageFlags: flags,
|
Id string `json:"id"`
|
||||||
MessageID: task.MessageId,
|
ImageUrl string `json:"imageUrl"`
|
||||||
SessionID: SessionID,
|
Progress string `json:"progress"`
|
||||||
Data: map[string]any{
|
Prompt string `json:"prompt"`
|
||||||
"component_type": 2,
|
PromptEn string `json:"promptEn"`
|
||||||
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
Properties struct {
|
||||||
},
|
} `json:"properties"`
|
||||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
StartTime int `json:"startTime"`
|
||||||
}
|
State string `json:"state"`
|
||||||
|
Status string `json:"status"`
|
||||||
var res InteractionsResult
|
SubmitTime int `json:"submitTime"`
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
SetErrorResult(&res).
|
|
||||||
Post(c.apiURL)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
|
||||||
func (c *Client) Variation(task types.MjTask) error {
|
|
||||||
flags := 0
|
|
||||||
interactionsReq := &InteractionsRequest{
|
|
||||||
Type: 3,
|
|
||||||
ApplicationID: ApplicationID,
|
|
||||||
GuildID: c.Config.GuildId,
|
|
||||||
ChannelID: c.Config.ChanelId,
|
|
||||||
MessageFlags: flags,
|
|
||||||
MessageID: task.MessageId,
|
|
||||||
SessionID: SessionID,
|
|
||||||
Data: map[string]any{
|
|
||||||
"component_type": 2,
|
|
||||||
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
|
||||||
},
|
|
||||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
||||||
}
|
|
||||||
|
|
||||||
var res InteractionsResult
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
SetErrorResult(&res).
|
|
||||||
Post(c.apiURL)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,204 +0,0 @@
|
|||||||
package plus
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Service MJ 绘画服务
|
|
||||||
type Service struct {
|
|
||||||
Name string // service Name
|
|
||||||
Client *Client // MJ Client
|
|
||||||
taskQueue *store.RedisQueue
|
|
||||||
notifyQueue *store.RedisQueue
|
|
||||||
db *gorm.DB
|
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
|
||||||
HandledTaskNum int32 // already handled task number
|
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
||||||
taskTimeout int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
|
||||||
return &Service{
|
|
||||||
Name: name,
|
|
||||||
db: db,
|
|
||||||
taskQueue: taskQueue,
|
|
||||||
notifyQueue: notifyQueue,
|
|
||||||
Client: client,
|
|
||||||
taskTimeout: timeout,
|
|
||||||
maxHandleTaskNum: maxTaskNum,
|
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Run() {
|
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
|
||||||
for {
|
|
||||||
s.checkTasks()
|
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.MjTask
|
|
||||||
err := s.taskQueue.LPop(&task)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("taking task with error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// if it's reference message, check if it's this channel's message
|
|
||||||
//if task.ChannelId != "" && task.ChannelId != s.Name {
|
|
||||||
// logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
|
||||||
// s.taskQueue.RPush(task)
|
|
||||||
// time.Sleep(time.Second)
|
|
||||||
// continue
|
|
||||||
//}
|
|
||||||
|
|
||||||
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
|
||||||
var res ImageRes
|
|
||||||
switch task.Type {
|
|
||||||
case types.TaskImage:
|
|
||||||
res, err = s.Client.Imagine(task)
|
|
||||||
break
|
|
||||||
case types.TaskUpscale:
|
|
||||||
res, err = s.Client.Upscale(task)
|
|
||||||
break
|
|
||||||
case types.TaskVariation:
|
|
||||||
res, err = s.Client.Variation(task)
|
|
||||||
break
|
|
||||||
case types.TaskBlend:
|
|
||||||
res, err = s.Client.Blend(task)
|
|
||||||
break
|
|
||||||
case types.TaskSwapFace:
|
|
||||||
res, err = s.Client.SwapFace(task)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
s.db.Where("id = ?", task.Id).First(&job)
|
|
||||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
|
||||||
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
|
||||||
logger.Error("绘画任务执行失败:", errMsg)
|
|
||||||
job.Progress = -1
|
|
||||||
job.ErrMsg = errMsg
|
|
||||||
// update the task progress
|
|
||||||
s.db.Updates(&job)
|
|
||||||
// 任务失败,通知前端
|
|
||||||
s.notifyQueue.RPush(task.UserId)
|
|
||||||
// restore img_call quota
|
|
||||||
if task.Type.String() != types.TaskUpscale.String() {
|
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: 任务提交失败,加入队列重试
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Infof("任务提交成功:%+v", res)
|
|
||||||
// lock the task until the execute timeout
|
|
||||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
|
||||||
// 更新任务 ID/频道
|
|
||||||
job.TaskId = res.Result
|
|
||||||
job.ChannelId = s.Name
|
|
||||||
s.db.Updates(&job)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
|
||||||
func (s *Service) canHandleTask() bool {
|
|
||||||
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
|
||||||
return handledNum < s.maxHandleTaskNum
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the expired tasks
|
|
||||||
func (s *Service) checkTasks() {
|
|
||||||
for k, t := range s.taskStartTimes {
|
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
|
||||||
delete(s.taskStartTimes, k)
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
|
||||||
// delete task from database
|
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CBReq struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
SubmitTime int64 `json:"submitTime"`
|
|
||||||
StartTime int64 `json:"startTime"`
|
|
||||||
FinishTime int64 `json:"finishTime"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
FailReason interface{} `json:"failReason"`
|
|
||||||
Properties struct {
|
|
||||||
FinalPrompt string `json:"finalPrompt"`
|
|
||||||
} `json:"properties"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Notify(job model.MidJourneyJob) error {
|
|
||||||
task, err := s.Client.QueryTask(job.TaskId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 任务执行失败了
|
|
||||||
if task.FailReason != "" {
|
|
||||||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": task.FailReason,
|
|
||||||
})
|
|
||||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(task.Buttons) > 0 {
|
|
||||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
|
||||||
}
|
|
||||||
oldProgress := job.Progress
|
|
||||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
|
||||||
job.Prompt = task.PromptEn
|
|
||||||
if task.ImageUrl != "" {
|
|
||||||
if s.Client.Config.CdnURL != "" {
|
|
||||||
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
|
|
||||||
} else {
|
|
||||||
job.OrgURL = task.ImageUrl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
job.MessageId = task.Id
|
|
||||||
tx := s.db.Updates(&job)
|
|
||||||
if tx.Error != nil {
|
|
||||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
|
||||||
}
|
|
||||||
if task.Status == "SUCCESS" {
|
|
||||||
// release lock task
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
|
||||||
}
|
|
||||||
// 通知前端更新任务进度
|
|
||||||
if oldProgress != job.Progress {
|
|
||||||
s.notifyQueue.RPush(job.UserId)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetImageHash(action string) string {
|
|
||||||
split := strings.Split(action, "::")
|
|
||||||
if len(split) > 5 {
|
|
||||||
return split[4]
|
|
||||||
}
|
|
||||||
return split[len(split)-1]
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
package plus
|
package mj
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -13,62 +12,21 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
// PlusClient MidJourney Plus ProxyClient
|
||||||
|
type PlusClient struct {
|
||||||
// Client MidJourney Plus Client
|
Config types.MjPlusConfig
|
||||||
type Client struct {
|
|
||||||
Config types.MidJourneyPlusConfig
|
|
||||||
apiURL string
|
apiURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
|
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
|
||||||
var apiURL string
|
return &PlusClient{Config: config, apiURL: config.ApiURL}
|
||||||
if config.CdnURL != "" {
|
|
||||||
apiURL = config.CdnURL
|
|
||||||
} else {
|
|
||||||
apiURL = config.ApiURL
|
|
||||||
}
|
|
||||||
if config.Mode == "" {
|
|
||||||
config.Mode = "fast"
|
|
||||||
}
|
|
||||||
return &Client{Config: config, apiURL: apiURL}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageReq struct {
|
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
BotType string `json:"botType"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Dimensions string `json:"dimensions,omitempty"`
|
|
||||||
Base64Array []string `json:"base64Array,omitempty"`
|
|
||||||
AccountFilter struct {
|
|
||||||
InstanceId string `json:"instanceId"`
|
|
||||||
Modes []interface{} `json:"modes"`
|
|
||||||
Remix bool `json:"remix"`
|
|
||||||
RemixAutoConsidered bool `json:"remixAutoConsidered"`
|
|
||||||
} `json:"accountFilter,omitempty"`
|
|
||||||
NotifyHook string `json:"notifyHook"`
|
|
||||||
State string `json:"state,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageRes struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Properties struct {
|
|
||||||
} `json:"properties"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrRes struct {
|
|
||||||
Error struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
||||||
body := ImageReq{
|
body := ImageReq{
|
||||||
BotType: "MID_JOURNEY",
|
BotType: "MID_JOURNEY",
|
||||||
Prompt: task.Prompt,
|
Prompt: task.Prompt,
|
||||||
NotifyHook: c.Config.NotifyURL,
|
|
||||||
Base64Array: make([]string, 0),
|
Base64Array: make([]string, 0),
|
||||||
}
|
}
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
@@ -81,6 +39,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
var errRes ErrRes
|
var errRes ErrRes
|
||||||
r, err := req.C().R().
|
r, err := req.C().R().
|
||||||
@@ -90,9 +49,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr, _ := io.ReadAll(r.Body)
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.IsErrorState() {
|
if r.IsErrorState() {
|
||||||
@@ -104,12 +61,11 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Blend 融图
|
// Blend 融图
|
||||||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
||||||
body := ImageReq{
|
body := ImageReq{
|
||||||
BotType: "MID_JOURNEY",
|
BotType: "MID_JOURNEY",
|
||||||
Dimensions: "SQUARE",
|
Dimensions: "SQUARE",
|
||||||
NotifyHook: c.Config.NotifyURL,
|
|
||||||
Base64Array: make([]string, 0),
|
Base64Array: make([]string, 0),
|
||||||
}
|
}
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
@@ -132,8 +88,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
|||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr, _ := io.ReadAll(r.Body)
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.IsErrorState() {
|
if r.IsErrorState() {
|
||||||
@@ -144,7 +99,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SwapFace 换脸
|
// SwapFace 换脸
|
||||||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
if len(task.ImgArr) != 2 {
|
if len(task.ImgArr) != 2 {
|
||||||
@@ -171,8 +126,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
"accountFilter": gin.H{
|
"accountFilter": gin.H{
|
||||||
"instanceId": "",
|
"instanceId": "",
|
||||||
},
|
},
|
||||||
"notifyHook": c.Config.NotifyURL,
|
"state": "",
|
||||||
"state": "",
|
|
||||||
}
|
}
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
var errRes ErrRes
|
var errRes ErrRes
|
||||||
@@ -183,8 +137,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr, _ := io.ReadAll(r.Body)
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.IsErrorState() {
|
if r.IsErrorState() {
|
||||||
@@ -195,11 +148,10 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upscale 放大指定的图片
|
// Upscale 放大指定的图片
|
||||||
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
body := map[string]string{
|
body := map[string]string{
|
||||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||||
"taskId": task.MessageId,
|
"taskId": task.MessageId,
|
||||||
"notifyHook": c.Config.NotifyURL,
|
|
||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
@@ -222,11 +174,10 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
body := map[string]string{
|
body := map[string]string{
|
||||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||||
"taskId": task.MessageId,
|
"taskId": task.MessageId,
|
||||||
"notifyHook": c.Config.NotifyURL,
|
|
||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
@@ -248,32 +199,7 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryRes struct {
|
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
|
||||||
Action string `json:"action"`
|
|
||||||
Buttons []struct {
|
|
||||||
CustomId string `json:"customId"`
|
|
||||||
Emoji string `json:"emoji"`
|
|
||||||
Label string `json:"label"`
|
|
||||||
Style int `json:"style"`
|
|
||||||
Type int `json:"type"`
|
|
||||||
} `json:"buttons"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
FailReason string `json:"failReason"`
|
|
||||||
FinishTime int `json:"finishTime"`
|
|
||||||
Id string `json:"id"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Properties struct {
|
|
||||||
} `json:"properties"`
|
|
||||||
StartTime int `json:"startTime"`
|
|
||||||
State string `json:"state"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
SubmitTime int `json:"submitTime"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||||
var res QueryRes
|
var res QueryRes
|
||||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
@@ -290,3 +216,5 @@ func (c *Client) QueryTask(taskId string) (QueryRes, error) {
|
|||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Client = &PlusClient{}
|
||||||
@@ -2,13 +2,12 @@ package mj
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service/mj/plus"
|
logger2 "chatplus/logger"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -16,7 +15,7 @@ import (
|
|||||||
|
|
||||||
// ServicePool Mj service pool
|
// ServicePool Mj service pool
|
||||||
type ServicePool struct {
|
type ServicePool struct {
|
||||||
services []interface{}
|
services []*Service
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@@ -24,8 +23,10 @@ type ServicePool struct {
|
|||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||||
services := make([]interface{}, 0)
|
services := make([]*Service, 0)
|
||||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
||||||
|
|
||||||
@@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := plus.NewClient(config)
|
cli := NewPlusClient(config)
|
||||||
name := fmt.Sprintf("mj-service-plus-%d", k)
|
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||||
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
|
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||||
go func() {
|
go func() {
|
||||||
servicePlus.Run()
|
service.Run()
|
||||||
}()
|
}()
|
||||||
services = append(services, servicePlus)
|
services = append(services, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(services) == 0 {
|
for k, config := range appConfig.MjProxyConfigs {
|
||||||
// create mj client and service
|
if config.Enabled == false {
|
||||||
for k, config := range appConfig.MjConfigs {
|
continue
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// create mj client
|
|
||||||
client := NewClient(config, appConfig.ProxyURL)
|
|
||||||
|
|
||||||
name := fmt.Sprintf("MjService-%d", k)
|
|
||||||
// create mj service
|
|
||||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
|
|
||||||
botName := fmt.Sprintf("MjBot-%d", k)
|
|
||||||
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = bot.Run()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// run mj service
|
|
||||||
go func() {
|
|
||||||
service.Run()
|
|
||||||
}()
|
|
||||||
|
|
||||||
services = append(services, service)
|
|
||||||
}
|
}
|
||||||
|
cli := NewProxyClient(config)
|
||||||
|
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||||
|
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||||
|
go func() {
|
||||||
|
service.Run()
|
||||||
|
}()
|
||||||
|
services = append(services, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ServicePool{
|
return &ServicePool{
|
||||||
@@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := p.Clients.Get(userId)
|
cli := p.Clients.Get(userId)
|
||||||
if client == nil {
|
if cli == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte("Task Updated"))
|
err = cli.Send([]byte("Task Updated"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() {
|
|||||||
logger.Infof("try to download image: %s", v.OrgURL)
|
logger.Infof("try to download image: %s", v.OrgURL)
|
||||||
var imgURL string
|
var imgURL string
|
||||||
var err error
|
var err error
|
||||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
|
||||||
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
||||||
if len(task.Buttons) > 0 {
|
if len(task.Buttons) > 0 {
|
||||||
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
|
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
}
|
}
|
||||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
||||||
} else {
|
} else {
|
||||||
@@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() {
|
|||||||
v.ImgURL = imgURL
|
v.ImgURL = imgURL
|
||||||
p.db.Updates(&v)
|
p.db.Updates(&v)
|
||||||
|
|
||||||
client := p.Clients.Get(uint(v.UserId))
|
cli := p.Clients.Get(uint(v.UserId))
|
||||||
if client == nil {
|
if cli == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte("Task Updated"))
|
err = cli.Send([]byte("Task Updated"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -167,52 +149,42 @@ func (p *ServicePool) HasAvailableService() bool {
|
|||||||
return len(p.services) > 0
|
return len(p.services) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ServicePool) Notify(data plus.CBReq) error {
|
|
||||||
logger.Debugf("收到任务回调:%+v", data)
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
res := p.db.Where("task_id = ?", data.Id).First(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
return fmt.Errorf("非法任务:%s", data.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 任务已经拉取完成
|
|
||||||
if job.Progress == 100 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
|
||||||
return servicePlus.Notify(job)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncTaskProgress 异步拉取任务
|
// SyncTaskProgress 异步拉取任务
|
||||||
func (p *ServicePool) SyncTaskProgress() {
|
func (p *ServicePool) SyncTaskProgress() {
|
||||||
go func() {
|
go func() {
|
||||||
var items []model.MidJourneyJob
|
var items []model.MidJourneyJob
|
||||||
for {
|
for {
|
||||||
res := p.db.Where("progress >= ? AND progress < ?", 0, 100).Find(&items)
|
res := p.db.Where("progress < ?", 100).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range items {
|
for _, job := range items {
|
||||||
// 30 分钟还没完成的任务直接删除
|
// 失败或者 30 分钟还没完成的任务删除并退回算力
|
||||||
if time.Now().Sub(v.CreatedAt) > time.Minute*30 {
|
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
|
||||||
p.db.Delete(&v)
|
// 删除任务
|
||||||
// 非放大任务,退回绘图次数
|
p.db.Delete(&job)
|
||||||
if v.Type != types.TaskUpscale.String() {
|
// 退回算力
|
||||||
p.db.Model(&model.User{}).Where("id = ?", v.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
|
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||||
|
var user model.User
|
||||||
|
p.db.Where("id = ?", job.UserId).First(&user)
|
||||||
|
p.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power + job.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "mid-journey",
|
||||||
|
Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(v.ChannelId, "mj-service-plus") {
|
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
|
||||||
continue
|
_ = servicePlus.Notify(job)
|
||||||
}
|
|
||||||
|
|
||||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
|
||||||
_ = servicePlus.Notify(v)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,12 +193,10 @@ func (p *ServicePool) SyncTaskProgress() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ServicePool) getServicePlus(name string) *plus.Service {
|
func (p *ServicePool) getService(name string) *Service {
|
||||||
for _, s := range p.services {
|
for _, s := range p.services {
|
||||||
if servicePlus, ok := s.(*plus.Service); ok {
|
if s.Name == name {
|
||||||
if servicePlus.Name == name {
|
return s
|
||||||
return servicePlus
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
176
api/service/mj/proxy_client.go
Normal file
176
api/service/mj/proxy_client.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/utils"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyClient MidJourney Proxy Client
|
||||||
|
type ProxyClient struct {
|
||||||
|
Config types.MjProxyConfig
|
||||||
|
apiURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
|
||||||
|
return &ProxyClient{Config: config, apiURL: config.ApiURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
|
||||||
|
body := ImageReq{
|
||||||
|
Prompt: task.Prompt,
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
all, err := io.ReadAll(r.Body)
|
||||||
|
logger.Info(string(all))
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
errStr, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blend 融图
|
||||||
|
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
|
||||||
|
body := ImageReq{
|
||||||
|
Dimensions: "SQUARE",
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
for _, imgURL := range task.ImgArr {
|
||||||
|
imageData, err := utils.DownloadImage(imgURL, "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwapFace 换脸
|
||||||
|
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
|
||||||
|
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale 放大指定的图片
|
||||||
|
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"action": "UPSCALE",
|
||||||
|
"index": task.Index,
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
|
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"action": "VARIATION",
|
||||||
|
"index": task.Index,
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||||
|
var res QueryRes
|
||||||
|
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Get(apiURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return QueryRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Client = &ProxyClient{}
|
||||||
@@ -2,8 +2,11 @@ package mj
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
|
"chatplus/utils"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,24 +16,24 @@ import (
|
|||||||
|
|
||||||
// Service MJ 绘画服务
|
// Service MJ 绘画服务
|
||||||
type Service struct {
|
type Service struct {
|
||||||
name string // service name
|
Name string // service Name
|
||||||
client *Client // MJ client
|
Client Client // MJ Client
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
maxHandleTaskNum int32 // max task number current service can handle
|
||||||
handledTaskNum int32 // already handled task number
|
HandledTaskNum int32 // already handled task number
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||||
taskTimeout int64
|
taskTimeout int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
Name: name,
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
client: client,
|
Client: cli,
|
||||||
taskTimeout: timeout,
|
taskTimeout: timeout,
|
||||||
maxHandleTaskNum: maxTaskNum,
|
maxHandleTaskNum: maxTaskNum,
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
taskStartTimes: make(map[int]time.Time, 0),
|
||||||
@@ -38,7 +41,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||||||
for {
|
for {
|
||||||
s.checkTasks()
|
s.checkTasks()
|
||||||
if !s.canHandleTask() {
|
if !s.canHandleTask() {
|
||||||
@@ -55,57 +58,72 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it's reference message, check if it's this channel's message
|
// 如果配置了多个中转平台的 API KEY
|
||||||
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
|
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
|
||||||
|
if task.ChannelId != "" && task.ChannelId != s.Name {
|
||||||
|
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
||||||
s.taskQueue.RPush(task)
|
s.taskQueue.RPush(task)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
// 如果是 mj-proxy 则自动翻译提示词
|
||||||
|
if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
|
||||||
|
if err == nil {
|
||||||
|
task.Prompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
||||||
|
var res ImageRes
|
||||||
switch task.Type {
|
switch task.Type {
|
||||||
case types.TaskImage:
|
case types.TaskImage:
|
||||||
err = s.client.Imagine(task)
|
res, err = s.Client.Imagine(task)
|
||||||
break
|
break
|
||||||
case types.TaskUpscale:
|
case types.TaskUpscale:
|
||||||
err = s.client.Upscale(task)
|
res, err = s.Client.Upscale(task)
|
||||||
break
|
break
|
||||||
case types.TaskVariation:
|
case types.TaskVariation:
|
||||||
err = s.client.Variation(task)
|
res, err = s.Client.Variation(task)
|
||||||
break
|
break
|
||||||
case types.TaskBlend:
|
case types.TaskBlend:
|
||||||
err = s.client.Blend(task)
|
res, err = s.Client.Blend(task)
|
||||||
break
|
break
|
||||||
case types.TaskSwapFace:
|
case types.TaskSwapFace:
|
||||||
err = s.client.SwapFace(task)
|
res, err = s.Client.SwapFace(task)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
var job model.MidJourneyJob
|
||||||
logger.Error("绘画任务执行失败:", err.Error())
|
s.db.Where("id = ?", task.Id).First(&job)
|
||||||
|
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||||
|
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
||||||
|
logger.Error("绘画任务执行失败:", errMsg)
|
||||||
|
job.Progress = -1
|
||||||
|
job.ErrMsg = errMsg
|
||||||
// update the task progress
|
// update the task progress
|
||||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
s.db.Updates(&job)
|
||||||
"progress": -1,
|
// 任务失败,通知前端
|
||||||
"err_msg": err.Error(),
|
|
||||||
})
|
|
||||||
s.notifyQueue.RPush(task.UserId)
|
s.notifyQueue.RPush(task.UserId)
|
||||||
// restore img_call quota
|
|
||||||
if task.Type.String() != types.TaskUpscale.String() {
|
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
logger.Infof("任务提交成功:%+v", res)
|
||||||
// lock the task until the execute timeout
|
// lock the task until the execute timeout
|
||||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
atomic.AddInt32(&s.HandledTaskNum, 1)
|
||||||
|
// 更新任务 ID/频道
|
||||||
|
job.TaskId = res.Result
|
||||||
|
job.ChannelId = s.Name
|
||||||
|
s.db.Updates(&job)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
// check if current service instance can handle more task
|
||||||
func (s *Service) canHandleTask() bool {
|
func (s *Service) canHandleTask() bool {
|
||||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
||||||
return handledNum < s.maxHandleTaskNum
|
return handledNum < s.maxHandleTaskNum
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,64 +132,75 @@ func (s *Service) checkTasks() {
|
|||||||
for k, t := range s.taskStartTimes {
|
for k, t := range s.taskStartTimes {
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||||
delete(s.taskStartTimes, k)
|
delete(s.taskStartTimes, k)
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||||
// delete task from database
|
// delete task from database
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Notify(data CBReq) {
|
type CBReq struct {
|
||||||
// extract the task ID
|
Id string `json:"id"`
|
||||||
split := strings.Split(data.Prompt, " ")
|
Action string `json:"action"`
|
||||||
var job model.MidJourneyJob
|
Status string `json:"status"`
|
||||||
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
Prompt string `json:"prompt"`
|
||||||
if res.Error == nil && data.Status == Finished {
|
PromptEn string `json:"promptEn"`
|
||||||
logger.Warn("重复消息:", data.MessageId)
|
Description string `json:"description"`
|
||||||
return
|
SubmitTime int64 `json:"submitTime"`
|
||||||
}
|
StartTime int64 `json:"startTime"`
|
||||||
|
FinishTime int64 `json:"finishTime"`
|
||||||
tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC")
|
Progress string `json:"progress"`
|
||||||
if data.ReferenceId != "" {
|
ImageUrl string `json:"imageUrl"`
|
||||||
tx = tx.Where("reference_id = ?", data.ReferenceId)
|
FailReason interface{} `json:"failReason"`
|
||||||
} else {
|
Properties struct {
|
||||||
tx = tx.Where("task_id = ?", split[0])
|
FinalPrompt string `json:"finalPrompt"`
|
||||||
}
|
} `json:"properties"`
|
||||||
// fixed: 修复 U/V 操作任务混淆覆盖的 Bug
|
}
|
||||||
if strings.Contains(data.Prompt, "** - Image #") { // for upscale
|
|
||||||
tx = tx.Where("type = ?", types.TaskUpscale.String())
|
func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||||
} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations
|
task, err := s.Client.QueryTask(job.TaskId)
|
||||||
tx = tx.Where("type = ?", types.TaskVariation.String())
|
if err != nil {
|
||||||
}
|
return err
|
||||||
res = tx.First(&job)
|
}
|
||||||
if res.Error != nil {
|
|
||||||
logger.Warn("非法任务:", res.Error)
|
// 任务执行失败了
|
||||||
return
|
if task.FailReason != "" {
|
||||||
}
|
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": -1,
|
||||||
job.ChannelId = data.ChannelId
|
"err_msg": task.FailReason,
|
||||||
job.MessageId = data.MessageId
|
})
|
||||||
job.ReferenceId = data.ReferenceId
|
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||||
job.Progress = data.Progress
|
}
|
||||||
job.Prompt = data.Prompt
|
|
||||||
job.Hash = data.Image.Hash
|
if len(task.Buttons) > 0 {
|
||||||
job.OrgURL = data.Image.URL
|
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
if s.client.Config.UseCDN {
|
}
|
||||||
job.UseProxy = true
|
oldProgress := job.Progress
|
||||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL)
|
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||||
}
|
job.Prompt = task.PromptEn
|
||||||
|
if task.ImageUrl != "" {
|
||||||
res = s.db.Updates(&job)
|
job.OrgURL = task.ImageUrl
|
||||||
if res.Error != nil {
|
}
|
||||||
logger.Error("error with update job: ", res.Error)
|
job.MessageId = task.Id
|
||||||
return
|
tx := s.db.Updates(&job)
|
||||||
}
|
if tx.Error != nil {
|
||||||
|
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||||
if data.Status == Finished {
|
}
|
||||||
// release lock task
|
if task.Status == "SUCCESS" {
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
// release lock task
|
||||||
}
|
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||||
|
}
|
||||||
s.notifyQueue.RPush(job.UserId)
|
// 通知前端更新任务进度
|
||||||
|
if oldProgress != job.Progress {
|
||||||
|
s.notifyQueue.RPush(job.UserId)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageHash(action string) string {
|
||||||
|
split := strings.Split(action, "::")
|
||||||
|
if len(split) > 5 {
|
||||||
|
return split[4]
|
||||||
|
}
|
||||||
|
return split[len(split)-1]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
const (
|
|
||||||
ApplicationID string = "936929561302675456"
|
|
||||||
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InteractionsRequest struct {
|
|
||||||
Type int `json:"type"`
|
|
||||||
ApplicationID string `json:"application_id"`
|
|
||||||
MessageFlags int `json:"message_flags,omitempty"`
|
|
||||||
MessageID string `json:"message_id,omitempty"`
|
|
||||||
GuildID string `json:"guild_id"`
|
|
||||||
ChannelID string `json:"channel_id"`
|
|
||||||
SessionID string `json:"session_id"`
|
|
||||||
Data map[string]any `json:"data"`
|
|
||||||
Nonce string `json:"nonce,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type InteractionsResult struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string
|
|
||||||
Error map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
type CBReq struct {
|
|
||||||
ChannelId string `json:"channel_id"`
|
|
||||||
MessageId string `json:"message_id"`
|
|
||||||
ReferenceId string `json:"reference_id"`
|
|
||||||
Image Image `json:"image"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Status TaskStatus `json:"status"`
|
|
||||||
Progress int `json:"progress"`
|
|
||||||
}
|
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
// 上传文件字节数据
|
||||||
|
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s AliYunOss) Delete(fileURL string) error {
|
func (s AliYunOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ package oss
|
|||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LocalStorage struct {
|
type LocalStorage struct {
|
||||||
@@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||||
|
err = os.WriteFile(filePath, imageData, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error writing to file:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s LocalStorage) Delete(fileURL string) error {
|
func (s LocalStorage) Delete(fileURL string) error {
|
||||||
if _, err := os.Stat(fileURL); err == nil {
|
if _, err := os.Stat(fileURL); err == nil {
|
||||||
return os.Remove(fileURL)
|
return os.Remove(fileURL)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
info, err := s.client.PutObject(
|
||||||
|
context.Background(),
|
||||||
|
s.config.Bucket,
|
||||||
|
objectKey,
|
||||||
|
strings.NewReader(string(imageData)),
|
||||||
|
int64(len(imageData)),
|
||||||
|
minio.PutObjectOptions{ContentType: "image/png"})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s MiniOss) Delete(fileURL string) error {
|
func (s MiniOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
ret := storage.PutRet{}
|
||||||
|
extra := storage.PutExtra{}
|
||||||
|
// 上传文件字节数据
|
||||||
|
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) Delete(fileURL string) error {
|
func (s QinNiuOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
|
|||||||
@@ -17,5 +17,6 @@ type File struct {
|
|||||||
type Uploader interface {
|
type Uploader interface {
|
||||||
PutFile(ctx *gin.Context, name string) (File, error)
|
PutFile(ctx *gin.Context, name string) (File, error)
|
||||||
PutImg(imageURL string, useProxy bool) (string, error)
|
PutImg(imageURL string, useProxy bool) (string, error)
|
||||||
|
PutBase64(imageData string) (string, error)
|
||||||
Delete(fileURL string) error
|
Delete(fileURL string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,16 +29,17 @@ type JPayReq struct {
|
|||||||
OutTradeNo string `json:"out_trade_no"`
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
Subject string `json:"body"`
|
Subject string `json:"body"`
|
||||||
NotifyURL string `json:"notify_url"`
|
NotifyURL string `json:"notify_url"`
|
||||||
|
ReturnURL string `json:"callback_url"`
|
||||||
}
|
}
|
||||||
type JPayReps struct {
|
type JPayReps struct {
|
||||||
CodeUrl string `json:"code_url"`
|
|
||||||
OutTradeNo string `json:"out_trade_no"`
|
OutTradeNo string `json:"out_trade_no"`
|
||||||
OrderId string `json:"payjs_order_id"`
|
OrderId string `json:"payjs_order_id"`
|
||||||
Qrcode string `json:"qrcode"`
|
|
||||||
ReturnCode int `json:"return_code"`
|
ReturnCode int `json:"return_code"`
|
||||||
ReturnMsg string `json:"return_msg"`
|
ReturnMsg string `json:"return_msg"`
|
||||||
Sign string `json:"Sign"`
|
Sign string `json:"Sign"`
|
||||||
TotalFee string `json:"total_fee"`
|
TotalFee string `json:"total_fee"`
|
||||||
|
CodeUrl string `json:"code_url,omitempty"`
|
||||||
|
Qrcode string `json:"qrcode,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r JPayReps) IsOK() bool {
|
func (r JPayReps) IsOK() bool {
|
||||||
@@ -78,8 +79,14 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (js *PayJS) PayH5(p url.Values) string {
|
||||||
|
p.Add("mchid", js.config.AppId)
|
||||||
|
p.Add("sign", js.sign(p))
|
||||||
|
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
func (js *PayJS) sign(params url.Values) string {
|
func (js *PayJS) sign(params url.Values) string {
|
||||||
params.Del(`Sign`)
|
params.Del(`sign`)
|
||||||
var keys = make([]string, 0, 0)
|
var keys = make([]string, 0, 0)
|
||||||
for key := range params {
|
for key := range params {
|
||||||
if params.Get(key) != `` {
|
if params.Get(key) != `` {
|
||||||
@@ -109,7 +116,7 @@ func (js *PayJS) Check(tradeNo string) error {
|
|||||||
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
|
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("payjs_order_id", tradeNo)
|
params.Add("payjs_order_id", tradeNo)
|
||||||
params.Add("Sign", js.sign(params))
|
params.Add("sign", js.sign(params))
|
||||||
data := strings.NewReader(params.Encode())
|
data := strings.NewReader(params.Encode())
|
||||||
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
|
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -135,6 +142,7 @@ func (js *PayJS) Check(tradeNo string) error {
|
|||||||
if r.ReturnCode == 1 && r.Status == 1 {
|
if r.ReturnCode == 1 && r.Status == 1 {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
|
logger.Errorf("PayJs 支付验证响应:%s", string(body))
|
||||||
return errors.New("order not paid")
|
return errors.New("order not paid")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
|
"chatplus/store/model"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -14,6 +16,7 @@ type ServicePool struct {
|
|||||||
services []*Service
|
services []*Service
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
|
db *gorm.DB
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||||
// create mj client and service
|
// create mj client and service
|
||||||
for k, config := range appConfig.SdConfigs {
|
for _, config := range appConfig.SdConfigs {
|
||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// create sd service
|
// create sd service
|
||||||
name := fmt.Sprintf("StableDifffusion Service-%d", k)
|
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
|
||||||
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
|
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
|
||||||
// run sd service
|
// run sd service
|
||||||
go func() {
|
go func() {
|
||||||
service.Run()
|
service.Run()
|
||||||
@@ -42,6 +45,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
services: services,
|
services: services,
|
||||||
|
db: db,
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -52,6 +56,68 @@ func (p *ServicePool) PushTask(task types.SdTask) {
|
|||||||
p.taskQueue.RPush(task)
|
p.taskQueue.RPush(task)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ServicePool) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||||
|
for {
|
||||||
|
var userId uint
|
||||||
|
err := p.notifyQueue.LPop(&userId)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client := p.Clients.Get(userId)
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = client.Send([]byte("Task Updated"))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
|
func (p *ServicePool) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
|
for {
|
||||||
|
var jobs []model.SdJob
|
||||||
|
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 5 分钟还没完成的任务直接删除
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
||||||
|
p.db.Delete(&job)
|
||||||
|
var user model.User
|
||||||
|
p.db.Where("id = ?", job.UserId).First(&user)
|
||||||
|
// 退回绘图次数
|
||||||
|
res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
|
if res.Error == nil && res.RowsAffected > 0 {
|
||||||
|
p.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power + job.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "stable-diffusion",
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// HasAvailableService check if it has available mj service in pool
|
// HasAvailableService check if it has available mj service in pool
|
||||||
func (p *ServicePool) HasAvailableService() bool {
|
func (p *ServicePool) HasAvailableService() bool {
|
||||||
return len(p.services) > 0
|
return len(p.services) > 0
|
||||||
|
|||||||
@@ -2,69 +2,59 @@ package sd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SD 绘画服务
|
// SD 绘画服务
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
httpClient *req.Client
|
httpClient *req.Client
|
||||||
config types.StableDiffusionConfig
|
config types.StableDiffusionConfig
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
name string // service name
|
name string // service name
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
|
||||||
handledTaskNum int32 // already handled task number
|
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
||||||
taskTimeout int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||||
|
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
name: name,
|
||||||
config: config,
|
config: config,
|
||||||
httpClient: req.C(),
|
httpClient: req.C(),
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
db: db,
|
db: db,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
taskTimeout: timeout,
|
|
||||||
maxHandleTaskNum: maxTaskNum,
|
|
||||||
taskStartTimes: make(map[int]time.Time),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
for {
|
for {
|
||||||
s.checkTasks()
|
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.SdTask
|
var task types.SdTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("taking task with error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// 翻译提示词
|
||||||
|
if utils.HasChinese(task.Params.Prompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
|
||||||
|
if err == nil {
|
||||||
|
task.Params.Prompt = content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
||||||
err = s.Txt2Img(task)
|
err = s.Txt2Img(task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,240 +64,135 @@ func (s *Service) Run() {
|
|||||||
"progress": -1,
|
"progress": -1,
|
||||||
"err_msg": err.Error(),
|
"err_msg": err.Error(),
|
||||||
})
|
})
|
||||||
// restore img_call quota
|
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
// release task num
|
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
|
||||||
// 通知前端,任务失败
|
// 通知前端,任务失败
|
||||||
s.notifyQueue.RPush(task.UserId)
|
s.notifyQueue.RPush(task.UserId)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock the task until the execute timeout
|
|
||||||
s.taskStartTimes[task.Id] = time.Now()
|
|
||||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
// Txt2ImgReq 文生图请求实体
|
||||||
func (s *Service) canHandleTask() bool {
|
type Txt2ImgReq struct {
|
||||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
Prompt string `json:"prompt"`
|
||||||
return handledNum < s.maxHandleTaskNum
|
NegativePrompt string `json:"negative_prompt"`
|
||||||
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
Steps int `json:"steps"`
|
||||||
|
CfgScale float32 `json:"cfg_scale"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
SamplerName string `json:"sampler_name"`
|
||||||
|
EnableHr bool `json:"enable_hr,omitempty"`
|
||||||
|
HrScale int `json:"hr_scale,omitempty"`
|
||||||
|
HrUpscaler string `json:"hr_upscaler,omitempty"`
|
||||||
|
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
|
||||||
|
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
|
||||||
|
ForceTaskId string `json:"force_task_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the expired tasks
|
// Txt2ImgResp 文生图响应实体
|
||||||
func (s *Service) checkTasks() {
|
type Txt2ImgResp struct {
|
||||||
for k, t := range s.taskStartTimes {
|
Images []string `json:"images"`
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
Parameters struct {
|
||||||
delete(s.taskStartTimes, k)
|
} `json:"parameters"`
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
Info string `json:"info"`
|
||||||
// delete task from database
|
}
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
|
||||||
}
|
// TaskProgressResp 任务进度响应实体
|
||||||
}
|
type TaskProgressResp struct {
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
EtaRelative float64 `json:"eta_relative"`
|
||||||
|
CurrentImage string `json:"current_image"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Txt2Img 文生图 API
|
// Txt2Img 文生图 API
|
||||||
func (s *Service) Txt2Img(task types.SdTask) error {
|
func (s *Service) Txt2Img(task types.SdTask) error {
|
||||||
var taskInfo TaskInfo
|
body := Txt2ImgReq{
|
||||||
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
|
Prompt: task.Params.Prompt,
|
||||||
if err != nil {
|
NegativePrompt: task.Params.NegativePrompt,
|
||||||
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
|
Steps: task.Params.Steps,
|
||||||
|
CfgScale: task.Params.CfgScale,
|
||||||
|
Width: task.Params.Width,
|
||||||
|
Height: task.Params.Height,
|
||||||
|
SamplerName: task.Params.Sampler,
|
||||||
}
|
}
|
||||||
|
if task.Params.Seed > 0 {
|
||||||
err = json.Unmarshal(bytes, &taskInfo)
|
body.Seed = task.Params.Seed
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode json params: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
|
if task.Params.HdFix {
|
||||||
data := taskInfo.Data
|
body.EnableHr = true
|
||||||
params := task.Params
|
body.HrScale = task.Params.HdScale
|
||||||
data[ParamKeys["task_id"]] = params.TaskId
|
body.HrUpscaler = task.Params.HdScaleAlg
|
||||||
data[ParamKeys["prompt"]] = params.Prompt
|
body.HrSecondPassSteps = task.Params.HdSteps
|
||||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
body.DenoisingStrength = task.Params.HdRedrawRate
|
||||||
data[ParamKeys["steps"]] = params.Steps
|
}
|
||||||
data[ParamKeys["sampler"]] = params.Sampler
|
var res Txt2ImgResp
|
||||||
// @fix bug: 有些 stable diffusion 没有面部修复功能
|
var errChan = make(chan error)
|
||||||
//data[ParamKeys["face_fix"]] = params.FaceFix
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
logger.Debugf("send image request to %s", apiURL)
|
||||||
data[ParamKeys["seed"]] = params.Seed
|
|
||||||
data[ParamKeys["height"]] = params.Height
|
|
||||||
data[ParamKeys["width"]] = params.Width
|
|
||||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
|
||||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
|
||||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
|
||||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
|
||||||
data[ParamKeys["hd_sample_num"]] = params.HdSteps
|
|
||||||
|
|
||||||
taskInfo.SessionId = task.SessionId
|
|
||||||
taskInfo.TaskId = params.TaskId
|
|
||||||
taskInfo.Data = data
|
|
||||||
taskInfo.JobId = task.Id
|
|
||||||
taskInfo.UserId = uint(task.UserId)
|
|
||||||
go func() {
|
go func() {
|
||||||
s.runTask(taskInfo, s.httpClient)
|
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行任务
|
|
||||||
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
|
|
||||||
body := map[string]any{
|
|
||||||
"data": taskInfo.Data,
|
|
||||||
"event_data": taskInfo.EventData,
|
|
||||||
"fn_index": taskInfo.FnIndex,
|
|
||||||
"session_hash": taskInfo.SessionHash,
|
|
||||||
}
|
|
||||||
var result = make(chan CBReq)
|
|
||||||
go func() {
|
|
||||||
var res struct {
|
|
||||||
Data []interface{} `json:"data"`
|
|
||||||
IsGenerating bool `json:"is_generating"`
|
|
||||||
Duration float64 `json:"duration"`
|
|
||||||
AverageDuration float64 `json:"average_duration"`
|
|
||||||
}
|
|
||||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
|
||||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cbReq.Message = "error with send request: " + err.Error()
|
errChan <- err
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.IsErrorState() {
|
if response.IsErrorState() {
|
||||||
bytes, _ := io.ReadAll(response.Body)
|
errChan <- fmt.Errorf("error http code status: %v", response.Status)
|
||||||
cbReq.Message = "error http status code: " + string(bytes)
|
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var images []struct {
|
// 保存 Base64 图片
|
||||||
Name string `json:"name"`
|
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||||
Data interface{} `json:"data"`
|
|
||||||
IsFile bool `json:"is_file"`
|
|
||||||
}
|
|
||||||
err = utils.ForceCovert(res.Data[0], &images)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cbReq.Message = "error with decode image:" + err.Error()
|
errChan <- fmt.Errorf("error with upload image: %v", err)
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 获取绘画真实的 seed
|
||||||
var info map[string]any
|
var info map[string]interface{}
|
||||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
err = utils.JsonDecode(res.Info, &info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(res.Data)
|
errChan <- fmt.Errorf("error with decode task response: %v", err)
|
||||||
cbReq.Message = "error with decode image url:" + err.Error()
|
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
|
||||||
// 获取真实的 seed 值
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
|
||||||
cbReq.ImageName = images[0].Name
|
errChan <- nil
|
||||||
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
|
|
||||||
cbReq.Seed = seed
|
|
||||||
cbReq.Success = true
|
|
||||||
cbReq.Progress = 100
|
|
||||||
result <- cbReq
|
|
||||||
close(result)
|
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case value := <-result:
|
case err := <-errChan: // 任务完成
|
||||||
s.callback(value)
|
if err != nil {
|
||||||
return
|
return err
|
||||||
|
}
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||||
|
s.notifyQueue.RPush(task.UserId)
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
var progressReq = map[string]any{
|
err, resp := s.checkTaskProgress()
|
||||||
"id_task": taskInfo.TaskId,
|
// 更新任务进度
|
||||||
"id_live_preview": 1,
|
if err == nil && resp.Progress > 0 {
|
||||||
|
logger.Debugf("Check task progress: %+v", resp.Progress)
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||||
|
// 发送更新状态信号
|
||||||
|
s.notifyQueue.RPush(task.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
var progressRes struct {
|
|
||||||
Active bool `json:"active"`
|
|
||||||
Queued bool `json:"queued"`
|
|
||||||
Completed bool `json:"completed"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
Eta float64 `json:"eta"`
|
|
||||||
LivePreview string `json:"live_preview"`
|
|
||||||
IDLivePreview int `json:"id_live_preview"`
|
|
||||||
TextInfo interface{} `json:"textinfo"`
|
|
||||||
}
|
|
||||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
|
|
||||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
|
||||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
|
||||||
logger.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.IsErrorState() {
|
|
||||||
bytes, _ := io.ReadAll(response.Body)
|
|
||||||
logger.Error(string(bytes))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cbReq.ImageData = progressRes.LivePreview
|
|
||||||
cbReq.Progress = int(progressRes.Progress * 100)
|
|
||||||
s.callback(cbReq)
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) callback(data CBReq) {
|
// 执行任务
|
||||||
// release task num
|
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||||
if data.Success { // 任务成功
|
var res TaskProgressResp
|
||||||
var job model.SdJob
|
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
||||||
res := s.db.Where("id = ?", data.JobId).First(&job)
|
if err != nil {
|
||||||
if res.Error != nil {
|
return err, nil
|
||||||
logger.Warn("非法任务:", res.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 更新任务进度
|
|
||||||
job.Progress = data.Progress
|
|
||||||
// 更新任务 seed
|
|
||||||
var params types.SdTaskParams
|
|
||||||
err := utils.JsonDecode(job.Params, ¶ms)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("任务解析失败:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
params.Seed = data.Seed
|
|
||||||
if data.ImageName != "" { // 下载图片
|
|
||||||
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
|
||||||
if data.Progress == 100 {
|
|
||||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download img: ", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
job.ImgURL = imageURL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
job.Params = utils.JsonEncode(params)
|
|
||||||
res = s.db.Updates(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update job: ", res.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("绘图进度:%d", data.Progress)
|
|
||||||
} else { // 任务失败
|
|
||||||
logger.Error("任务执行失败:", data.Message)
|
|
||||||
// update the task progress
|
|
||||||
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": data.Message,
|
|
||||||
})
|
|
||||||
// restore img_calls
|
|
||||||
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
}
|
}
|
||||||
|
if response.IsErrorState() {
|
||||||
|
return fmt.Errorf("error http code status: %v", response.Status), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &res
|
||||||
}
|
}
|
||||||
|
|||||||
4
api/service/types.go
Normal file
4
api/service/types.go
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
||||||
|
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||||
@@ -41,7 +41,7 @@ func parseTransactionMessage(xmlData string) *Message {
|
|||||||
}
|
}
|
||||||
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
|
if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
|
||||||
if err := decoder.DecodeElement(&value, &se); err == nil {
|
if err := decoder.DecodeElement(&value, &se); err == nil {
|
||||||
if strings.Contains(value, "trans_id=") {
|
if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
|
||||||
message.Url = value
|
message.Url = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,13 +39,14 @@ func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor {
|
|||||||
|
|
||||||
func (e *XXLJobExecutor) Run() error {
|
func (e *XXLJobExecutor) Run() error {
|
||||||
e.executor.RegTask("ClearOrders", e.ClearOrders)
|
e.executor.RegTask("ClearOrders", e.ClearOrders)
|
||||||
e.executor.RegTask("ResetVipCalls", e.ResetVipCalls)
|
e.executor.RegTask("ResetVipPower", e.ResetVipPower)
|
||||||
|
e.executor.RegTask("ResetUserPower", e.ResetUserPower)
|
||||||
return e.executor.Run()
|
return e.executor.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
|
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
|
||||||
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
|
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||||
logger.Debug("执行清理未支付订单...")
|
logger.Info("执行清理未支付订单...")
|
||||||
var sysConfig model.Config
|
var sysConfig model.Config
|
||||||
res := e.db.Where("marker", "system").First(&sysConfig)
|
res := e.db.Where("marker", "system").First(&sysConfig)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
@@ -64,15 +65,17 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
|
|||||||
timeout := time.Now().Unix() - int64(config.OrderPayTimeout)
|
timeout := time.Now().Unix() - int64(config.OrderPayTimeout)
|
||||||
start := utils.Stamp2str(timeout)
|
start := utils.Stamp2str(timeout)
|
||||||
// 这里不是用软删除,而是永久删除订单
|
// 这里不是用软删除,而是永久删除订单
|
||||||
res = e.db.Unscoped().Where("status != ? AND created_at < ?", types.OrderPaidSuccess, start).Delete(&model.Order{})
|
res = e.db.Unscoped().Where("status IN ? AND created_at < ?", []types.OrderStatus{types.OrderNotPaid, types.OrderScanned}, start).Delete(&model.Order{})
|
||||||
return fmt.Sprintf("Clear order successfully, affect rows: %d", res.RowsAffected)
|
logger.Infof("Clear order successfully, affect rows: %d", res.RowsAffected)
|
||||||
|
return "success"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetVipCalls 清理过期的 VIP 会员
|
// ResetVipPower 重置VIP会员算力
|
||||||
func (e *XXLJobExecutor) ResetVipCalls(cxt context.Context, param *xxl.RunReq) (msg string) {
|
// 自动将 VIP 会员的算力补充到每月赠送的最大值
|
||||||
|
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||||
logger.Info("开始进行月底账号盘点...")
|
logger.Info("开始进行月底账号盘点...")
|
||||||
var users []model.User
|
var users []model.User
|
||||||
res := e.db.Where("vip = ?", 1).Find(&users)
|
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return "No vip users found"
|
return "No vip users found"
|
||||||
}
|
}
|
||||||
@@ -89,60 +92,92 @@ func (e *XXLJobExecutor) ResetVipCalls(cxt context.Context, param *xxl.RunReq) (
|
|||||||
return "error with decode system config: " + err.Error()
|
return "error with decode system config: " + err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取本月月初时间
|
|
||||||
currentTime := time.Now()
|
|
||||||
year, month, _ := currentTime.Date()
|
|
||||||
firstOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, currentTime.Location()).Unix()
|
|
||||||
for _, u := range users {
|
for _, u := range users {
|
||||||
// 账号到期,直接清零
|
// 处理过期的 VIP
|
||||||
if u.ExpiredTime <= currentTime.Unix() {
|
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
|
||||||
logger.Info("账号过期:", u.Username)
|
|
||||||
u.Calls = 0
|
|
||||||
u.Vip = false
|
u.Vip = false
|
||||||
} else {
|
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
|
||||||
if u.Calls <= 0 {
|
continue
|
||||||
u.Calls = 0
|
|
||||||
}
|
|
||||||
if u.ImgCalls <= 0 {
|
|
||||||
u.ImgCalls = 0
|
|
||||||
}
|
|
||||||
// 如果该用户当月有充值点卡,则将点卡中未用完的点数结余到下个月
|
|
||||||
var orders []model.Order
|
|
||||||
e.db.Debug().Where("user_id = ? AND pay_time > ?", u.Id, firstOfMonth).Find(&orders)
|
|
||||||
var calls = 0
|
|
||||||
var imgCalls = 0
|
|
||||||
for _, o := range orders {
|
|
||||||
var remark types.OrderRemark
|
|
||||||
err = utils.JsonDecode(o.Remark, &remark)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if remark.Days > 0 { // 会员续费
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
calls += remark.Calls
|
|
||||||
imgCalls += remark.ImgCalls
|
|
||||||
}
|
|
||||||
if u.Calls > calls { // 本月套餐没有用完
|
|
||||||
u.Calls = calls + config.VipMonthCalls
|
|
||||||
} else {
|
|
||||||
u.Calls = u.Calls + config.VipMonthCalls
|
|
||||||
}
|
|
||||||
if u.ImgCalls > imgCalls { // 本月套餐没有用完
|
|
||||||
u.ImgCalls = imgCalls + config.VipMonthImgCalls
|
|
||||||
} else {
|
|
||||||
u.ImgCalls = u.ImgCalls + config.VipMonthImgCalls
|
|
||||||
}
|
|
||||||
logger.Infof("%s 点卡结余:%d", u.Username, calls)
|
|
||||||
}
|
}
|
||||||
u.Tokens = 0
|
|
||||||
// update user
|
// update user
|
||||||
e.db.Updates(&u)
|
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower))
|
||||||
|
// 记录算力变动日志
|
||||||
|
if tx.Error == nil {
|
||||||
|
var user model.User
|
||||||
|
e.db.Where("id", u.Id).First(&user)
|
||||||
|
e.db.Create(&model.PowerLog{
|
||||||
|
UserId: u.Id,
|
||||||
|
Username: u.Username,
|
||||||
|
Type: types.PowerRecharge,
|
||||||
|
Amount: config.VipMonthPower,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Balance: user.Power,
|
||||||
|
Model: "系统盘点",
|
||||||
|
Remark: fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
logger.Info("月底盘点完成!")
|
logger.Info("月底盘点完成!")
|
||||||
return "success"
|
return "success"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *XXLJobExecutor) ResetUserPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||||
|
logger.Info("今日算力派发开始:", time.Now())
|
||||||
|
var users []model.User
|
||||||
|
res := e.db.Where("status", 1).Find(&users)
|
||||||
|
if res.Error != nil {
|
||||||
|
return "No matching users"
|
||||||
|
}
|
||||||
|
|
||||||
|
var sysConfig model.Config
|
||||||
|
res = e.db.Where("marker", "system").First(&sysConfig)
|
||||||
|
if res.Error != nil {
|
||||||
|
return "error with get system config: " + res.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
var config types.SystemConfig
|
||||||
|
err := utils.JsonDecode(sysConfig.Config, &config)
|
||||||
|
if err != nil {
|
||||||
|
return "error with decode system config: " + err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.DailyPower <= 0 {
|
||||||
|
return "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
var counter = 0
|
||||||
|
var totalPower = 0
|
||||||
|
for _, u := range users {
|
||||||
|
if u.Power >= config.DailyPower {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var power = config.DailyPower - u.Power
|
||||||
|
// update user
|
||||||
|
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
|
||||||
|
// 记录算力充值日志
|
||||||
|
if tx.Error == nil {
|
||||||
|
var user model.User
|
||||||
|
e.db.Where("id", u.Id).First(&user)
|
||||||
|
e.db.Create(&model.PowerLog{
|
||||||
|
UserId: u.Id,
|
||||||
|
Username: u.Username,
|
||||||
|
Type: types.PowerGift,
|
||||||
|
Amount: power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Balance: user.Power,
|
||||||
|
Model: "系统赠送",
|
||||||
|
Remark: fmt.Sprintf("系统每日算力派发,今日额度:%d", config.DailyPower),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
counter++
|
||||||
|
totalPower += power
|
||||||
|
}
|
||||||
|
logger.Infof("今日派发算力结束!累计派发 %d 人,累计派发算力:%d", counter, totalPower)
|
||||||
|
return "success"
|
||||||
|
}
|
||||||
|
|
||||||
type customLogger struct{}
|
type customLogger struct{}
|
||||||
|
|
||||||
func (l *customLogger) Info(format string, a ...interface{}) {
|
func (l *customLogger) Info(format string, a ...interface{}) {
|
||||||
|
|||||||
11
api/store/model/admin_user.go
Normal file
11
api/store/model/admin_user.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type AdminUser struct {
|
||||||
|
BaseModel
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
Salt string // 密码盐
|
||||||
|
Status bool `gorm:"default:true"` // 当前状态
|
||||||
|
LastLoginAt int64 // 最后登录时间
|
||||||
|
LastLoginIp string // 最后登录 IP
|
||||||
|
}
|
||||||
@@ -9,6 +9,6 @@ type ApiKey struct {
|
|||||||
Value string // API Key 的值
|
Value string // API Key 的值
|
||||||
ApiURL string // 当前 KEY 的 API 地址
|
ApiURL string // 当前 KEY 的 API 地址
|
||||||
Enabled bool // 是否启用
|
Enabled bool // 是否启用
|
||||||
UseProxy bool // 是否使用代理访问 API URL
|
ProxyURL string // 代理地址
|
||||||
LastUsedAt int64 // 最后使用时间
|
LastUsedAt int64 // 最后使用时间
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package model
|
|||||||
|
|
||||||
type ChatModel struct {
|
type ChatModel struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
Platform string
|
Platform string
|
||||||
Name string
|
Name string
|
||||||
Value string // API Key 的值
|
Value string // API Key 的值
|
||||||
SortNum int
|
SortNum int
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Weight int // 对话权重,每次对话扣减多少次对话额度
|
Power int // 每次对话消耗算力
|
||||||
Open bool // 是否开放模型给所有人使用
|
Open bool // 是否开放模型给所有人使用
|
||||||
|
MaxTokens int // 最大响应长度
|
||||||
|
MaxContext int // 最大上下文长度
|
||||||
|
Temperature float32 // 模型温度
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import "time"
|
|||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
Id uint `gorm:"primarykey;column:id"`
|
Id uint `gorm:"primarykey;column:id"`
|
||||||
UserId uint
|
UserId int
|
||||||
Name string
|
Name string
|
||||||
ObjKey string
|
ObjKey string
|
||||||
URL string
|
URL string
|
||||||
|
|||||||
@@ -10,6 +10,6 @@ type InviteLog struct {
|
|||||||
UserId uint
|
UserId uint
|
||||||
Username string
|
Username string
|
||||||
InviteCode string
|
InviteCode string
|
||||||
Reward string `gorm:"column:reward_json"` // 邀请奖励
|
Remark string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type MidJourneyJob struct {
|
|||||||
UseProxy bool // 是否使用反代加载图片
|
UseProxy bool // 是否使用反代加载图片
|
||||||
Publish bool //是否发布图片到画廊
|
Publish bool //是否发布图片到画廊
|
||||||
ErrMsg string // 报错信息
|
ErrMsg string // 报错信息
|
||||||
|
Power int // 消耗算力
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
20
api/store/model/power_log.go
Normal file
20
api/store/model/power_log.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PowerLog 算力消费日志
|
||||||
|
type PowerLog struct {
|
||||||
|
Id uint `gorm:"primarykey;column:id"`
|
||||||
|
UserId uint
|
||||||
|
Username string
|
||||||
|
Type types.PowerType
|
||||||
|
Amount int
|
||||||
|
Balance int
|
||||||
|
Model string // 模型
|
||||||
|
Remark string // 备注
|
||||||
|
Mark types.PowerMark // 资金类型
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
@@ -7,8 +7,7 @@ type Product struct {
|
|||||||
Price float64
|
Price float64
|
||||||
Discount float64
|
Discount float64
|
||||||
Days int
|
Days int
|
||||||
Calls int
|
Power int
|
||||||
ImgCalls int
|
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Sales int
|
Sales int
|
||||||
SortNum int
|
SortNum int
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ type SdJob struct {
|
|||||||
Params string
|
Params string
|
||||||
Publish bool //是否发布图片到画廊
|
Publish bool //是否发布图片到画廊
|
||||||
ErrMsg string // 报错信息
|
ErrMsg string // 报错信息
|
||||||
|
Power int // 消耗算力
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ type User struct {
|
|||||||
Password string
|
Password string
|
||||||
Avatar string
|
Avatar string
|
||||||
Salt string // 密码盐
|
Salt string // 密码盐
|
||||||
TotalTokens int64 // 总消耗 tokens
|
Power int // 剩余算力
|
||||||
Calls int // 剩余对话次数
|
|
||||||
ImgCalls int // 剩余绘图次数
|
|
||||||
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
|
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
|
||||||
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
|
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
|
||||||
ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型
|
ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型
|
||||||
@@ -18,5 +16,4 @@ type User struct {
|
|||||||
LastLoginAt int64 // 最后登录时间
|
LastLoginAt int64 // 最后登录时间
|
||||||
LastLoginIp string // 最后登录 IP
|
LastLoginIp string // 最后登录 IP
|
||||||
Vip bool // 是否 VIP 会员
|
Vip bool // 是否 VIP 会员
|
||||||
Tokens int
|
|
||||||
}
|
}
|
||||||
|
|||||||
10
api/store/vo/admin_user.go
Normal file
10
api/store/vo/admin_user.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package vo
|
||||||
|
|
||||||
|
type AdminUser struct {
|
||||||
|
BaseVo
|
||||||
|
Username string `json:"username"`
|
||||||
|
Status bool `json:"status"` // 当前状态
|
||||||
|
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
|
||||||
|
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
|
||||||
|
RoleIds interface{} `json:"role_ids"` //角色ids
|
||||||
|
}
|
||||||
@@ -9,6 +9,6 @@ type ApiKey struct {
|
|||||||
Value string `json:"value"` // API Key 的值
|
Value string `json:"value"` // API Key 的值
|
||||||
ApiURL string `json:"api_url"`
|
ApiURL string `json:"api_url"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
UseProxy bool `json:"use_proxy"`
|
ProxyURL string `json:"proxy_url"`
|
||||||
LastUsedAt int64 `json:"last_used_at"` // 最后使用时间
|
LastUsedAt int64 `json:"last_used_at"` // 最后使用时间
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package vo
|
|||||||
|
|
||||||
type ChatModel struct {
|
type ChatModel struct {
|
||||||
BaseVo
|
BaseVo
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
SortNum int `json:"sort_num"`
|
SortNum int `json:"sort_num"`
|
||||||
Weight int `json:"weight"`
|
Power int `json:"power"`
|
||||||
Open bool `json:"open"`
|
Open bool `json:"open"`
|
||||||
|
MaxTokens int `json:"max_tokens"` // 最大响应长度
|
||||||
|
MaxContext int `json:"max_context"` // 最大上下文长度
|
||||||
|
Temperature float32 `json:"temperature"` // 模型温度
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,5 @@ import "chatplus/core/types"
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
ChatConfig types.ChatConfig `json:"chat_config"`
|
|
||||||
SystemConfig types.SystemConfig `json:"system_config"`
|
SystemConfig types.SystemConfig `json:"system_config"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
package vo
|
package vo
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InviteLog struct {
|
type InviteLog struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
InviterId uint `json:"inviter_id"`
|
InviterId uint `json:"inviter_id"`
|
||||||
UserId uint `json:"user_id"`
|
UserId uint `json:"user_id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
InviteCode string `json:"invite_code"`
|
InviteCode string `json:"invite_code"`
|
||||||
Reward types.InviteReward `json:"reward"`
|
Remark string `json:"remark"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,5 +18,6 @@ type MidJourneyJob struct {
|
|||||||
UseProxy bool `json:"use_proxy"`
|
UseProxy bool `json:"use_proxy"`
|
||||||
Publish bool `json:"publish"`
|
Publish bool `json:"publish"`
|
||||||
ErrMsg string `json:"err_msg"`
|
ErrMsg string `json:"err_msg"`
|
||||||
|
Power int `json:"power"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
17
api/store/vo/power_log.go
Normal file
17
api/store/vo/power_log.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package vo
|
||||||
|
|
||||||
|
import "chatplus/core/types"
|
||||||
|
|
||||||
|
type PowerLog struct {
|
||||||
|
Id uint `json:"id"`
|
||||||
|
UserId uint `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Type types.PowerType `json:"type"`
|
||||||
|
TypeStr string `json:"type_str"`
|
||||||
|
Amount int `json:"amount"`
|
||||||
|
Mark types.PowerMark `json:"mark"`
|
||||||
|
Balance int `json:"balance"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Remark string `json:"remark"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
}
|
||||||
@@ -6,8 +6,7 @@ type Product struct {
|
|||||||
Price float64 `json:"price"`
|
Price float64 `json:"price"`
|
||||||
Discount float64 `json:"discount"`
|
Discount float64 `json:"discount"`
|
||||||
Days int `json:"days"`
|
Days int `json:"days"`
|
||||||
Calls int `json:"calls"`
|
Power int `json:"power"`
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Sales int `json:"sales"`
|
Sales int `json:"sales"`
|
||||||
SortNum int `json:"sort_num"`
|
SortNum int `json:"sort_num"`
|
||||||
|
|||||||
@@ -12,6 +12,5 @@ type Reward struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RewardExchange struct {
|
type RewardExchange struct {
|
||||||
Calls int `json:"calls"`
|
Power int `json:"power"`
|
||||||
ImgCalls int `json:"img_calls"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,5 +16,6 @@ type SdJob struct {
|
|||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Publish bool `json:"publish"`
|
Publish bool `json:"publish"`
|
||||||
ErrMsg string `json:"err_msg"`
|
ErrMsg string `json:"err_msg"`
|
||||||
|
Power int `json:"power"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,17 @@
|
|||||||
package vo
|
package vo
|
||||||
|
|
||||||
import "chatplus/core/types"
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
BaseVo
|
BaseVo
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Nickname string `json:"nickname"`
|
Nickname string `json:"nickname"`
|
||||||
Avatar string `json:"avatar"`
|
Avatar string `json:"avatar"`
|
||||||
Salt string `json:"salt"` // 密码盐
|
Salt string `json:"salt"` // 密码盐
|
||||||
TotalTokens int64 `json:"total_tokens"` // 总消耗tokens
|
Power int `json:"power"` // 剩余算力
|
||||||
Calls int `json:"calls"` // 剩余对话次数
|
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
|
||||||
ImgCalls int `json:"img_calls"`
|
ChatModels []int `json:"chat_models"` // AI模型集合
|
||||||
ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置
|
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
|
||||||
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
|
Status bool `json:"status"` // 当前状态
|
||||||
ChatModels []string `json:"chat_models"` // AI模型集合
|
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
|
||||||
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
|
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
|
||||||
Status bool `json:"status"` // 当前状态
|
Vip bool `json:"vip"`
|
||||||
LastLoginAt int64 `json:"last_login_at"` // 最后登录时间
|
|
||||||
LastLoginIp string `json:"last_login_ip"` // 最后登录 IP
|
|
||||||
Vip bool `json:"vip"`
|
|
||||||
Tokens int `json:"token"` // 当月消耗的 fee
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,49 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chatplus/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
text := `
|
text := "一只 蜗牛在树干上爬,阳光透过树叶照在蜗牛的背上 --ar 1:1 --iw 0.250000 --v 6"
|
||||||
> search("Shenzhen weather January 15, 2024")
|
fmt.Println(utils.HasChinese(text))
|
||||||
|
|
||||||
> mclick([0, 9, 16])
|
|
||||||
|
|
||||||
> **end-searching**
|
|
||||||
|
|
||||||
今天深圳的天气情况如下:
|
|
||||||
|
|
||||||
- 白天气温预计在21°C至24°C之间,天气晴朗。
|
|
||||||
- 晚上气温预计在21°C左右,云量较多,可能会有间断性小雨。
|
|
||||||
- 风向主要是东南风,风速大约在6至12公里每小时之间。
|
|
||||||
|
|
||||||
这些信息表明深圳今天的天气相对舒适,适合户外活动。晚上可能需要带伞以应对间断性小雨。温度较为宜人,早晚可能稍微凉爽一些【[Shenzhen weather in January 2024 | Shenzhen 14 day weather](https://www.weather25.com/asia/china/guangdong/shenzhen?page=month&month=January)】【[Hourly forecast for Shenzhen, Guangdong, China](https://www.timeanddate.com/weather/china/shenzhen/hourly)】【[Shenzhen Guangdong China 15 Day Weather Forecast](https://www.weatheravenue.com/en/asia/cn/guangdong/shenzhen-weather-15-days.html)】。
|
|
||||||
|
|
||||||
我将根据这些信息生成一张气象图,展示深圳今天的天气情况。
|
|
||||||
|
|
||||||
{"prompt":"A detailed weather map for Shenzhen, China, on January 15, 2024. The map shows a sunny day with clear skies during the day and partly cloudy skies at night. Temperatures range from 21\u00b0C to 24\u00b0C during the day and around 21\u00b0C at night. There are indications of light southeast winds during the day and evening, with wind speeds ranging from 6 to 12 km/h. The map includes symbols for sunshine, light clouds, and wind direction arrows, along with temperature readings for different times of the day. The layout is clear, with a focus on Shenzhen's geographical location and the surrounding region.","size":"1024x1024"}
|
|
||||||
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
And here is another image link: .
|
|
||||||
|
|
||||||
|
|
||||||
这是根据今天深圳的天气情况制作的气象图。图中展示了白天晴朗、夜间部分多云的天气,以及相关的温度和风向信息。`
|
|
||||||
pattern := `!\[([^\]]*)]\(([^)]+)\)`
|
|
||||||
|
|
||||||
// 编译正则表达式
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
|
|
||||||
// 查找匹配的字符串
|
|
||||||
matches := re.FindAllStringSubmatch(text, -1)
|
|
||||||
|
|
||||||
// 提取链接并打印
|
|
||||||
for _, match := range matches {
|
|
||||||
fmt.Println(match[2])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user