diff --git a/server/chat_handler.go b/server/chat_handler.go index 1017e76b..42362e77 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "openai/types" + "openai/utils" "strings" "time" ) @@ -28,7 +29,7 @@ func (s *Server) ChatHandle(c *gin.Context) { sessionId := c.Query("sessionId") roleKey := c.Query("role") session := s.ChatSession[sessionId] - logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token) + logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username) client := NewWsClient(ws) var roles = GetChatRoles() var chatRole = roles[roleKey] @@ -36,10 +37,11 @@ func (s *Server) ChatHandle(c *gin.Context) { c.Abort() return } - // 发送打招呼信息 - replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client) - replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client) - replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client) + // 加载历史消息,如果历史消息为空则发送打招呼消息 + _, err = GetChatHistory(session.Username, roleKey) + if err != nil { + replyMessage(client, chatRole.HelloMsg, true) + } go func() { for { _, message, err := client.Receive() @@ -60,15 +62,15 @@ func (s *Server) ChatHandle(c *gin.Context) { } // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 -func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error { - token, err := GetToken(session.Token) +func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client) error { + user, err := GetUser(session.Username) if err != nil { - replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!") + replyMessage(ws, "当前 user 无效,请使用合法的 user 登录!", false) return err } - if token.MaxCalls > 0 && token.RemainingCalls <= 0 { - replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!") + if user.MaxCalls > 0 && user.RemainingCalls <= 0 { + replyMessage(ws, "当前 user 点数已经用尽,请充值后再使用或者联系管理员!", false) return nil } var r = types.ApiRequest{ @@ -91,7 +93,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex r.Messages = append(context, types.Message{ Role: "user", - Content: text, + Content: prompt, }) requestBody, err := json.Marshal(r) @@ -107,7 +109,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex } request.Header.Add("Content-Type", "application/json") - var retryCount = 3 + var retryCount = 5 var response *http.Response var failedKey = "" var failedProxyURL = "" @@ -145,10 +147,11 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex // 如果三次请求都失败的话,则返回对应的错误信息 if err != nil { - replyError(ws, ErrorMsg) + replyMessage(ws, ErrorMsg, false) return err } + // 循环读取 Chunk 消息 var message = types.Message{} var contents = make([]string, 0) var responseBody = types.ApiResponse{} @@ -161,58 +164,59 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex } if line == "" { - replyMessage(types.WsMessage{Type: types.WsEnd}, ws) + replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) break } else if len(line) < 20 { continue } err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil { - logger.Error(line) - replyError(ws, ErrorMsg) + if err != nil { // 数据解析出错 + logger.Error(err, line) + replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}) break } // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role - replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: false}, ws) + replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: false}) continue } else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了 - replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}, ws) + replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}) break } else { content := responseBody.Choices[0].Delta.Content contents = append(contents, content) - replyMessage(types.WsMessage{ + replyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, Content: responseBody.Choices[0].Delta.Content, IsHelloMsg: false, - }, ws) + }) } } - // 当前 Token 调用次数减 1 - if token.MaxCalls > 0 { - token.RemainingCalls -= 1 - _ = PutToken(*token) + _ = response.Body.Close() // 关闭资源 + + // 当前 Username 调用次数减 1 + if user.MaxCalls > 0 { + user.RemainingCalls -= 1 + _ = PutUser(*user) } - // 追加历史消息 - context = append(context, types.Message{ - Role: "user", - Content: text, - }) + // 追加上下文消息 + useMsg := types.Message{Role: "user", Content: prompt} + context = append(context, useMsg) message.Content = strings.Join(contents, "") context = append(context, message) - // 保存上下文 s.ChatContext[key] = context - _ = response.Body.Close() // 关闭资源 - return nil -} -func replyError(ws Client, message string) { - replyMessage(types.WsMessage{Type: types.WsStart}, ws) - replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws) - replyMessage(types.WsMessage{Type: types.WsEnd}, ws) + // 追加历史消息 + if user.EnableHistory { + err = AppendChatHistory(user.Name, role.Key, useMsg) + if err != nil { + return err + } + err = AppendChatHistory(user.Name, role.Key, message) + } + return err } // 随机获取一个 API Key,如果请求失败,则更换 API Key 重试 @@ -267,8 +271,8 @@ func (s *Server) getProxyURL(failedProxyURL string) string { return "" } -// 回复客户端消息 -func replyMessage(message types.WsMessage, client Client) { +// 回复客户片段端消息 +func replyChunkMessage(client Client, message types.WsMessage) { msg, err := json.Marshal(message) if err != nil { logger.Errorf("Error for decoding json data: %v", err.Error()) @@ -279,3 +283,61 @@ func replyMessage(message types.WsMessage, client Client) { logger.Errorf("Error for reply message: %v", err.Error()) } } + +// 回复客户端一条完整的消息 +func replyMessage(ws Client, message string, isHelloMsg bool) { + replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: isHelloMsg}) + replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message, IsHelloMsg: isHelloMsg}) + replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: isHelloMsg}) +} + +func (s *Server) GetChatHistoryHandle(c *gin.Context) { + sessionId := c.GetHeader(types.TokenName) + var data struct { + Role string `json:"role"` + } + err := json.NewDecoder(c.Request.Body).Decode(&data) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) + return + } + + session := s.ChatSession[sessionId] + history, err := GetChatHistory(session.Username, data.Role) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "No history message"}) + return + } + + var messages = make([]types.HistoryMessage, 0) + role, err := GetChatRole(data.Role) + if err == nil { + // 先将打招呼的消息追加上去 + messages = append(messages, types.HistoryMessage{ + Type: "reply", + Id: utils.RandString(32), + Icon: role.Icon, + Content: role.HelloMsg, + }) + + for _, v := range history { + if v.Role == "user" { + messages = append(messages, types.HistoryMessage{ + Type: "prompt", + Id: utils.RandString(32), + Icon: "images/avatar/user.png", + Content: v.Content, + }) + } else if v.Role == "assistant" { + messages = append(messages, types.HistoryMessage{ + Type: "reply", + Id: utils.RandString(32), + Icon: role.Icon, + Content: v.Content, + }) + } + } + } + + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: messages}) +} diff --git a/server/config_handler.go b/server/config_handler.go index 4f0a4e27..1b96c984 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -37,13 +37,13 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { s.Config.Chat.Temperature = float32(v) } - // max_tokens + // max_users if maxTokens, ok := data["max_tokens"]; ok { v, err := strconv.Atoi(maxTokens) if err != nil { c.JSON(http.StatusOK, types.BizVo{ Code: types.InvalidParams, - Message: "max_tokens must be a int parameter", + Message: "max_users must be a int parameter", }) return } @@ -86,8 +86,8 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) } -// SetDebug 开启/关闭调试模式 -func (s *Server) SetDebug(c *gin.Context) { +// SetDebugHandle 开启/关闭调试模式 +func (s *Server) SetDebugHandle(c *gin.Context) { var data struct { Debug bool `json:"debug"` } @@ -101,9 +101,9 @@ func (s *Server) SetDebug(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) } -// AddToken 添加 Token -func (s *Server) AddToken(c *gin.Context) { - var data types.Token +// AddUserHandle 添加 Username +func (s *Server) AddUserHandle(c *gin.Context) { + var data types.User err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) @@ -116,25 +116,25 @@ func (s *Server) AddToken(c *gin.Context) { return } - // 检查当前要添加的 token 是否已经存在 - _, err = GetToken(data.Name) + // 检查当前要添加的 Username 是否已经存在 + _, err = GetUser(data.Name) if err == nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + data.Name + " already exists"}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username " + data.Name + " already exists"}) return } - token := types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls} - err = PutToken(token) + user := types.User{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls} + err = PutUser(user) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: user}) } -// BatchAddToken 批量生成 Token -func (s *Server) BatchAddToken(c *gin.Context) { +// BatchAddUserHandle 批量生成 Username +func (s *Server) BatchAddUserHandle(c *gin.Context) { var data struct { Number int `json:"number"` MaxCalls int `json:"max_calls"` @@ -145,24 +145,24 @@ func (s *Server) BatchAddToken(c *gin.Context) { return } - var tokens = make([]string, 0) + var users = make([]string, 0) for i := 0; i < data.Number; i++ { name := utils.RandString(12) - _, err := GetToken(name) + _, err := GetUser(name) for err == nil { name = utils.RandString(12) } - err = PutToken(types.Token{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}) + err = PutUser(types.User{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}) if err == nil { - tokens = append(tokens, name) + users = append(users, name) } } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: tokens}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: users}) } -func (s *Server) SetToken(c *gin.Context) { - var data types.Token +func (s *Server) SetUserHandle(c *gin.Context) { + var data types.User err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) @@ -175,49 +175,50 @@ func (s *Server) SetToken(c *gin.Context) { return } - token, err := GetToken(data.Name) + user, err := GetUser(data.Name) if err != nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username not found"}) return } if data.MaxCalls > 0 { - token.RemainingCalls += data.MaxCalls - token.MaxCalls - if token.RemainingCalls < 0 { - token.RemainingCalls = 0 + user.RemainingCalls += data.MaxCalls - user.MaxCalls + if user.RemainingCalls < 0 { + user.RemainingCalls = 0 } } - token.MaxCalls = data.MaxCalls + user.MaxCalls = data.MaxCalls + user.EnableHistory = data.EnableHistory - err = PutToken(*token) + err = PutUser(*user) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: user}) } -// RemoveToken 删除 Token -func (s *Server) RemoveToken(c *gin.Context) { - var data types.Token +// RemoveUserHandle 删除 Username +func (s *Server) RemoveUserHandle(c *gin.Context) { + var data types.User err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) return } - err = RemoveToken(data.Name) + err = RemoveUser(data.Name) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetUsers()}) } -// AddApiKey 添加一个 API key -func (s *Server) AddApiKey(c *gin.Context) { +// AddApiKeyHandle 添加一个 API key +func (s *Server) AddApiKeyHandle(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { @@ -239,8 +240,8 @@ func (s *Server) AddApiKey(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) } -// RemoveApiKey 移除一个 API key -func (s *Server) RemoveApiKey(c *gin.Context) { +// RemoveApiKeyHandle 移除一个 API key +func (s *Server) RemoveApiKeyHandle(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { @@ -267,12 +268,13 @@ func (s *Server) RemoveApiKey(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) } -// ListApiKeys 获取 API key 列表 -func (s *Server) ListApiKeys(c *gin.Context) { +// ListApiKeysHandle 获取 API key 列表 +func (s *Server) ListApiKeysHandle(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) } -func (s *Server) GetChatRoleList(c *gin.Context) { +// GetChatRoleListHandle 获取聊天角色列表 +func (s *Server) GetChatRoleListHandle(c *gin.Context) { var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"} var res = make([]interface{}, 0) var roles = GetChatRoles() @@ -292,8 +294,8 @@ func (s *Server) GetChatRoleList(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: res}) } -// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作 -func (s *Server) UpdateChatRole(c *gin.Context) { +// UpdateChatRoleHandle 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作 +func (s *Server) UpdateChatRoleHandle(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { @@ -345,8 +347,8 @@ func (s *Server) UpdateChatRole(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: role}) } -// AddProxy 添加一个代理 -func (s *Server) AddProxy(c *gin.Context) { +// AddProxyHandle 添加一个代理 +func (s *Server) AddProxyHandle(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { @@ -371,7 +373,8 @@ func (s *Server) AddProxy(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.ProxyURL}) } -func (s *Server) RemoveProxy(c *gin.Context) { +// RemoveProxyHandle 删除一个代理 +func (s *Server) RemoveProxyHandle(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { diff --git a/server/db.go b/server/db.go index 80f6a281..4af1b826 100644 --- a/server/db.go +++ b/server/db.go @@ -7,7 +7,7 @@ import ( ) const ( - TokenPrefix = "chat/tokens/" + UserPrefix = "chat/users/" ChatRolePrefix = "chat/roles/" ChatHistoryPrefix = "chat/history/" ) @@ -22,45 +22,45 @@ func init() { db = leveldb } -// GetTokens 获取 token 信息 -// chat/tokens -func GetTokens() []types.Token { - items := db.Search(TokenPrefix) - var tokens = make([]types.Token, 0) +// GetUsers 获取 user 信息 +// chat/users +func GetUsers() []types.User { + items := db.Search(UserPrefix) + var users = make([]types.User, 0) for _, v := range items { - var token types.Token - err := json.Unmarshal([]byte(v), &token) + var user types.User + err := json.Unmarshal([]byte(v), &user) if err != nil { continue } - tokens = append(tokens, token) + users = append(users, user) } - return tokens + return users } -func PutToken(token types.Token) error { - key := TokenPrefix + token.Name - return db.Put(key, token) +func PutUser(user types.User) error { + key := UserPrefix + user.Name + return db.Put(key, user) } -func GetToken(name string) (*types.Token, error) { - key := TokenPrefix + name +func GetUser(username string) (*types.User, error) { + key := UserPrefix + username bytes, err := db.Get(key) if err != nil { return nil, err } - var token types.Token - err = json.Unmarshal(bytes, &token) + var user types.User + err = json.Unmarshal(bytes, &user) if err != nil { return nil, err } - return &token, nil + return &user, nil } -func RemoveToken(token string) error { - key := TokenPrefix + token +func RemoveUser(username string) error { + key := UserPrefix + username return db.Delete(key) } @@ -86,7 +86,7 @@ func PutChatRole(role types.ChatRole) error { } func GetChatRole(key string) (*types.ChatRole, error) { - key = ChatHistoryPrefix + key + key = ChatRolePrefix + key bytes, err := db.Get(key) if err != nil { return nil, err @@ -102,7 +102,37 @@ func GetChatRole(key string) (*types.ChatRole, error) { } // GetChatHistory 获取聊天历史记录 -// chat/history/{token}/{role} -func GetChatHistory() []types.Message { - return nil +// chat/history/{user}/{role} +func GetChatHistory(user string, role string) ([]types.Message, error) { + key := ChatHistoryPrefix + user + "/" + role + bytes, err := db.Get(key) + if err != nil { + return nil, err + } + + var message []types.Message + err = json.Unmarshal(bytes, &message) + if err != nil { + return nil, err + } + + return message, nil +} + +// AppendChatHistory 追加聊天记录 +func AppendChatHistory(user string, role string, message types.Message) error { + messages, err := GetChatHistory(user, role) + if err != nil { + messages = make([]types.Message, 0) + } + + messages = append(messages, message) + key := ChatHistoryPrefix + user + "/" + role + return db.Put(key, messages) +} + +// ClearChatHistory 清空某个角色下的聊天记录 +func ClearChatHistory(user string, role string) error { + key := ChatHistoryPrefix + user + "/" + role + return db.Delete(key) } diff --git a/server/server.go b/server/server.go index 32e7ba7c..5576bfa1 100644 --- a/server/server.go +++ b/server/server.go @@ -7,7 +7,6 @@ import ( "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "io/fs" - "log" "net/http" logger2 "openai/logger" "openai/types" @@ -36,7 +35,7 @@ type Server struct { ConfigPath string ChatContext map[string][]types.Message // 聊天上下文 [SessionID] => []Messages - // 保存 Websocket 会话 Token, 每个 Token 只能连接一次 + // 保存 Websocket 会话 Username, 每个 Username 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API ChatSession map[string]types.ChatSession ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 @@ -83,19 +82,21 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { engine.GET("/api/session/get", s.GetSessionHandle) engine.POST("/api/login", s.LoginHandle) engine.Any("/api/chat", s.ChatHandle) + engine.POST("api/chat/history", s.GetChatHistoryHandle) + engine.POST("/api/config/set", s.ConfigSetHandle) - engine.GET("/api/config/chat-roles/get", s.GetChatRoleList) - engine.POST("api/config/token/add", s.AddToken) - engine.POST("api/config/token/batch-add", s.BatchAddToken) - engine.POST("api/config/token/set", s.SetToken) - engine.POST("api/config/token/remove", s.RemoveToken) - engine.POST("api/config/apikey/add", s.AddApiKey) - engine.POST("api/config/apikey/remove", s.RemoveApiKey) - engine.POST("api/config/apikey/list", s.ListApiKeys) - engine.POST("api/config/role/set", s.UpdateChatRole) - engine.POST("api/config/proxy/add", s.AddProxy) - engine.POST("api/config/proxy/remove", s.RemoveProxy) - engine.POST("api/config/debug", s.SetDebug) + engine.GET("/api/config/chat-roles/get", s.GetChatRoleListHandle) + engine.POST("api/config/user/add", s.AddUserHandle) + engine.POST("api/config/user/batch-add", s.BatchAddUserHandle) + engine.POST("api/config/user/set", s.SetUserHandle) + engine.POST("api/config/user/remove", s.RemoveUserHandle) + engine.POST("api/config/apikey/add", s.AddApiKeyHandle) + engine.POST("api/config/apikey/remove", s.RemoveApiKeyHandle) + engine.POST("api/config/apikey/list", s.ListApiKeysHandle) + engine.POST("api/config/role/set", s.UpdateChatRoleHandle) + engine.POST("api/config/proxy/add", s.AddProxyHandle) + engine.POST("api/config/proxy/remove", s.RemoveProxyHandle) + engine.POST("api/config/debug", s.SetDebugHandle) engine.NoRoute(func(c *gin.Context) { if c.Request.URL.Path == "/favicon.ico" { @@ -122,7 +123,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { func Recover(c *gin.Context) { defer func() { if r := recover(); r != nil { - log.Printf("panic: %v\n", r) + logger.Error("panic: %v\n", r) debug.PrintStack() c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) c.Abort() @@ -156,7 +157,7 @@ func corsMiddleware() gin.HandlerFunc { c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") //允许跨域设置可以返回其他子段,可以自定义字段 - c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Token") + c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Username") // 允许浏览器(客户端)可以解析的头部 (重要) c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") //设置缓存时间 @@ -171,7 +172,7 @@ func corsMiddleware() gin.HandlerFunc { defer func() { if err := recover(); err != nil { - log.Printf("Panic info is: %v", err) + logger.Info("Panic info is: %v", err) } }() @@ -242,27 +243,28 @@ func (s *Server) GetSessionHandle(c *gin.Context) { } func (s *Server) LoginHandle(c *gin.Context) { - var data map[string]string + var data struct { + Token string `json:"token"` + } err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) return } - token := data["token"] - if !utils.ContainToken(GetTokens(), token) { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"}) + if !utils.Containuser(GetUsers(), data.Token) { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid user"}) return } sessionId := utils.RandString(42) session := sessions.Default(c) - session.Set(sessionId, token) + session.Set(sessionId, data.Token) err = session.Save() if err != nil { logger.Error("Error for save session: ", err) } // 记录客户端 IP 地址 - s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Token: token, SessionId: sessionId} + s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Username: data.Token, SessionId: sessionId} c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId}) } diff --git a/types/chat.go b/types/chat.go index fde48851..ceed1a69 100644 --- a/types/chat.go +++ b/types/chat.go @@ -14,6 +14,14 @@ type Message struct { Content string `json:"content"` } +// HistoryMessage 历史聊天消息 +type HistoryMessage struct { + Type string `json:"type"` + Id string `json:"id"` + Icon string `json:"icon"` + Content string `json:"content"` +} + type ApiResponse struct { Choices []ChoiceItem `json:"choices"` } @@ -37,7 +45,7 @@ type ChatRole struct { type ChatSession struct { SessionId string `json:"session_id"` ClientIP string `json:"client_ip"` // 客户端 IP - Token string `json:"token"` // 当前登录的 token + Username string `json:"user"` // 当前登录的 user } func GetDefaultChatRole() map[string]ChatRole { diff --git a/types/config.go b/types/config.go index 3e9174b3..5b620815 100644 --- a/types/config.go +++ b/types/config.go @@ -13,10 +13,11 @@ type Config struct { Chat Chat } -type Token struct { +type User struct { Name string `json:"name"` MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制 RemainingCalls int `json:"remaining_calls"` // 剩余调用次数 + EnableHistory bool `json:"enable_history"` // 是否启用聊天记录 } // Chat configs struct diff --git a/types/web.go b/types/web.go index e6045ae2..91e235c0 100644 --- a/types/web.go +++ b/types/web.go @@ -6,7 +6,7 @@ type BizVo struct { Page int `json:"page,omitempty"` PageSize int `json:"page_size,omitempty"` Total int `json:"total,omitempty"` - Message string `json:"message"` + Message string `json:"message,omitempty"` Data interface{} `json:"data,omitempty"` } @@ -36,5 +36,5 @@ const ( ErrorMsg = "系统开小差了" ) -const TokenName = "ChatGPT-Token" +const TokenName = "ChatGPT-TOKEN" const SessionKey = "WEB_SSH_SESSION" diff --git a/utils/utils.go b/utils/utils.go index 5d903596..a934f1e6 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -41,9 +41,9 @@ func ContainsStr(slice []string, item string) bool { return false } -func ContainToken(slice []types.Token, token string) bool { +func Containuser(slice []types.User, user string) bool { for _, e := range slice { - if e.Name == token { + if e.Name == user { return true } } diff --git a/web/src/utils/http.js b/web/src/utils/http.js index e5899e09..b76bdda7 100644 --- a/web/src/utils/http.js +++ b/web/src/utils/http.js @@ -10,7 +10,7 @@ axios.defaults.headers.post['Content-Type'] = 'application/json' axios.interceptors.request.use( config => { // set token - config.headers['ChatGPT-Token'] = getSessionId(); + config.headers['ChatGPT-TOKEN'] = getSessionId(); return config }, error => { return Promise.reject(error) diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index c1098f9f..1b740380 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -185,10 +185,13 @@ export default defineComponent({ window.addEventListener("resize", () => { this.chatBoxHeight = window.innerHeight - this.toolBoxHeight; + this.inputBoxWidth = window.innerWidth - 20; }); this.connect(); + this.fetchChatHistory(); + }, methods: { @@ -300,6 +303,17 @@ export default defineComponent({ break; } } + + this.fetchChatHistory(); + }, + + // 从后端获取聊天历史记录 + fetchChatHistory: function () { + httpPost("/api/chat/history", {role: this.role}).then((res) => { + this.chatData = res.data + }).catch((e) => { + console.error(e.message) + }) }, inputKeyDown: function (e) {