fix: fixed conflicts

This commit is contained in:
RockYang
2023-07-10 10:11:17 +08:00
64 changed files with 14 additions and 21 deletions

View File

@@ -0,0 +1,138 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB) *ManagerHandler {
h := ManagerHandler{db: db}
h.App = app
return &h
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
manager := h.App.Config.Manager
if data.Username == manager.Username && data.Password == manager.Password {
err := utils.SetLoginAdmin(c, manager)
if err != nil {
resp.ERROR(c, "Save session failed")
return
}
manager.Password = "" // 清空密码]
resp.SUCCESS(c, manager)
} else {
resp.ERROR(c, "用户名或者密码错误")
}
}
// Logout 注销
func (h *ManagerHandler) Logout(c *gin.Context) {
session := sessions.Default(c)
session.Delete(types.SessionAdmin)
err := session.Save()
if err != nil {
resp.ERROR(c, "Save session failed")
} else {
resp.SUCCESS(c)
}
}
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
session := sessions.Default(c)
admin := session.Get(types.SessionAdmin)
if admin == nil {
resp.NotAuth(c)
} else {
resp.SUCCESS(c)
}
}
// Migrate 数据修正
func (h *ManagerHandler) Migrate(c *gin.Context) {
opt := c.Query("opt")
switch opt {
case "user":
// 将用户订阅角色的数据结构从 map 改成数组
var users []model.User
h.db.Find(&users)
for _, u := range users {
var m map[string]int
var roleKeys = make([]string, 0)
err := utils.JsonDecode(u.ChatRoles, &m)
if err != nil {
continue
}
for k := range m {
roleKeys = append(roleKeys, k)
}
u.ChatRoles = utils.JsonEncode(roleKeys)
h.db.Updates(&u)
}
break
case "role":
// 修改角色图片,改成绝对路径
var roles []model.ChatRole
h.db.Find(&roles)
for _, r := range roles {
if !strings.HasPrefix(r.Icon, "/") {
r.Icon = "/" + r.Icon
h.db.Updates(&r)
}
}
break
case "history":
// 修改角色图片,改成绝对路径
var message []model.HistoryMessage
h.db.Find(&message)
for _, r := range message {
if !strings.HasPrefix(r.Icon, "/") {
r.Icon = "/" + r.Icon
h.db.Updates(&r)
}
}
break
case "avatar":
// 更新用户的头像地址
var users []model.User
h.db.Find(&users)
for _, u := range users {
if !strings.HasPrefix(u.Avatar, "/") {
u.Avatar = "/" + u.Avatar
h.db.Updates(&u)
}
}
break
}
resp.SUCCESS(c, "SUCCESS")
}

View File

