feat: load preview page do not require user to login

This commit is contained in:
RockYang
2024-03-19 18:25:01 +08:00
parent 549f618cff
commit c5e583b215
57 changed files with 758 additions and 1550 deletions

View File

@@ -30,19 +30,15 @@ type Manager struct {
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
h := ManagerHandler{db: db, redis: client}
h.App = app
return &h
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -56,7 +52,7 @@ func (h *ManagerHandler) Login(c *gin.Context) {
}
var manager model.AdminUser
res := h.db.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {
resp.ERROR(c, "请检查用户名或者密码是否填写正确")
return
@@ -78,7 +74,7 @@ func (h *ManagerHandler) Login(c *gin.Context) {
"user_id": manager.Username,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
@@ -93,35 +89,19 @@ func (h *ManagerHandler) Login(c *gin.Context) {
// 更新最后登录时间和IP
manager.LastLoginIp = c.ClientIP()
manager.LastLoginAt = time.Now().Unix()
h.db.Model(&manager).Updates(manager)
h.DB.Model(&manager).Updates(manager)
permissions := h.GetAdminSlugs(manager.Id)
var result = struct {
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
Permissions []string `json:"permissions"`
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
}{
IsSuperAdmin: manager.Id == 1,
Token: tokenString,
Permissions: permissions,
}
resp.SUCCESS(c, result)
}
func (h *ManagerHandler) GetAdminSlugs(userId uint) []string {
var permissions []string
err := h.db.Raw("SELECT distinct p.slug "+
"FROM chatgpt_admin_user_roles as ur "+
"LEFT JOIN chatgpt_admin_role_permissions as rp ON ur.role_id = rp.role_id "+
"LEFT JOIN chatgpt_admin_permissions as p ON rp.permission_id = p.id "+
"WHERE ur.admin_id = ?", userId).Scan(&permissions)
if err.Error == nil {
return []string{}
}
return permissions
}
// Logout 注销
func (h *ManagerHandler) Logout(c *gin.Context) {
key := h.GetUserKey(c)

View File

@@ -1,132 +0,0 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type SysPermissionHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewSysPermissionHandler(app *core.AppServer, db *gorm.DB) *SysPermissionHandler {
h := SysPermissionHandler{db: db}
h.App = app
return &h
}
func (h *SysPermissionHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
resp.NotPermission(c)
return
}
var items []model.AdminPermission
var data = make([]vo.AdminPermission, 0)
res := h.db.Find(&items)
if res.Error != nil {
resp.ERROR(c, "暂无数据")
return
}
for _, item := range items {
adminPermissionVo := vo.AdminPermission{}
_ = utils.CopyObject(item, &adminPermissionVo)
data = append(data, adminPermissionVo)
}
data = ArrayToTree(data)
resp.SUCCESS(c, data)
}
func ArrayToTree(dates []vo.AdminPermission) []vo.AdminPermission {
group := make(map[int][]vo.AdminPermission, 0)
for _, node := range dates {
group[node.Pid] = append(group[node.Pid], node)
}
// 初始化递归,从根节点开始构建树
result := FindSiblings(group[0], group)
return result
}
func FindSiblings(siblings []vo.AdminPermission, group map[int][]vo.AdminPermission) []vo.AdminPermission {
result := make([]vo.AdminPermission, 0)
for _, sibling := range siblings {
children, ok := group[sibling.Id]
if ok {
sibling.Children = FindSiblings(children, group)
}
result = append(result, sibling)
}
return result
}
func (h *SysPermissionHandler) Save(c *gin.Context) {
var data struct {
Id int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Sort int `json:"sort"`
Pid int `json:"pid"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var permission = model.AdminPermission{}
var res *gorm.DB
if data.Id > 0 { // 更新
permission.Id = data.Id
// 此处需要用 map 更新,用结构体无法更新 0 值
res = h.db.Model(&permission).Updates(map[string]interface{}{
"name": data.Name,
"slug": data.Slug,
"sort": data.Sort,
"pid": data.Pid,
})
} else {
p := model.AdminPermission{
Name: data.Name,
Slug: data.Slug,
Sort: data.Sort,
Pid: data.Pid,
}
res = h.db.Create(&p)
}
if res.Error != nil {
fmt.Println(res.Error)
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}
func (h *SysPermissionHandler) Remove(c *gin.Context) {
var data struct {
Id int
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
res := h.db.Where("id = ?", data.Id).Delete(&model.AdminPermission{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -1,166 +0,0 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type SysRoleHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewSysRoleHandler(app *core.AppServer, db *gorm.DB) *SysRoleHandler {
h := SysRoleHandler{db: db}
h.App = app
return &h
}
type permission struct {
Id int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
}
func (h *SysRoleHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
resp.NotPermission(c)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
name := h.GetTrim(c, "name")
offset := (page - 1) * pageSize
var items []model.AdminRole
var data = make([]vo.AdminRole, 0)
var total int64
session := h.db.Session(&gorm.Session{})
if name != "" {
session = session.Where("name LIKE ?", "%"+name+"%")
}
session.Model(&model.AdminRole{}).Count(&total)
res := session.Offset(offset).Limit(pageSize).Find(&items)
if res.Error != nil {
resp.ERROR(c, "暂无数据")
return
}
for _, item := range items {
adminRoleVo := vo.AdminRole{}
err := utils.CopyObject(item, &adminRoleVo)
if err == nil {
var permissions []permission
h.db.Raw("SELECT p.id,p.name,p.slug "+
"FROM chatgpt_admin_role_permissions as rp "+
"LEFT JOIN chatgpt_admin_permissions as p ON rp.permission_id = p.id "+
"WHERE rp.role_id = ?", item.Id).Scan(&permissions)
adminRoleVo.Permissions = permissions
adminRoleVo.CreatedAt = item.CreatedAt.Format("2006-01-02 15:04:05")
data = append(data, adminRoleVo)
}
}
pageVo := vo.NewPage(total, page, pageSize, data)
resp.SUCCESS(c, pageVo)
}
func (h *SysRoleHandler) Save(c *gin.Context) {
var data struct {
Id int
Name string
Description string
Permissions []int
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var role = model.AdminRole{}
var res *gorm.DB
tx := h.db.Begin()
if data.Id > 0 { // 更新
role.Id = data.Id
//删除角色对应的权限
err := tx.Where("role_id = ?", role.Id).Delete(model.AdminRolePermission{})
if err.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败")
return
}
//更新角色名
res = tx.Model(&role).Updates(map[string]interface{}{
"name": data.Name,
"description": data.Description,
})
} else {
//新建角色
role.Name = data.Name
role.Description = data.Description
res = tx.Create(&role)
}
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败")
return
}
rp := make([]model.AdminRolePermission, 0)
if len(data.Permissions) > 0 {
for _, per := range data.Permissions {
rp = append(rp, model.AdminRolePermission{
RoleId: role.Id,
PermissionId: per,
})
}
res2 := tx.CreateInBatches(rp, len(rp))
if res2.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败")
return
}
}
tx.Commit()
resp.SUCCESS(c)
}
func (h *SysRoleHandler) Remove(c *gin.Context) {
var data struct {
Id int
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
tx := h.db.Begin()
res := tx.Where("id = ?", data.Id).Delete(&model.AdminRole{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
res = tx.Where("role_id = ?", data.Id).Delete(&model.AdminRolePermission{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
}
resp.SUCCESS(c)
}

View File

@@ -1,219 +0,0 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type SysUserHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewSysUserHandler(app *core.AppServer, db *gorm.DB) *SysUserHandler {
h := SysUserHandler{db: db}
h.App = app
return &h
}
type role struct {
Id int `json:"id"`
Name string `json:"name"`
}
// List 用户列表
func (h *SysUserHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
resp.NotPermission(c)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
username := h.GetTrim(c, "username")
offset := (page - 1) * pageSize
var items []model.AdminUser
var users = make([]vo.AdminUser, 0)
var total int64
session := h.db.Session(&gorm.Session{})
if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%")
}
// 查询total
session.Model(&model.AdminUser{}).Count(&total)
res := session.Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var userVo vo.AdminUser
err := utils.CopyObject(item, &userVo)
if err == nil {
var roles []role
h.db.Raw("SELECT r.id,r.name "+
"FROM chatgpt_admin_user_roles as ur "+
"LEFT JOIN chatgpt_admin_roles as r ON ur.role_id = r.id "+
"WHERE ur.admin_id = ?", item.Id).Scan(&roles)
userVo.Id = item.Id
userVo.CreatedAt = item.CreatedAt.Unix()
userVo.UpdatedAt = item.UpdatedAt.Unix()
userVo.RoleIds = roles
users = append(users, userVo)
} else {
logger.Error(err)
}
}
}
pageVo := vo.NewPage(total, page, pageSize, users)
resp.SUCCESS(c, pageVo)
}
// Save 更新或者新增
func (h *SysUserHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Password string `json:"password"`
Username string `json:"username"`
Status bool `json:"status"`
RoleIds []int `json:"role_ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 默认id为1是超级管理员
if data.Id == 1 {
resp.ERROR(c, "超级管理员不支持更新")
return
}
var user = model.AdminUser{}
var res *gorm.DB
var userVo vo.AdminUser
tx := h.db.Begin()
if data.Id > 0 { // 更新
user.Id = data.Id
err := tx.Where("admin_id = ?", user.Id).Delete(&model.AdminUserRole{})
if err.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败")
return
}
// 此处需要用 map 更新,用结构体无法更新 0 值
res = tx.Model(&user).Updates(map[string]interface{}{
"username": data.Username,
"status": data.Status,
})
} else {
salt := utils.RandString(8)
user.Username = data.Username
user.Password = utils.GenPassword(data.Password, salt)
user.Salt = salt
user.Status = true
res = tx.Create(&user)
_ = utils.CopyObject(user, &userVo)
userVo.Id = user.Id
userVo.CreatedAt = user.CreatedAt.Unix()
userVo.UpdatedAt = user.UpdatedAt.Unix()
}
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败")
return
}
// 添加角色
userRole := make([]model.AdminUserRole, 0)
if len(data.RoleIds) > 0 {
for _, roleId := range data.RoleIds {
userRole = append(userRole, model.AdminUserRole{
AdminId: user.Id,
RoleId: roleId,
})
}
err := tx.CreateInBatches(userRole, len(userRole))
if err.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败")
return
}
}
tx.Commit()
resp.SUCCESS(c, userVo)
}
// ResetPass 重置密码
func (h *SysUserHandler) ResetPass(c *gin.Context) {
var data struct {
Id uint
Password string
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.AdminUser
res := h.db.First(&user, data.Id)
if res.Error != nil {
resp.ERROR(c, "No user found")
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
resp.SUCCESS(c)
}
}
// Remove 删除
func (h *SysUserHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 默认id为1是超级管理员
if data.Id == 1 {
resp.ERROR(c, "超级管理员不能删除")
return
}
if data.Id > 0 {
tx := h.db.Begin()
res := tx.Where("id = ?", data.Id).Delete(&model.AdminUser{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
res2 := tx.Where("admin_id = ?", data.Id).Delete(&model.AdminUserRole{})
if res2.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
}
resp.SUCCESS(c)
}

View File

@@ -14,13 +14,10 @@ import (
type ApiKeyHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db}
h.App = app
return &h
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
}
func (h *ApiKeyHandler) Save(c *gin.Context) {
@@ -41,7 +38,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{}
if data.Id > 0 {
h.db.Find(&apiKey, data.Id)
h.DB.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
@@ -50,7 +47,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey.Enabled = data.Enabled
apiKey.ProxyURL = data.ProxyURL
apiKey.Name = data.Name
res := h.db.Save(&apiKey)
res := h.DB.Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -68,14 +65,14 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := h.db.Find(&items)
res := h.DB.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
@@ -105,7 +102,7 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -122,7 +119,7 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) {
return
}
if data.Id > 0 {
res := h.db.Where("id = ?", data.Id).Delete(&model.ApiKey{})
res := h.DB.Where("id = ?", data.Id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -13,9 +13,7 @@ type CaptchaHandler struct {
}
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
h := CaptchaHandler{}
h.App = app
return &h
return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
}
type CaptchaVo struct {

View File

@@ -14,13 +14,10 @@ import (
type ChatHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
h := ChatHandler{db: db}
h.App = app
return &h
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type chatItemVo struct {
@@ -36,7 +33,7 @@ type chatItemVo struct {
}
func (h *ChatHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
@@ -54,7 +51,7 @@ func (h *ChatHandler) List(c *gin.Context) {
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if data.Title != "" {
session = session.Where("title LIKE ?", "%"+data.Title+"%")
}
@@ -88,9 +85,9 @@ func (h *ChatHandler) List(c *gin.Context) {
var messages []model.ChatMessage
var users []model.User
var roles []model.ChatRole
h.db.Where("chat_id IN ?", chatIds).Find(&messages)
h.db.Where("id IN ?", userIds).Find(&users)
h.db.Where("id IN ?", roleIds).Find(&roles)
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles)
tokenMap := make(map[string]int)
userMap := make(map[uint]string)
@@ -155,7 +152,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if data.Content != "" {
session = session.Where("content LIKE ?", "%"+data.Content+"%")
}
@@ -183,7 +180,7 @@ func (h *ChatHandler) Messages(c *gin.Context) {
userIds = append(userIds, item.UserId)
}
var users []model.User
h.db.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", userIds).Find(&users)
userMap := make(map[uint]string)
for _, user := range users {
userMap[user.Id] = user.Username
@@ -210,7 +207,7 @@ func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
@@ -237,7 +234,7 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
return
}
tx := h.db.Begin()
tx := h.DB.Begin()
// 删除聊天记录
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
if res.Error != nil {
@@ -260,7 +257,7 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
// RemoveMessage 删除聊天记录
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tx := h.db.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
if tx.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -15,13 +15,10 @@ import (
type ChatModelHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ChatModelHandler) Save(c *gin.Context) {
@@ -59,7 +56,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -78,12 +75,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
@@ -120,7 +117,7 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -140,7 +137,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -157,7 +154,7 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
return
}
res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -15,13 +15,10 @@ import (
type ChatRoleHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
h := ChatRoleHandler{db: db}
h.App = app
return &h
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// Save 创建或者更新某个角色
@@ -41,7 +38,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&role)
res := h.DB.Save(&role)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -53,14 +50,14 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
}
func (h *ChatRoleHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.db.Order("sort_num ASC").Find(&items)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
@@ -93,7 +90,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -115,7 +112,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -135,7 +132,7 @@ func (h *ChatRoleHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Where("id = ?", data.Id).Delete(&model.ChatRole{})
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
if res.Error != nil {
resp.ERROR(c, "删除失败!")
return

View File

@@ -14,13 +14,10 @@ import (
type ConfigHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db}
h.App = app
return &h
return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ConfigHandler) Update(c *gin.Context) {
@@ -40,7 +37,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value}
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -48,7 +45,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
if config.Id > 0 {
config.Config = value
res := h.db.Updates(&config)
res := h.DB.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -56,7 +53,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// update config cache for AppServer
var cfg model.Config
h.db.Where("marker", data.Key).First(&cfg)
h.DB.Where("marker", data.Key).First(&cfg)
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
@@ -73,14 +70,14 @@ func (h *ConfigHandler) Update(c *gin.Context) {
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
key := c.Query("key")
var config model.Config
res := h.db.Where("marker", key).First(&config)
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return

View File

@@ -14,13 +14,10 @@ import (
type DashboardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
h := DashboardHandler{db: db}
h.App = app
return &h
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
type statsVo struct {
@@ -37,35 +34,35 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
var userCount int64
now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil {
stats.Users = userCount
}
// new chats statistic
var chatCount int64
res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
if res.Error == nil {
stats.Chats = chatCount
}
// tokens took stats
var historyMessages []model.ChatMessage
res = h.db.Where("created_at > ?", zeroTime).Find(&historyMessages)
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
for _, item := range historyMessages {
stats.Tokens += item.Tokens
}
// 众筹收入
var rewards []model.Reward
res = h.db.Where("created_at > ?", zeroTime).Find(&rewards)
res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
for _, item := range rewards {
stats.Income += item.Amount
}
// 订单收入
var orders []model.Order
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders {
stats.Income += item.Amount
}
@@ -84,7 +81,7 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
// 统计用户7天增加的曲线
var users []model.User
res = h.db.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
if res.Error == nil {
for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
@@ -92,20 +89,20 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
}
// 统计7天Token 消耗
res = h.db.Where("created_at > ?", startDate).Find(&historyMessages)
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
for _, item := range historyMessages {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
}
// 浮点数相加?
// 统计最近7天的众筹
res = h.db.Where("created_at > ?", startDate).Find(&rewards)
res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
for _, item := range rewards {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
// 统计最近7天的订单
res = h.db.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}

View File

@@ -17,13 +17,10 @@ import (
type FunctionHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
h := FunctionHandler{db: db}
h.App = app
return &h
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *FunctionHandler) Save(c *gin.Context) {
@@ -44,7 +41,7 @@ func (h *FunctionHandler) Save(c *gin.Context) {
Enabled: data.Enabled,
}
res := h.db.Save(&f)
res := h.DB.Save(&f)
if res.Error != nil {
resp.ERROR(c, "error with save data:"+res.Error.Error())
return
@@ -65,7 +62,7 @@ func (h *FunctionHandler) Set(c *gin.Context) {
return
}
res := h.db.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -74,13 +71,13 @@ func (h *FunctionHandler) Set(c *gin.Context) {
}
func (h *FunctionHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Function
res := h.db.Find(&items)
res := h.DB.Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
@@ -102,7 +99,7 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Delete(&model.Function{Id: uint(id)})
res := h.DB.Delete(&model.Function{Id: uint(id)})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -15,17 +15,14 @@ import (
type OrderHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
@@ -42,7 +39,7 @@ func (h *OrderHandler) List(c *gin.Context) {
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if data.OrderNo != "" {
session = session.Where("order_no", data.OrderNo)
}
@@ -82,7 +79,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
if id > 0 {
var item model.Order
res := h.db.First(&item, id)
res := h.DB.First(&item, id)
if res.Error != nil {
resp.ERROR(c, "记录不存在!")
return
@@ -93,7 +90,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
return
}
res = h.db.Unscoped().Where("id = ?", id).Delete(&model.Order{})
res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -15,13 +15,10 @@ import (
type ProductHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *ProductHandler) Save(c *gin.Context) {
@@ -51,7 +48,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&item)
res := h.DB.Save(&item)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -70,12 +67,12 @@ func (h *ProductHandler) Save(c *gin.Context) {
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
if enable {
session = session.Where("enabled", enable)
@@ -111,7 +108,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
return
}
res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
res := h.DB.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -131,7 +128,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
res := h.DB.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
@@ -145,7 +142,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.Product{})
res := h.DB.Where("id = ?", id).Delete(&model.Product{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -14,23 +14,20 @@ import (
type RewardHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db}
h.App = app
return &h
return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
func (h *RewardHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
var items []model.Reward
res := h.db.Order("id DESC").Find(&items)
res := h.DB.Order("id DESC").Find(&items)
var rewards = make([]vo.Reward, 0)
if res.Error == nil {
userIds := make([]uint, 0)
@@ -38,7 +35,7 @@ func (h *RewardHandler) List(c *gin.Context) {
userIds = append(userIds, v.UserId)
}
var users []model.User
h.db.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", userIds).Find(&users)
var userMap = make(map[uint]model.User)
for _, u := range users {
userMap[u.Id] = u
@@ -71,7 +68,7 @@ func (h *RewardHandler) Remove(c *gin.Context) {
return
}
if data.Id > 0 {
res := h.db.Where("id = ?", data.Id).Delete(&model.Reward{})
res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@@ -13,14 +13,11 @@ import (
type UploadHandler struct {
handler.BaseHandler
db *gorm.DB
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
adminHandler := &UploadHandler{db: db, uploaderManager: manager}
adminHandler.App = app
return adminHandler
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
@@ -30,7 +27,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
return
}
userId := 0
res := h.db.Create(&model.File{
res := h.DB.Create(&model.File{
UserId: userId,
Name: file.Name,
ObjKey: file.ObjKey,

View File

@@ -16,18 +16,15 @@ import (
type UserHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
h := UserHandler{db: db}
h.App = app
return &h
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
if err := utils.CheckPermission(c, h.db); err != nil {
if err := utils.CheckPermission(c, h.DB); err != nil {
resp.NotPermission(c)
return
}
@@ -41,7 +38,7 @@ func (h *UserHandler) List(c *gin.Context) {
var users = make([]vo.User, 0)
var total int64
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if username != "" {
session = session.Where("username LIKE ?", "%"+username+"%")
}
@@ -87,7 +84,7 @@ func (h *UserHandler) Save(c *gin.Context) {
if data.Id > 0 { // 更新
user.Id = data.Id
// 此处需要用 map 更新,用结构体无法更新 0 值
res = h.db.Model(&user).Updates(map[string]interface{}{
res = h.DB.Model(&user).Updates(map[string]interface{}{
"username": data.Username,
"status": data.Status,
"vip": data.Vip,
@@ -108,7 +105,7 @@ func (h *UserHandler) Save(c *gin.Context) {
ChatModels: utils.JsonEncode(data.ChatModels),
ExpiredTime: utils.Str2stamp(data.ExpiredTime),
}
res = h.db.Create(&u)
res = h.DB.Create(&u)
_ = utils.CopyObject(u, &userVo)
userVo.Id = u.Id
userVo.CreatedAt = u.CreatedAt.Unix()
@@ -135,7 +132,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.First(&user, data.Id)
res := h.DB.First(&user, data.Id)
if res.Error != nil {
resp.ERROR(c, "No user found")
return
@@ -143,7 +140,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -152,43 +149,33 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
func (h *UserHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
tx := h.db.Begin()
res := h.db.Where("id = ?", data.Id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
res = h.db.Where("user_id = ?", data.Id).Delete(&model.ChatItem{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除聊天历史记录
res = h.db.Where("user_id = ?", data.Id).Delete(&model.ChatMessage{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除登录日志
res = h.db.Where("user_id = ?", data.Id).Delete(&model.UserLoginLog{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
resp.SUCCESS(c)
}
@@ -196,10 +183,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
var total int64
h.db.Model(&model.UserLoginLog{}).Count(&total)
h.DB.Model(&model.UserLoginLog{}).Count(&total)
offset := (page - 1) * pageSize
var items []model.UserLoginLog
res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "获取数据失败")
return

View File

@@ -4,8 +4,11 @@ import (
"chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/utils"
"errors"
"fmt"
"gorm.io/gorm"
"strings"
"github.com/gin-gonic/gin"
@@ -15,6 +18,7 @@ var logger = logger2.GetLogger()
type BaseHandler struct {
App *core.AppServer
DB *gorm.DB
}
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
@@ -57,3 +61,27 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
}
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0
}
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
value, exists := c.Get(types.LoginUserCache)
if exists {
return value.(model.User), nil
}
userId, ok := c.Get(types.LoginUserID)
if !ok {
return model.User{}, errors.New("user not login")
}
var user model.User
res := h.DB.First(&user, userId)
// 更新缓存
if res.Error == nil {
c.Set(types.LoginUserCache, user)
}
return user, res.Error
}

View File

@@ -12,37 +12,34 @@ import (
type ChatModelHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
h := ChatModelHandler{db: db}
h.App = app
return &h
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
// 只加载用户订阅的 AI 模型
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
var res *gorm.DB
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
err := utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
}
var models []int
err = utils.JsonDecode(user.ChatModels, &models)
if err != nil {
resp.ERROR(c, "当前用户没有订阅任何模型")
return
}
// 查询用户有权限访问的模型以及所有开放的模型
res := h.db.Where("enabled = ?", true).Where(
h.db.Where("id IN ?", models).Or("open =?", true),
).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var cm vo.ChatModel

View File

@@ -14,27 +14,24 @@ import (
type ChatRoleHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
handler := &ChatRoleHandler{db: db}
handler.App = app
return handler
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List get user list
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error())
return
}
// 获取所有角色
if all {
if userId == 0 {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
@@ -49,13 +46,8 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.NotAuth(c)
return
}
var user model.User
h.db.First(&user, userId)
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
@@ -80,7 +72,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
// UpdateRole 更新用户聊天角色
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
@@ -94,7 +86,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
return
}
res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
if res.Error != nil {
logger.Error("添加应用失败:", err)
resp.ERROR(c, "更新数据库失败!")

View File

@@ -136,7 +136,7 @@ func (h *ChatHandler) sendAzureMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -158,7 +158,7 @@ func (h *ChatHandler) sendAzureMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -168,7 +168,7 @@ func (h *ChatHandler) sendAzureMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -180,7 +180,7 @@ func (h *ChatHandler) sendAzureMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -160,7 +160,7 @@ func (h *ChatHandler) sendBaiduMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -182,7 +182,7 @@ func (h *ChatHandler) sendBaiduMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -191,7 +191,7 @@ func (h *ChatHandler) sendBaiduMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -203,7 +203,7 @@ func (h *ChatHandler) sendBaiduMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -35,19 +35,16 @@ var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
uploadManager *oss.UploaderManager
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
h := ChatHandler{
db: db,
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
}
h.App = app
return &h
}
func (h *ChatHandler) Init() {
@@ -73,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
client := types.NewWsClient(ws)
// get model info
var chatModel model.ChatModel
res := h.db.First(&chatModel, modelId)
res := h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -82,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
@@ -99,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.db.Where("chat_id = ?", chatId).First(&chat)
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
@@ -116,7 +113,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
@@ -181,7 +178,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
return res.Error
@@ -238,7 +235,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.db.Where("enabled", true).Find(&items)
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
@@ -290,7 +287,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
@@ -382,7 +379,7 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -433,7 +430,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
@@ -459,7 +456,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
@@ -527,10 +524,10 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
h.db.Create(&model.PowerLog{
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,

View File

@@ -13,28 +13,22 @@ import (
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
return
}
// fix: 只能读取本人的消息列表
if uint(userId) != h.GetLoginUserId(c) {
resp.ERROR(c, "Hacker attempt, you can ONLY get yourself chats.")
if !h.IsLogin(c) {
resp.SUCCESS(c)
return
}
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.db.Find(&roles, roleIds)
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
@@ -66,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
@@ -78,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil {
resp.ERROR(c, "No chats found")
return
@@ -97,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
}
err = h.db.Transaction(func(tx *gorm.DB) error {
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil {
return res.Error
}
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
if res.Error != nil {
return res.Error
}
@@ -126,7 +120,7 @@ func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
@@ -152,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
// 删除当前会话的聊天记录
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.")
return
@@ -187,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
}
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
if res.Error != nil {
resp.ERROR(c, "No chat found")
return

View File

@@ -139,7 +139,7 @@ func (h *ChatHandler) sendChatGLMMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -161,7 +161,7 @@ func (h *ChatHandler) sendChatGLMMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -171,7 +171,7 @@ func (h *ChatHandler) sendChatGLMMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -183,7 +183,7 @@ func (h *ChatHandler) sendChatGLMMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -100,7 +100,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
if !utils.IsEmptyValue(tool) {
res := h.db.Where("name = ?", tool.Function.Name).First(&function)
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
@@ -210,7 +210,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -240,7 +240,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -250,7 +250,7 @@ func (h *ChatHandler) sendOpenAiMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -262,18 +262,19 @@ func (h *ChatHandler) sendOpenAiMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
logger.Debug(string(body))
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
@@ -281,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {

View File

@@ -172,7 +172,7 @@ func (h *ChatHandler) sendQWenMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -194,7 +194,7 @@ func (h *ChatHandler) sendQWenMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -204,7 +204,7 @@ func (h *ChatHandler) sendQWenMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -216,7 +216,7 @@ func (h *ChatHandler) sendQWenMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -69,13 +69,13 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
@@ -200,7 +200,7 @@ func (h *ChatHandler) sendXunFeiMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -222,7 +222,7 @@ func (h *ChatHandler) sendXunFeiMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -232,7 +232,7 @@ func (h *ChatHandler) sendXunFeiMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -244,7 +244,7 @@ func (h *ChatHandler) sendXunFeiMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}

View File

@@ -12,20 +12,17 @@ import (
type ConfigHandler struct {
BaseHandler
db *gorm.DB
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db}
h.App = app
return &h
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.db.Where("marker", key).First(&config)
res := h.DB.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return

View File

@@ -19,7 +19,6 @@ import (
type FunctionHandler struct {
BaseHandler
db *gorm.DB
config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager
}
@@ -28,8 +27,8 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo
return &FunctionHandler{
BaseHandler: BaseHandler{
App: server,
DB: db,
},
db: db,
config: config.ApiConfig,
uploadManager: manager,
}
@@ -191,7 +190,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
logger.Debugf("绘画参数:%+v", params)
var user model.User
tx := h.db.Where("id = ?", params["user_id"]).First(&user)
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
@@ -205,7 +204,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
return
@@ -213,7 +212,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt
@@ -242,7 +241,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return
}
// 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res)
// 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
@@ -253,10 +252,10 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// 更新用户算力
tx = h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
tx = h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
h.db.Create(&model.PowerLog{
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,

View File

@@ -15,32 +15,29 @@ import (
// InviteHandler 用户邀请
type InviteHandler struct {
BaseHandler
db *gorm.DB
}
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
h := InviteHandler{db: db}
h.App = app
return &h
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c)
var inviteCode model.InviteCode
res := h.db.Where("user_id = ?", userId).First(&inviteCode)
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
// 如果邀请码不存在,则创建一个
if res.Error != nil {
code := strings.ToUpper(utils.RandString(8))
for {
res = h.db.Where("code = ?", code).First(&inviteCode)
res = h.DB.Where("code = ?", code).First(&inviteCode)
if res.Error != nil { // 不存在相同的邀请码则退出
break
}
}
inviteCode.UserId = userId
inviteCode.Code = code
h.db.Create(&inviteCode)
h.DB.Create(&inviteCode)
}
var codeVo vo.InviteCode
@@ -65,7 +62,7 @@ func (h *InviteHandler) List(c *gin.Context) {
return
}
userId := h.GetLoginUserId(c)
session := h.db.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
@@ -91,6 +88,6 @@ func (h *InviteHandler) List(c *gin.Context) {
// Hits 访问邀请码
func (h *InviteHandler) Hits(c *gin.Context) {
code := c.Query("code")
h.db.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c)
}

View File

@@ -24,25 +24,25 @@ import (
type MidJourneyHandler struct {
BaseHandler
db *gorm.DB
pool *mj.ServicePool
snowflake *service.Snowflake
uploader *oss.UploaderManager
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
h := MidJourneyHandler{
db: db,
return &MidJourneyHandler{
snowflake: snowflake,
pool: pool,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
@@ -172,7 +172,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
opt = "换脸"
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
@@ -193,11 +193,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
}
// update user's power
tx := h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := utils.GetLoginUser(c, h.db)
h.db.Create(&model.PowerLog{
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
@@ -248,7 +248,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
@@ -299,7 +299,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
@@ -322,11 +322,11 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
// update user's power
tx := h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := utils.GetLoginUser(c, h.db)
h.db.Create(&model.PowerLog{
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
@@ -373,7 +373,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
@@ -434,7 +434,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
}
// remove job recode
res := h.db.Delete(&model.MidJourneyJob{Id: data.Id})
res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -486,7 +486,7 @@ func (h *MidJourneyHandler) Publish(c *gin.Context) {
return
}
res := h.db.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return

View File

@@ -14,13 +14,10 @@ import (
type OrderHandler struct {
BaseHandler
db *gorm.DB
}
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
h := OrderHandler{db: db}
h.App = app
return &h
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
func (h *OrderHandler) List(c *gin.Context) {
@@ -33,7 +30,7 @@ func (h *OrderHandler) List(c *gin.Context) {
return
}
userId := h.GetLoginUserId(c)
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
var total int64
session.Model(&model.Order{}).Count(&total)
var items []model.Order

View File

@@ -35,7 +35,6 @@ type PaymentHandler struct {
huPiPayService *payment.HuPiPayService
js *payment.PayJS
snowflake *service.Snowflake
db *gorm.DB
fs embed.FS
lock sync.Mutex
}
@@ -45,20 +44,21 @@ func NewPaymentHandler(
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
js *payment.PayJS,
snowflake *service.Snowflake,
db *gorm.DB,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
h := PaymentHandler{
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
js: js,
snowflake: snowflake,
fs: fs,
db: db,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
}
h.App = server
return &h
}
func (h *PaymentHandler) DoPay(c *gin.Context) {
@@ -71,7 +71,7 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
@@ -84,7 +84,7 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
}
// 更新扫码状态
h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
if payWay == "alipay" { // 支付宝
// 生成支付链接
notifyURL := h.App.Config.AlipayConfig.NotifyURL
@@ -130,7 +130,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
}
var order model.Order
res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
if res.Error != nil {
resp.ERROR(c, "Order not found")
return
@@ -145,7 +145,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
for {
time.Sleep(time.Second)
var item model.Order
h.db.Where("order_no = ?", data.OrderNo).First(&item)
h.DB.Where("order_no = ?", data.OrderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
@@ -169,7 +169,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
}
var product model.Product
res := h.db.First(&product, data.ProductId)
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
@@ -181,7 +181,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return
}
var user model.User
res = h.db.First(&user, data.UserId)
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
@@ -221,7 +221,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.db.Create(&order)
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
@@ -291,7 +291,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
}
var product model.Product
res := h.db.First(&product, data.ProductId)
res := h.DB.First(&product, data.ProductId)
if res.Error != nil {
resp.ERROR(c, "Product not found")
return
@@ -303,7 +303,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
return
}
var user model.User
res = h.db.First(&user, data.UserId)
res = h.DB.First(&user, data.UserId)
if res.Error != nil {
resp.ERROR(c, "Invalid user ID")
return
@@ -343,7 +343,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
PayWay: payWay,
Remark: utils.JsonEncode(remark),
}
res = h.db.Create(&order)
res = h.DB.Create(&order)
if res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "error with create order: "+res.Error.Error())
return
@@ -402,7 +402,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var order model.Order
res := h.db.Where("order_no = ?", orderNo).First(&order)
res := h.DB.Where("order_no = ?", orderNo).First(&order)
if res.Error != nil {
err := fmt.Errorf("error with fetch order: %v", res.Error)
logger.Error(err)
@@ -418,7 +418,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
}
var user model.User
res = h.db.First(&user, order.UserId)
res = h.DB.First(&user, order.UserId)
if res.Error != nil {
err := fmt.Errorf("error with fetch user info: %v", res.Error)
logger.Error(err)
@@ -444,7 +444,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
power = remark.Power
}
} else { // 非 VIP 用户
} else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, power == 0
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
@@ -459,7 +459,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
}
// 更新用户信息
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
err := fmt.Errorf("error with update user info: %v", res.Error)
logger.Error(err)
@@ -470,7 +470,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
order.PayTime = time.Now().Unix()
order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo
res = h.db.Updates(&order)
res = h.DB.Updates(&order)
if res.Error != nil {
err := fmt.Errorf("error with update order info: %v", res.Error)
logger.Error(err)
@@ -478,11 +478,11 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
}
// 更新产品销量
h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
// 记录算力充值日志
if opt != "" {
h.db.Create(&model.PowerLog{
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRecharge,

View File

@@ -12,20 +12,17 @@ import (
type ProductHandler struct {
BaseHandler
db *gorm.DB
}
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
h := ProductHandler{db: db}
h.App = app
return &h
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product
var list = make([]vo.Product, 0)
res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Product

View File

@@ -16,13 +16,10 @@ const translatePromptTemplate = "Translate the following painting prompt words i
type PromptHandler struct {
BaseHandler
db *gorm.DB
}
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
h := &PromptHandler{db: db}
h.App = app
return h
return &PromptHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Rewrite translate and rewrite prompt with ChatGPT
@@ -35,7 +32,7 @@ func (h *PromptHandler) Rewrite(c *gin.Context) {
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt))
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(rewritePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -53,7 +50,7 @@ func (h *PromptHandler) Translate(c *gin.Context) {
return
}
content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt))
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, data.Prompt))
if err != nil {
resp.ERROR(c, err.Error())
return

View File

@@ -18,14 +18,11 @@ import (
type RewardHandler struct {
BaseHandler
db *gorm.DB
lock sync.Mutex
}
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
h := RewardHandler{db: db, lock: sync.Mutex{}}
h.App = server
return &h
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// Verify 打赏码核销
@@ -38,7 +35,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.HACKER(c)
return
@@ -51,7 +48,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
defer h.lock.Unlock()
var item model.Reward
res := h.db.Where("tx_id = ?", data.TxId).First(&item)
res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
if res.Error != nil {
resp.ERROR(c, "无效的众筹交易流水号!")
return
@@ -62,7 +59,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
return
}
tx := h.db.Begin()
tx := h.DB.Begin()
exchange := vo.RewardExchange{}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
@@ -85,7 +82,7 @@ func (h *RewardHandler) Verify(c *gin.Context) {
}
// 记录算力充值日志
h.db.Create(&model.PowerLog{
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,

View File

@@ -24,19 +24,19 @@ import (
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
pool *sd.ServicePool
uploader *oss.UploaderManager
}
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager) *SdJobHandler {
h := SdJobHandler{
db: db,
return &SdJobHandler{
pool: pool,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
@@ -61,7 +61,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
@@ -143,7 +143,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Power: h.App.SysConfig.SdPower,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
@@ -164,11 +164,11 @@ func (h *SdJobHandler) Image(c *gin.Context) {
}
// update user's power
tx := h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := utils.GetLoginUser(c, h.db)
h.db.Create(&model.PowerLog{
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
@@ -217,7 +217,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
session := h.db.Session(&gorm.Session{})
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
@@ -274,7 +274,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
}
// remove job recode
res := h.db.Delete(&model.SdJob{Id: data.Id})
res := h.DB.Delete(&model.SdJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -305,7 +305,7 @@ func (h *SdJobHandler) Publish(c *gin.Context) {
return
}
res := h.db.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return

View File

@@ -29,9 +29,12 @@ func NewSmsHandler(
sms *sms.ServiceManager,
smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler {
handler := &SmsHandler{redis: client, sms: sms, captcha: captcha, smtp: smtp}
handler.App = app
return handler
return &SmsHandler{
redis: client,
sms: sms,
captcha: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
}
// SendCode 发送验证码

View File

@@ -3,12 +3,6 @@ package handler
import (
"chatplus/service"
"chatplus/service/payment"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
@@ -21,208 +15,3 @@ type TestHandler struct {
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js}
}
type reqBody struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
type resBody struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
func (h *TestHandler) Test(c *gin.Context) {
image(c)
}
func upscale(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/submit/action"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := map[string]string{
"customId": "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a",
"taskId": "1704880156226095",
"notifyHook": "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type queryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func query(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+r.Status)
return
}
resp.SUCCESS(c, res)
}
type errRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func image(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := reqBody{
BotType: "MID_JOURNEY",
Prompt: "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6",
NotifyHook: "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type cbReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (h *TestHandler) Mj(c *gin.Context) {
var data cbReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error(err)
}
logger.Debugf("任务ID%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt)
apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
_, _ = req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress)
}
func (h *TestHandler) initUserNickname(c *gin.Context) {
var users []model.User
tx := h.db.Find(&users)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
for _, u := range users {
u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
h.db.Updates(&u)
}
resp.SUCCESS(c)
}
func (h *TestHandler) initMjTaskId(c *gin.Context) {
var jobs []model.MidJourneyJob
tx := h.db.Find(&jobs)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
for _, job := range jobs {
id, _ := h.snowflake.Next(true)
job.TaskId = id
h.db.Updates(&job)
}
resp.SUCCESS(c)
}

View File

@@ -14,14 +14,11 @@ import (
type UploadHandler struct {
BaseHandler
db *gorm.DB
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
handler := &UploadHandler{db: db, uploaderManager: manager}
handler.App = app
return handler
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
@@ -32,7 +29,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
}
userId := h.GetLoginUserId(c)
res := h.db.Create(&model.File{
res := h.DB.Create(&model.File{
UserId: int(userId),
Name: file.Name,
ObjKey: file.ObjKey,
@@ -53,7 +50,7 @@ func (h *UploadHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items []model.File
var files = make([]vo.File, 0)
h.db.Where("user_id = ?", userId).Find(&items)
h.DB.Where("user_id = ?", userId).Find(&items)
if len(items) > 0 {
for _, v := range items {
var file vo.File
@@ -75,14 +72,14 @@ func (h *UploadHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
tx := h.db.Where("user_id = ? AND id = ?", userId, id).First(&file)
tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
if tx.Error != nil || file.Id == 0 {
resp.ERROR(c, "file not existed")
return
}
// remove database
tx = h.db.Model(&model.File{}).Delete("id = ?", id)
tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
if tx.Error != nil || tx.RowsAffected == 0 {
resp.ERROR(c, "failed to update database")
return

View File

@@ -21,7 +21,6 @@ import (
type UserHandler struct {
BaseHandler
db *gorm.DB
searcher *xdb.Searcher
redis *redis.Client
}
@@ -31,15 +30,14 @@ func NewUserHandler(
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, redis: client}
handler.App = app
return handler
return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
}
// Register user register
func (h *UserHandler) Register(c *gin.Context) {
// parameters process
var data struct {
RegWay string `json:"reg_way"`
Username string `json:"username"`
Password string `json:"password"`
Code string `json:"code"`
@@ -57,8 +55,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检查验证码
var key string
if utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") ||
utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
if data.RegWay == "email" || data.RegWay == "mobile" {
key = CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
@@ -70,7 +67,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 验证邀请码
inviteCode := model.InviteCode{}
if data.InviteCode != "" {
res := h.db.Where("code = ?", data.InviteCode).First(&inviteCode)
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "无效的邀请码")
return
@@ -79,7 +76,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// check if the username is exists
var item model.User
res := h.db.Where("username = ?", data.Username).First(&item)
res := h.DB.Where("username = ?", data.Username).First(&item)
if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册")
return
@@ -98,7 +95,7 @@ func (h *UserHandler) Register(c *gin.Context) {
Power: h.App.SysConfig.InitPower,
}
res = h.db.Create(&user)
res = h.DB.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
@@ -108,13 +105,13 @@ func (h *UserHandler) Register(c *gin.Context) {
// 记录邀请关系
if data.InviteCode != "" {
// 增加邀请数量
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
}
// 添加邀请记录
h.db.Create(&model.InviteLog{
h.DB.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
@@ -155,7 +152,7 @@ func (h *UserHandler) Login(c *gin.Context) {
return
}
var user model.User
res := h.db.Where("username = ?", data.Username).First(&user)
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户名不存在")
return
@@ -175,9 +172,9 @@ func (h *UserHandler) Login(c *gin.Context) {
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.db.Model(&user).Updates(user)
h.DB.Model(&user).Updates(user)
h.db.Create(&model.UserLoginLog{
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
@@ -222,7 +219,7 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
@@ -248,13 +245,13 @@ type userProfile struct {
}
func (h *UserHandler) Profile(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
var profile userProfile
err = utils.CopyObject(user, &profile)
if err != nil {
@@ -274,15 +271,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
h.DB.First(&user, user.Id)
user.Avatar = data.Avatar
user.Nickname = data.Nickname
res := h.db.Updates(&user)
res := h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新用户信息失败")
return
@@ -307,7 +304,7 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
@@ -321,7 +318,7 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
}
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.db.Model(&user).UpdateColumn("password", newPass)
res := h.DB.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败")
@@ -344,7 +341,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
var user model.User
res := h.db.Where("username", data.Username).First(&user)
res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户不存在!")
return
@@ -360,7 +357,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.db.Updates(&user)
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
} else {
@@ -390,19 +387,19 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定
var item model.User
res := h.db.Where("username = ?", data.Username).First(&item)
res := h.DB.Where("username = ?", data.Username).First(&item)
if res.Error == nil {
resp.ERROR(c, "该账号已经被其他账号绑定")
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res = h.db.Model(&user).UpdateColumn("username", data.Username)
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return