mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-12 04:03:42 +08:00
使用 leveldb 存储用户 token 和聊天记录
This commit is contained in:
@@ -23,14 +23,17 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
logger.Fatal(err)
|
||||
return
|
||||
}
|
||||
token := c.Query("token")
|
||||
sessionId := c.Query("sessionId")
|
||||
role := c.Query("role")
|
||||
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
||||
client := NewWsClient(ws)
|
||||
// TODO: 这里需要先判断一下角色是否存在,并且角色是被启用的
|
||||
if !s.ChatRoles[role].Enable { // 角色未启用
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 发送打招呼信息
|
||||
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.Config.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
|
||||
go func() {
|
||||
for {
|
||||
@@ -43,7 +46,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||
err = s.sendMessage(token, role, string(message), client)
|
||||
err = s.sendMessage(sessionId, role, string(message), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -64,9 +67,13 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext {
|
||||
context = v
|
||||
} else {
|
||||
context = s.Config.ChatRoles[role].Context
|
||||
context = s.ChatRoles[role].Context
|
||||
}
|
||||
logger.Infof("会话上下文:%+v", context)
|
||||
|
||||
if s.DebugMode {
|
||||
logger.Infof("会话上下文:%+v", context)
|
||||
}
|
||||
|
||||
r.Messages = append(context, types.Message{
|
||||
Role: "user",
|
||||
Content: text,
|
||||
@@ -179,6 +186,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
context = append(context, message)
|
||||
// 保存上下文
|
||||
s.ChatContext[key] = context
|
||||
_ = response.Body.Close() // 关闭资源
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@@ -86,6 +86,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
}
|
||||
|
||||
// AddToken 添加 Token
|
||||
func (s *Server) AddToken(c *gin.Context) {
|
||||
var data map[string]string
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
@@ -95,22 +96,38 @@ func (s *Server) AddToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if token, ok := data["token"]; ok {
|
||||
if !utils.ContainsItem(s.Config.Tokens, token) {
|
||||
s.Config.Tokens = append(s.Config.Tokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
var name = data["name"]
|
||||
var maxCalls = data["max_calls"]
|
||||
if name == "" || maxCalls == "" {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
|
||||
n, err := strconv.Atoi(maxCalls)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "enable_auth must be a int parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var tokens = GetTokens()
|
||||
if utils.ContainToken(tokens, name) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"})
|
||||
return
|
||||
}
|
||||
|
||||
err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||
}
|
||||
|
||||
// RemoveToken 删除 Token
|
||||
func (s *Server) RemoveToken(c *gin.Context) {
|
||||
var data map[string]string
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
@@ -121,22 +138,14 @@ func (s *Server) RemoveToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if token, ok := data["token"]; ok {
|
||||
for i, v := range s.Config.Tokens {
|
||||
if v == token {
|
||||
s.Config.Tokens = append(s.Config.Tokens[:i], s.Config.Tokens[i+1:]...)
|
||||
break
|
||||
}
|
||||
err = RemoveToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||
}
|
||||
|
||||
// AddApiKey 添加一个 API key
|
||||
@@ -153,7 +162,7 @@ func (s *Server) AddApiKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@@ -181,7 +190,7 @@ func (s *Server) RemoveApiKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@@ -196,22 +205,22 @@ func (s *Server) ListApiKeys(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) GetChatRoles(c *gin.Context) {
|
||||
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||
var roles = make([]interface{}, 0)
|
||||
for _, k := range rolesOrder {
|
||||
if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
|
||||
roles = append(roles, struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
}{
|
||||
Key: v.Key,
|
||||
Name: v.Name,
|
||||
Icon: v.Icon,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
|
||||
//var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||
//var roles = make([]interface{}, 0)
|
||||
//for _, k := range rolesOrder {
|
||||
// if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
|
||||
// roles = append(roles, struct {
|
||||
// Key string `json:"key"`
|
||||
// Name string `json:"name"`
|
||||
// Icon string `json:"icon"`
|
||||
// }{
|
||||
// Key: v.Key,
|
||||
// Name: v.Name,
|
||||
// Icon: v.Icon,
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
|
||||
}
|
||||
|
||||
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
|
||||
@@ -229,39 +238,39 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
role := s.Config.ChatRoles[key]
|
||||
if enable, ok := data["enable"]; ok {
|
||||
v, err := strconv.ParseBool(enable)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "enable must be a bool parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
role.Enable = v
|
||||
}
|
||||
//role := s.Config.ChatRoles[key]
|
||||
//if enable, ok := data["enable"]; ok {
|
||||
// v, err := strconv.ParseBool(enable)
|
||||
// if err != nil {
|
||||
// c.JSON(http.StatusOK, types.BizVo{
|
||||
// Code: types.InvalidParams,
|
||||
// Message: "enable must be a bool parameter",
|
||||
// })
|
||||
// return
|
||||
// }
|
||||
// role.Enable = v
|
||||
//}
|
||||
|
||||
if name, ok := data["name"]; ok {
|
||||
role.Name = name
|
||||
}
|
||||
if helloMsg, ok := data["hello_msg"]; ok {
|
||||
role.HelloMsg = helloMsg
|
||||
}
|
||||
if icon, ok := data["icon"]; ok {
|
||||
role.Icon = icon
|
||||
}
|
||||
|
||||
s.Config.ChatRoles[key] = role
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
//if name, ok := data["name"]; ok {
|
||||
// role.Name = name
|
||||
//}
|
||||
//if helloMsg, ok := data["hello_msg"]; ok {
|
||||
// role.HelloMsg = helloMsg
|
||||
//}
|
||||
//if icon, ok := data["icon"]; ok {
|
||||
// role.Icon = icon
|
||||
//}
|
||||
//
|
||||
//s.Config.ChatRoles[key] = role
|
||||
//
|
||||
//// 保存配置文件
|
||||
//err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
//if err != nil {
|
||||
// c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
}
|
||||
|
||||
// AddProxy 添加一个代理
|
||||
@@ -275,13 +284,13 @@ func (s *Server) AddProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
if proxy, ok := data["proxy"]; ok {
|
||||
if !utils.ContainsItem(s.Config.ProxyURL, proxy) {
|
||||
if !utils.ContainsStr(s.Config.ProxyURL, proxy) {
|
||||
s.Config.ProxyURL = append(s.Config.ProxyURL, proxy)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@@ -309,7 +318,7 @@ func (s *Server) RemoveProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
|
||||
61
server/db.go
Normal file
61
server/db.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"openai/types"
|
||||
"openai/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenPrefix = "chat/tokens/"
|
||||
ChatRolePrefix = "chat/roles/"
|
||||
ChatHistoryPrefix = "chat/history/"
|
||||
)
|
||||
|
||||
var db *utils.LevelDB
|
||||
|
||||
func init() {
|
||||
leveldb, err := utils.NewLevelDB("data")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db = leveldb
|
||||
}
|
||||
|
||||
// GetTokens 获取 token 信息
|
||||
// chat/tokens
|
||||
func GetTokens() []types.Token {
|
||||
items := db.Search(TokenPrefix)
|
||||
var tokens = make([]types.Token, 0)
|
||||
for _, v := range items {
|
||||
var token types.Token
|
||||
err := json.Unmarshal([]byte(v), &token)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func PutToken(token types.Token) error {
|
||||
key := TokenPrefix + token.Name
|
||||
return db.Put(key, token)
|
||||
}
|
||||
|
||||
func RemoveToken(token string) error {
|
||||
key := TokenPrefix + token
|
||||
return db.Delete(key)
|
||||
}
|
||||
|
||||
// GetChatRoles 获取聊天角色
|
||||
// chat/roles
|
||||
func GetChatRoles() map[string]types.ChatRole {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatHistory 获取聊天历史记录
|
||||
// chat/history/{token}/{role}
|
||||
func GetChatHistory() []types.Message {
|
||||
return nil
|
||||
}
|
||||
@@ -38,29 +38,34 @@ type Server struct {
|
||||
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
WsSession map[string]string
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
DebugMode bool // 是否开启调试模式
|
||||
ChatRoles map[string]types.ChatRole // 保存预设角色信息
|
||||
}
|
||||
|
||||
func NewServer(configPath string) (*Server, error) {
|
||||
// load service configs
|
||||
config, err := types.LoadConfig(configPath)
|
||||
if config.ChatRoles == nil {
|
||||
config.ChatRoles = types.GetDefaultChatRole()
|
||||
}
|
||||
config, err := utils.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles := GetChatRoles()
|
||||
if roles == nil {
|
||||
roles = types.GetDefaultChatRole()
|
||||
}
|
||||
return &Server{
|
||||
Config: config,
|
||||
ConfigPath: configPath,
|
||||
ChatContext: make(map[string][]types.Message, 16),
|
||||
WsSession: make(map[string]string),
|
||||
ApiKeyAccessStat: make(map[string]int64),
|
||||
ChatRoles: roles,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
|
||||
s.DebugMode = debug
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
engine := gin.Default()
|
||||
if debug {
|
||||
@@ -225,7 +230,7 @@ func (s *Server) LoginHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
token := data["token"]
|
||||
if !utils.ContainsItem(s.Config.Tokens, token) {
|
||||
if !utils.ContainToken(GetTokens(), token) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user