@@ -0,0 +1,100 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ApiKeyHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db}
h.App = app
return &h
}
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Value string `json:"value"`
LastUsedAt string `json:"last_used_at"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
apiKey := model.ApiKey{Value: data.Value, UserId: data.UserId, LastUsedAt: utils.Str2stamp(data.LastUsedAt)}
apiKey.Id = data.Id
if apiKey.Id > 0 {
apiKey.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
var keyVo vo.ApiKey
err := utils.CopyObject(apiKey, &keyVo)
if err != nil {
resp.ERROR(c, "数据拷贝失败!")
return
}
keyVo.Id = apiKey.Id
keyVo.CreatedAt = apiKey.CreatedAt.Unix()
resp.SUCCESS(c, keyVo)
}
func (h *ApiKeyHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", -1)
query := h.db.Session(&gorm.Session{})
if userId >= 0 {
query = query.Where("user_id", userId)
}
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)
res := query.Find(&items)
if res.Error == nil {
for _, item := range items {
var key vo.ApiKey
err := utils.CopyObject(item, &key)
if err == nil {
key.Id = item.Id
key.CreatedAt = item.CreatedAt.Unix()
key.UpdatedAt = item.UpdatedAt.Unix()
keys = append(keys, key)
} else {
logger.Error(err)
}
}
}
resp.SUCCESS(c, keys)
}
func (h *ApiKeyHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,114 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type ChatRoleHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
h := ChatRoleHandler{db: db}
h.App = app
return &h
}
// Save 创建或者更新某个角色
func (h *ChatRoleHandler) Save(c *gin.Context) {
var data vo.ChatRole
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var role model.ChatRole
err := utils.CopyObject(data, &role)
if err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
role.Id = data.Id
if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0)
}
res := h.db.Save(&role)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
// 填充 ID 数据
data.Id = role.Id
data.CreatedAt = role.CreatedAt.Unix()
resp.SUCCESS(c, data)
}
func (h *ChatRoleHandler) List(c *gin.Context) {
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
res := h.db.Order("sort ASC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
return
}
for _, v := range items {
var role vo.ChatRole
err := utils.CopyObject(v, &role)
if err == nil {
role.Id = v.Id
role.CreatedAt = v.CreatedAt.Unix()
role.UpdatedAt = v.UpdatedAt.Unix()
roles = append(roles, role)
}
}
resp.SUCCESS(c, roles)
}
// SetSort 更新角色排序
func (h *ChatRoleHandler) SetSort(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Sort int `json:"sort"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id <= 0 {
resp.HACKER(c)
return
}
res := h.db.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update("sort", data.Sort)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return
}
resp.SUCCESS(c)
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
if res.Error != nil {
resp.ERROR(c, "删除失败!")
return
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,74 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ConfigHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
h := ConfigHandler{db: db}
h.App = app
return &h
}
func (h *ConfigHandler) Update(c *gin.Context) {
var data struct {
Key string `json:"key"`
Config map[string]interface{} `json:"config"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
str := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: str}
res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
if config.Id > 0 {
config.Config = str
res := h.db.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
}
resp.SUCCESS(c, config)
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
var config model.Config
res := h.db.Where("marker", key).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var m map[string]interface{}
err := utils.JsonDecode(config.Config, &m)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, m)
}

View File

@@ -0,0 +1,145 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type UserHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
h := UserHandler{db: db}
h.App = app
return &h
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
offset := (page - 1) * pageSize
var items []model.User
var users = make([]vo.User, 0)
var total int64
h.db.Model(&model.User{}).Count(&total)
res := h.db.Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var user vo.User
err := utils.CopyObject(item, &user)
if err == nil {
user.Id = item.Id
user.CreatedAt = item.CreatedAt.Unix()
user.UpdatedAt = item.UpdatedAt.Unix()
users = append(users, user)
} else {
logger.Error(err)
}
}
}
pageVo := vo.NewPage(total, page, pageSize, users)
resp.SUCCESS(c, pageVo)
}
func (h *UserHandler) Update(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Nickname string `json:"nickname"`
Calls int `json:"calls"`
ChatRoles []string `json:"chat_roles"`
ExpiredTime string `json:"expired_time"`
Status bool `json:"status"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user = model.User{}
user.Id = data.Id
// 此处需要用 map 更新,用结构体无法更新 0 值
res := h.db.Model(&user).Updates(map[string]interface{}{
"nickname": data.Nickname,
"calls": data.Calls,
"status": data.Status,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
"expired_time": utils.Str2stamp(data.ExpiredTime),
})
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}
func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id > 0 {
tx := h.db.Begin()
res := h.db.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
resp.ERROR(c, "删除失败")
return
}
// 删除聊天记录
res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除聊天历史记录
res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
// 删除登录日志
res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "删除失败")
return
}
tx.Commit()
}
resp.SUCCESS(c)
}
func (h *UserHandler) LoginLog(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
var total int64
h.db.Model(&model.UserLoginLog{}).Count(&total)
offset := (page - 1) * pageSize
var items []model.UserLoginLog
res := h.db.Offset(offset).Limit(pageSize).Find(&items)
if res.Error != nil {
resp.ERROR(c, "获取数据失败")
return
}
var logs []vo.UserLoginLog
for _, v := range items {
var log vo.UserLoginLog
err := utils.CopyObject(v, &log)
if err == nil {
log.Id = v.Id
log.CreatedAt = v.CreatedAt.Unix()
logs = append(logs, log)
}
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, logs))
}

View File

