使用 leveldb 存储用户 token 和聊天记录

This commit is contained in:
RockYang
2023-03-27 18:27:33 +08:00
parent 6a38de7eaa
commit a6bab7b12d
15 changed files with 370 additions and 175 deletions

View File

@@ -23,14 +23,17 @@ func (s *Server) ChatHandle(c *gin.Context) {
logger.Fatal(err)
return
}
token := c.Query("token")
sessionId := c.Query("sessionId")
role := c.Query("role")
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
client := NewWsClient(ws)
// TODO: 这里需要先判断一下角色是否存在,并且角色是被启用
if !s.ChatRoles[role].Enable { // 角色未启用
c.Abort()
return
}
// 发送打招呼信息
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.Config.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
go func() {
for {
@@ -43,7 +46,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
logger.Info("Receive a message: ", string(message))
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
err = s.sendMessage(token, role, string(message), client)
err = s.sendMessage(sessionId, role, string(message), client)
if err != nil {
logger.Error(err)
}
@@ -64,9 +67,13 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext {
context = v
} else {
context = s.Config.ChatRoles[role].Context
context = s.ChatRoles[role].Context
}
logger.Infof("会话上下文:%+v", context)
if s.DebugMode {
logger.Infof("会话上下文:%+v", context)
}
r.Messages = append(context, types.Message{
Role: "user",
Content: text,
@@ -179,6 +186,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
context = append(context, message)
// 保存上下文
s.ChatContext[key] = context
_ = response.Body.Close() // 关闭资源
return nil
}

View File

@@ -77,7 +77,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
err = utils.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return
@@ -86,6 +86,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
}
// AddToken 添加 Token
func (s *Server) AddToken(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
@@ -95,22 +96,38 @@ func (s *Server) AddToken(c *gin.Context) {
return
}
if token, ok := data["token"]; ok {
if !utils.ContainsItem(s.Config.Tokens, token) {
s.Config.Tokens = append(s.Config.Tokens, token)
}
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
var name = data["name"]
var maxCalls = data["max_calls"]
if name == "" || maxCalls == "" {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
n, err := strconv.Atoi(maxCalls)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "enable_auth must be a int parameter",
})
return
}
var tokens = GetTokens()
if utils.ContainToken(tokens, name) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"})
return
}
err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n})
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
}
// RemoveToken 删除 Token
func (s *Server) RemoveToken(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
@@ -121,22 +138,14 @@ func (s *Server) RemoveToken(c *gin.Context) {
}
if token, ok := data["token"]; ok {
for i, v := range s.Config.Tokens {
if v == token {
s.Config.Tokens = append(s.Config.Tokens[:i], s.Config.Tokens[i+1:]...)
break
}
err = RemoveToken(token)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return
}
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
}
// AddApiKey 添加一个 API key
@@ -153,7 +162,7 @@ func (s *Server) AddApiKey(c *gin.Context) {
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
err = utils.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return
@@ -181,7 +190,7 @@ func (s *Server) RemoveApiKey(c *gin.Context) {
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
err = utils.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return
@@ -196,22 +205,22 @@ func (s *Server) ListApiKeys(c *gin.Context) {
}
func (s *Server) GetChatRoles(c *gin.Context) {
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
var roles = make([]interface{}, 0)
for _, k := range rolesOrder {
if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
roles = append(roles, struct {
Key string `json:"key"`
Name string `json:"name"`
Icon string `json:"icon"`
}{
Key: v.Key,
Name: v.Name,
Icon: v.Icon,
})
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
//var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
//var roles = make([]interface{}, 0)
//for _, k := range rolesOrder {
// if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
// roles = append(roles, struct {
// Key string `json:"key"`
// Name string `json:"name"`
// Icon string `json:"icon"`
// }{
// Key: v.Key,
// Name: v.Name,
// Icon: v.Icon,
// })
// }
//}
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
}
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
@@ -229,39 +238,39 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
return
}
role := s.Config.ChatRoles[key]
if enable, ok := data["enable"]; ok {
v, err := strconv.ParseBool(enable)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "enable must be a bool parameter",
})
return
}
role.Enable = v
}
//role := s.Config.ChatRoles[key]
//if enable, ok := data["enable"]; ok {
// v, err := strconv.ParseBool(enable)
// if err != nil {
// c.JSON(http.StatusOK, types.BizVo{
// Code: types.InvalidParams,
// Message: "enable must be a bool parameter",
// })
// return
// }
// role.Enable = v
//}
if name, ok := data["name"]; ok {
role.Name = name
}
if helloMsg, ok := data["hello_msg"]; ok {
role.HelloMsg = helloMsg
}
if icon, ok := data["icon"]; ok {
role.Icon = icon
}
s.Config.ChatRoles[key] = role
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
//if name, ok := data["name"]; ok {
// role.Name = name
//}
//if helloMsg, ok := data["hello_msg"]; ok {
// role.HelloMsg = helloMsg
//}
//if icon, ok := data["icon"]; ok {
// role.Icon = icon
//}
//
//s.Config.ChatRoles[key] = role
//
//// 保存配置文件
//err = types.SaveConfig(s.Config, s.ConfigPath)
//if err != nil {
// c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
// return
//}
//
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
}
// AddProxy 添加一个代理
@@ -275,13 +284,13 @@ func (s *Server) AddProxy(c *gin.Context) {
}
if proxy, ok := data["proxy"]; ok {
if !utils.ContainsItem(s.Config.ProxyURL, proxy) {
if !utils.ContainsStr(s.Config.ProxyURL, proxy) {
s.Config.ProxyURL = append(s.Config.ProxyURL, proxy)
}
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
err = utils.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return
@@ -309,7 +318,7 @@ func (s *Server) RemoveProxy(c *gin.Context) {
}
// 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath)
err = utils.SaveConfig(s.Config, s.ConfigPath)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
return

61
server/db.go Normal file
View File

@@ -0,0 +1,61 @@
package server
import (
"encoding/json"
"openai/types"
"openai/utils"
)
const (
TokenPrefix = "chat/tokens/"
ChatRolePrefix = "chat/roles/"
ChatHistoryPrefix = "chat/history/"
)
var db *utils.LevelDB
func init() {
leveldb, err := utils.NewLevelDB("data")
if err != nil {
panic(err)
}
db = leveldb
}
// GetTokens 获取 token 信息
// chat/tokens
func GetTokens() []types.Token {
items := db.Search(TokenPrefix)
var tokens = make([]types.Token, 0)
for _, v := range items {
var token types.Token
err := json.Unmarshal([]byte(v), &token)
if err != nil {
continue
}
tokens = append(tokens, token)
}
return tokens
}
func PutToken(token types.Token) error {
key := TokenPrefix + token.Name
return db.Put(key, token)
}
func RemoveToken(token string) error {
key := TokenPrefix + token
return db.Delete(key)
}
// GetChatRoles 获取聊天角色
// chat/roles
func GetChatRoles() map[string]types.ChatRole {
return nil
}
// GetChatHistory 获取聊天历史记录
// chat/history/{token}/{role}
func GetChatHistory() []types.Message {
return nil
}

View File

@@ -38,29 +38,34 @@ type Server struct {
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
WsSession map[string]string
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
DebugMode bool // 是否开启调试模式
ChatRoles map[string]types.ChatRole // 保存预设角色信息
}
func NewServer(configPath string) (*Server, error) {
// load service configs
config, err := types.LoadConfig(configPath)
if config.ChatRoles == nil {
config.ChatRoles = types.GetDefaultChatRole()
}
config, err := utils.LoadConfig(configPath)
if err != nil {
return nil, err
}
roles := GetChatRoles()
if roles == nil {
roles = types.GetDefaultChatRole()
}
return &Server{
Config: config,
ConfigPath: configPath,
ChatContext: make(map[string][]types.Message, 16),
WsSession: make(map[string]string),
ApiKeyAccessStat: make(map[string]int64),
ChatRoles: roles,
}, nil
}
func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
s.DebugMode = debug
gin.SetMode(gin.ReleaseMode)
engine := gin.Default()
if debug {
@@ -225,7 +230,7 @@ func (s *Server) LoginHandle(c *gin.Context) {
return
}
token := data["token"]
if !utils.ContainsItem(s.Config.Tokens, token) {
if !utils.ContainToken(GetTokens(), token) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"})
return
}