@@ -0,0 +1,66 @@
package handler
import (
"chatplus/core"
logger2 "chatplus/logger"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
var logger = logger2.GetLogger()
type BaseHandler struct {
App *core.AppServer
}
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
return strings.TrimSpace(c.Query(key))
}
func (h *BaseHandler) PostInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.PostForm(key), defaultValue)
}
func (h *BaseHandler) GetInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.Query(key), defaultValue)
}
func intValue(str string, defaultValue int) int {
value, err := strconv.Atoi(str)
if err != nil {
return defaultValue
}
return value
}
func (h *BaseHandler) GetFloat(c *gin.Context, key string) float64 {
return floatValue(c.Query(key))
}
func (h *BaseHandler) PostFloat(c *gin.Context, key string) float64 {
return floatValue(c.PostForm(key))
}
func floatValue(str string) float64 {
value, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0
}
return value
}
func (h *BaseHandler) GetBool(c *gin.Context, key string) bool {
return boolValue(c.Query(key))
}
func (h *BaseHandler) PostBool(c *gin.Context, key string) bool {
return boolValue(c.PostForm(key))
}
func boolValue(str string) bool {
value, err := strconv.ParseBool(str)
if err != nil {
return false
}
return value
}

473
api/handler/chat_handler.go Normal file
View File

@@ -0,0 +1,473 @@
package handler
import (
"bufio"
"bytes"
"chatplus/core"
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
type ChatHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db}
handler.App = app
return &handler
}
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
chatModel := c.Query("model")
session := h.App.ChatSession.Get(sessionId)
if session.SessionId == "" {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
logger.Info("用户未登录")
c.Abort()
return
}
session = types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
Username: user.Username,
UserId: user.Id,
}
h.App.ChatSession.Put(sessionId, session)
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
res := h.db.Where("chat_id=?", chatId).First(&chat)
if res.Error == nil {
chatModel = chat.Model
roleId = int(chat.RoleId)
}
session.ChatId = chatId
session.Model = chatModel
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
replyMessage(client, "当前聊天角色不存在或者未启用!!!")
c.Abort()
return
}
// 保存会话连接
h.App.ChatClients.Put(sessionId, client)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
h.App.ChatClients.Delete(sessionId)
h.App.ReqCancelFunc.Delete(sessionId)
return
}
logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, string(message), client)
if err != nil {
logger.Error(err)
} else {
replyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Info("回答完毕: " + string(message))
}
}
}()
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
promptCreatedAt := time.Now() // 记录提问时间
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
replyMessage(ws, "非法用户,请联系管理员!")
return res.Error
}
var userVo vo.User
err := utils.CopyObject(user, &userVo)
userVo.Id = user.Id
if err != nil {
return errors.New("User 对象转换失败," + err.Error())
}
if userVo.Status == false {
replyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
replyMessage(ws, "![](/images/wx.png)")
return nil
}
if userVo.Calls <= 0 {
replyMessage(ws, "您的对话次数已经用尽,请联系管理员充值!")
replyMessage(ws, "![](/images/wx.png)")
return nil
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
replyMessage(ws, "您的账号已经过期,请联系管理员!")
replyMessage(ws, "![](/images/wx.png)")
return nil
}
var req = types.ApiRequest{
Model: session.Model,
Temperature: userVo.ChatConfig.Temperature,
MaxTokens: userVo.ChatConfig.MaxTokens,
Stream: true,
}
// 加载聊天上下文
var chatCtx []types.Message
if userVo.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
} else {
// 加载角色信息
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
chatCtx = messages
}
// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的
var historyMessages []model.HistoryMessage
res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages)
if res.Error == nil {
for _, msg := range historyMessages {
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
}
chatCtx = append(chatCtx, ms)
}
}
}
if h.App.Debug { // 调试打印聊天上下文
logger.Info("聊天上下文:", chatCtx)
}
}
req.Messages = append(chatCtx, types.Message{
Role: "user",
Content: prompt,
})
var apiKey string
response, err := h.doRequest(ctx, userVo, &apiKey, req)
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
replyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
return nil
} else {
logger.Error(err)
}
replyMessage(ws, ErrorMsg)
replyMessage(ws, "![](/images/wx.png)")
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now()
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
reader := bufio.NewReader(response.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error(err)
}
break
}
if !strings.Contains(line, "data:") {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
logger.Error(err, line)
replyMessage(ws, ErrorMsg)
replyMessage(ws, "![](/images/wx.png)")
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
})
}
} // end for
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
if res.Error != nil {
return res.Error
}
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息
if userVo.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
// for prompt
token, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: user.Avatar,
Content: prompt,
Tokens: token,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
token, err = utils.CalcTokens(message.Content, req.Model)
if err != nil {
logger.Error(err)
}
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: token,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 统计用户 token 数量
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
historyUserMsg.Tokens+historyReplyMsg.Tokens))
}
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.Model = session.Model
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
// OpenAI API 调用异常处理
// TODO: 是否考虑重发消息?
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
replyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
replyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
// 只保留最近的三条记录
chatContext := h.App.ChatContexts.Get(session.ChatId)
if len(chatContext) > 3 {
chatContext = chatContext[len(chatContext)-3:]
}
h.App.ChatContexts.Put(session.ChatId, chatContext)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
}
}
return nil
}
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
// 创建 HttpClient 请求对象
request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Add("Content-Type", "application/json")
proxyURL := h.App.Config.ProxyURL
if proxyURL == "" {
client = &http.Client{}
} else { // 使用代理
uri := url.URL{}
proxy, _ := uri.Parse(proxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
}
// 查询当前用户是否导入了自己的 API KEY
if user.ChatConfig.ApiKey != "" {
logger.Info("使用用户自己的 API KEY: ", user.ChatConfig.ApiKey)
*apiKey = user.ChatConfig.ApiKey
} else { // 获取系统的 API KEY
var key model.ApiKey
res := h.db.Where("user_id = ?", 0).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
*apiKey = key.Value
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
}
logger.Infof("Sending OpenAI request, KEY: %s, PROXY: %s, Model: %s", *apiKey, proxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
return client.Do(request)
}
// 回复客户片段端消息
func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*types.WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws types.Client, message string) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
}
// Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) {
text := c.Query("text")
md := c.Query("model")
tokens, err := utils.CalcTokens(text, md)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, tokens)
}
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")
if h.App.ReqCancelFunc.Has(sessionId) {
h.App.ReqCancelFunc.Get(sessionId)()
h.App.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}

View File

@@ -0,0 +1,157 @@
package handler
import (
"chatplus/core/types"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
)
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
return
}
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.db.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
roleMap[role.Id] = role
}
for _, chat := range chats {
var item vo.ChatItem
err := utils.CopyObject(chat, &item)
if err == nil {
item.Id = chat.Id
item.Icon = roleMap[chat.RoleId].Icon
items = append(items, item)
}
}
}
}
resp.SUCCESS(c, items)
}
// Update 更新会话标题
func (h *ChatHandler) Update(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Title string `json:"title"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var m = model.ChatItem{}
m.Id = data.Id
res := h.db.Model(&m).UpdateColumn("title", data.Title)
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
resp.SUCCESS(c, types.OkMsg)
}
// Remove 删除会话
func (h *ChatHandler) Remove(c *gin.Context) {
chatId := h.GetTrim(c, "chat_id")
if chatId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
// 清空会话上下文
h.App.ChatContexts.Delete(chatId)
resp.SUCCESS(c, types.OkMsg)
}
// History 获取聊天历史记录
func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
var items []model.HistoryMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ? AND user_id = ?", chatId, user.Id).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
} else {
for _, item := range items {
var v vo.HistoryMessage
err := utils.CopyObject(item, &v)
v.CreatedAt = item.CreatedAt.Unix()
v.UpdatedAt = item.UpdatedAt.Unix()
if err == nil {
messages = append(messages, v)
}
}
}
resp.SUCCESS(c, messages)
}
// Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil {
resp.ERROR(c, "No chats found")
return
}
// 清空聊天记录
for _, chat := range chats {
err := h.db.Where("chat_id = ? AND user_id = ?", chat.ChatId, user.Id).Delete(&model.HistoryMessage{})
if err != nil {
logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId)
}
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
}
// 删除所有的会话记录
res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.")
return
}
resp.SUCCESS(c, types.OkMsg)
}

View File

@@ -0,0 +1,59 @@
package handler
import (
"chatplus/core"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ChatRoleHandler struct {
BaseHandler
db *gorm.DB
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
handler := &ChatRoleHandler{db: db}
handler.App = app
return handler
}
// List get user list
func (h *ChatRoleHandler) List(c *gin.Context) {
var roles []model.ChatRole
res := h.db.Where("enable", true).Order("sort ASC").Find(&roles)
if res.Error != nil {
resp.ERROR(c, "No roles found,"+res.Error.Error())
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
var roleKeys []string
err = utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return
}
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
if !utils.ContainsStr(roleKeys, r.Key) {
continue
}
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
}

View File

@@ -0,0 +1,67 @@
package handler
import (
"chatplus/core"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"os"
"path/filepath"
"time"
)
type UploadHandler struct {
BaseHandler
db *gorm.DB
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB) *UploadHandler {
handler := &UploadHandler{db: db}
handler.App = app
return handler
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
resp.ERROR(c, fmt.Sprintf("文件上传失败: %s", err.Error()))
return
}
filePath, err := h.genFilePath(file.Filename)
if err != nil {
resp.ERROR(c, fmt.Sprintf("文件上传失败: %s", err.Error()))
return
}
// 将文件保存到指定路径
err = c.SaveUploadedFile(file, filePath)
if err != nil {
resp.ERROR(c, fmt.Sprintf("文件保存失败: %s", err.Error()))
return
}
resp.SUCCESS(c, h.genFileUrl(filePath))
}
// 生成上传文件路径
func (h *UploadHandler) genFilePath(filename string) (string, error) {
now := time.Now()
dir := fmt.Sprintf("%s/upload/%d/%d", h.App.Config.StaticDir, now.Year(), now.Month())
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0755)
if err != nil {
return "", fmt.Errorf("创建上传目录失败:%s", err)
}
}
fileExt := filepath.Ext(filename)
return fmt.Sprintf("%s/%d%s", dir, now.UnixMilli(), fileExt), nil
}
// 生成上传文件 URL
func (h *UploadHandler) genFileUrl(filePath string) string {
now := time.Now()
filename := filepath.Base(filePath)
return fmt.Sprintf("%s/upload/%d/%d/%s", h.App.Config.StaticUrl, now.Year(), now.Month(), filename)
}

389
api/handler/user_handler.go Normal file
View File

@@ -0,0 +1,389 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"gorm.io/gorm"
)
type UserHandler struct {
BaseHandler
db *gorm.DB
searcher *xdb.Searcher
levelDB *store.LevelDB
}
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, levelDB: levelDB}
handler.App = app
return handler
}
// Register user register
func (h *UserHandler) Register(c *gin.Context) {
// parameters process
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Mobile string `json:"mobile"`
Code int `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
data.Username = strings.TrimSpace(data.Username)
data.Password = strings.TrimSpace(data.Password)
if len(data.Username) < 5 {
resp.ERROR(c, "用户名长度不能少于5个字符")
return
}
if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符")
return
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
var code int
err := h.levelDB.Get(key, &code)
if err != nil || code != data.Code {
logger.Info(code)
resp.ERROR(c, "短信验证码错误")
return
}
// check if the username is exists
var item model.User
res := h.db.Where("username = ?", data.Username).First(&item)
if res.RowsAffected > 0 {
resp.ERROR(c, "用户名已存在")
return
}
res = h.db.Where("mobile = ?", data.Mobile).First(&item)
if res.RowsAffected > 0 {
resp.ERROR(c, "该手机号码以及被注册,请更换其他手机号")
return
}
// 默认订阅所有角色
var chatRoles []model.ChatRole
h.db.Find(&chatRoles)
var roleKeys = make([]string, 0)
for _, r := range chatRoles {
roleKeys = append(roleKeys, r.Key)
}
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(5)),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
Mobile: data.Mobile,
ChatRoles: utils.JsonEncode(roleKeys),
ChatConfig: utils.JsonEncode(types.ChatConfig{
Temperature: h.App.ChatConfig.Temperature,
MaxTokens: h.App.ChatConfig.MaxTokens,
EnableContext: h.App.ChatConfig.EnableContext,
EnableHistory: true,
Model: h.App.ChatConfig.Model,
ApiKey: "",
}),
}
// 初始化调用次数
var cfg model.Config
h.db.Where("marker = ?", "system").First(&cfg)
var config types.SystemConfig
err = utils.JsonDecode(cfg.Config, &config)
if err != nil || config.UserInitCalls <= 0 {
user.Calls = types.UserInitCalls
} else {
user.Calls = config.UserInitCalls
}
res = h.db.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
return
}
_ = h.levelDB.Delete(key) // 注册成功,删除短信验证码
resp.SUCCESS(c, user)
}
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
var data struct {
Username string
Password string
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var user model.User
res := h.db.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户名不存在")
return
}
password := utils.GenPassword(data.Password, user.Salt)
if password != user.Password {
resp.ERROR(c, "用户名或密码错误")
return
}
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.db.Model(&user).Updates(user)
sessionId := utils.RandString(42)
err := utils.SetLoginUser(c, user)
if err != nil {
resp.ERROR(c, "保存会话失败")
logger.Error("Error for save session: ", err)
return
}
// 记录登录信息在服务端
h.App.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
h.db.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
var chatConfig types.ChatConfig
err = utils.JsonDecode(user.ChatConfig, &chatConfig)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, gin.H{
"session_id": sessionId,
"id": user.Id,
"nickname": user.Nickname,
"avatar": user.Avatar,
"username": user.Username,
"tokens": user.Tokens,
"calls": user.Calls,
"expired_time": user.ExpiredTime,
"api_key": chatConfig.ApiKey,
"model": chatConfig.Model,
"temperature": chatConfig.Temperature,
"max_tokens": chatConfig.MaxTokens,
"enable_context": chatConfig.EnableContext,
"enable_history": chatConfig.EnableHistory,
})
}
// Logout 注 销
func (h *UserHandler) Logout(c *gin.Context) {
sessionId := c.GetHeader(types.SessionName)
session := sessions.Default(c)
session.Delete(types.SessionUser)
err := session.Save()
if err != nil {
logger.Error("Error for save session: ", err)
}
// 删除 websocket 会话列表
h.App.ChatSession.Delete(sessionId)
// 关闭 socket 连接
client := h.App.ChatClients.Get(sessionId)
if client != nil {
client.Close()
}
resp.SUCCESS(c)
}
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} else {
resp.NotAuth(c)
}
}
type userProfile struct {
Id uint `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Mobile string `json:"mobile"`
Avatar string `json:"avatar"`
ChatConfig types.ChatConfig `json:"chat_config"`
Calls int `json:"calls"`
Tokens int64 `json:"tokens"`
}
func (h *UserHandler) Profile(c *gin.Context) {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
var profile userProfile
err = utils.CopyObject(user, &profile)
if err != nil {
logger.Error("对象拷贝失败:", err.Error())
resp.ERROR(c, "获取用户信息失败")
return
}
profile.Id = user.Id
resp.SUCCESS(c, profile)
}
func (h *UserHandler) ProfileUpdate(c *gin.Context) {
var data userProfile
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
h.db.First(&user, user.Id)
user.Nickname = data.Nickname
user.Avatar = data.Avatar
var chatConfig types.ChatConfig
err = utils.JsonDecode(user.ChatConfig, &chatConfig)
if err != nil {
resp.ERROR(c, "用户配置解析失败")
return
}
chatConfig.EnableHistory = data.ChatConfig.EnableHistory
chatConfig.EnableContext = data.ChatConfig.EnableContext
chatConfig.Model = data.ChatConfig.Model
chatConfig.MaxTokens = data.ChatConfig.MaxTokens
chatConfig.ApiKey = data.ChatConfig.ApiKey
chatConfig.Temperature = data.ChatConfig.Temperature
user.ChatConfig = utils.JsonEncode(chatConfig)
res := h.db.Updates(&user)
if res.Error != nil {
resp.ERROR(c, "更新用户信息失败")
return
}
resp.SUCCESS(c)
}
// Password 更新密码
func (h *UserHandler) Password(c *gin.Context) {
var data struct {
OldPass string `json:"old_pass"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if len(data.Password) < 8 {
resp.ERROR(c, "密码长度不能少于8个字符")
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
password := utils.GenPassword(data.OldPass, user.Salt)
logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
if password != user.Password {
resp.ERROR(c, "原密码错误")
return
}
newPass := utils.GenPassword(data.Password, user.Salt)
res := h.db.Model(&user).UpdateColumn("password", newPass)
if res.Error != nil {
logger.Error("更新数据库失败: ", res.Error)
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}
// BindMobile 绑定手机号
func (h *UserHandler) BindMobile(c *gin.Context) {
var data struct {
Mobile string `json:"mobile"`
Code int `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.db.Where("mobile = ?", data.Mobile).First(&item)
if res.Error == nil {
resp.ERROR(c, "该手机号已经被其他账号绑定")
return
}
// 检查验证码
key := CodeStorePrefix + data.Mobile
var code int
err := h.levelDB.Get(key, &code)
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
return
}
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return
}
res = h.db.Model(&user).UpdateColumn("mobile", data.Mobile)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
_ = h.levelDB.Delete(key) // 删除短信验证码
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,150 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/store"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"time"
"github.com/gin-gonic/gin"
)
// 生成验证的控制器
type VerifyHandler struct {
BaseHandler
sms *service.AliYunSmsService
db *store.LevelDB
}
const TokenStorePrefix = "/verify/tokens/"
const CodeStorePrefix = "/verify/codes/"
const MobileStatPrefix = "/verify/stats/"
func NewVerifyHandler(app *core.AppServer, sms *service.AliYunSmsService, db *store.LevelDB) *VerifyHandler {
handler := &VerifyHandler{sms: sms, db: db}
handler.App = app
return handler
}
type VerifyToken struct {
Token string
Timestamp int64
}
// CodeStats 验证码发送统计
type CodeStats struct {
Mobile string
Count uint
Time int64
}
// Token 生成自验证 token
func (h *VerifyHandler) Token(c *gin.Context) {
// 如果不是通过浏览器访问,则返回错误的 token
if c.GetHeader("Sec-Fetch-Mode") != "cors" {
token := fmt.Sprintf("%s:%d", utils.RandString(32), time.Now().Unix())
encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(token))
if err != nil {
resp.ERROR(c, "Token 加密出错")
return
}
resp.SUCCESS(c, encrypt)
return
}
token := VerifyToken{
Token: utils.RandString(32),
Timestamp: time.Now().Unix(),
}
json := utils.JsonEncode(token)
encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(json))
if err != nil {
resp.ERROR(c, "Token 加密出错")
return
}
err = h.db.Put(TokenStorePrefix+token.Token, token)
if err != nil {
resp.ERROR(c, "Token 存储失败")
return
}
resp.SUCCESS(c, encrypt)
}
// SendMsg 发送验证码短信
func (h *VerifyHandler) SendMsg(c *gin.Context) {
var data struct {
Mobile string `json:"mobile"`
Token string `json:"token"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
decrypt, err := utils.AesDecrypt(h.App.Config.AesEncryptKey, data.Token)
if err != nil {
resp.ERROR(c, "Token 解密失败")
return
}
var token VerifyToken
err = utils.JsonDecode(string(decrypt), &token)
if err != nil {
resp.ERROR(c, "Token 解码失败")
return
}
if time.Now().Unix()-token.Timestamp > 30 {
resp.ERROR(c, "Token 已过期,请刷新页面重试")
return
}
// 验证当前手机号发送次数24 小时内相同手机号只允许发送 2 次
var stat CodeStats
err = h.db.Get(MobileStatPrefix+data.Mobile, &stat)
if err != nil {
logger.Error(err)
stat = CodeStats{
Mobile: data.Mobile,
Count: 0,
Time: time.Now().Unix(),
}
} else if stat.Count == 2 {
if time.Now().Unix()-stat.Time > 86400 {
stat.Count = 0
stat.Time = time.Now().Unix()
} else {
resp.ERROR(c, "触发流量预警,请 24 小时后再操作!")
return
}
}
code := utils.RandomNumber(6)
err = h.sms.SendVerifyCode(data.Mobile, code)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 每个 token 用完一次立即失效
_ = h.db.Delete(TokenStorePrefix + token.Token)
// 存储验证码,等待后面注册验证
err = h.db.Put(CodeStorePrefix+data.Mobile, code)
if err != nil {
resp.ERROR(c, "验证码保存失败")
return
}
// 更新发送次数
stat.Count = stat.Count + 1
_ = h.db.Put(MobileStatPrefix+data.Mobile, stat)
logger.Infof("%+v", stat)
resp.SUCCESS(c)
}