rename Token to User, the chat history function is ready

This commit is contained in:
RockYang 2023-03-28 16:03:41 +08:00
parent 95f9dfa9cb
commit ebc2041e8a
10 changed files with 261 additions and 141 deletions

View File

@ -12,6 +12,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"openai/types" "openai/types"
"openai/utils"
"strings" "strings"
"time" "time"
) )
@ -28,7 +29,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
sessionId := c.Query("sessionId") sessionId := c.Query("sessionId")
roleKey := c.Query("role") roleKey := c.Query("role")
session := s.ChatSession[sessionId] session := s.ChatSession[sessionId]
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token) logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := NewWsClient(ws) client := NewWsClient(ws)
var roles = GetChatRoles() var roles = GetChatRoles()
var chatRole = roles[roleKey] var chatRole = roles[roleKey]
@ -36,10 +37,11 @@ func (s *Server) ChatHandle(c *gin.Context) {
c.Abort() c.Abort()
return return
} }
// 发送打招呼信息 // 加载历史消息,如果历史消息为空则发送打招呼消息
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client) _, err = GetChatHistory(session.Username, roleKey)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client) if err != nil {
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client) replyMessage(client, chatRole.HelloMsg, true)
}
go func() { go func() {
for { for {
_, message, err := client.Receive() _, message, err := client.Receive()
@ -60,15 +62,15 @@ func (s *Server) ChatHandle(c *gin.Context) {
} }
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error { func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client) error {
token, err := GetToken(session.Token) user, err := GetUser(session.Username)
if err != nil { if err != nil {
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!") replyMessage(ws, "当前 user 无效,请使用合法的 user 登录!", false)
return err return err
} }
if token.MaxCalls > 0 && token.RemainingCalls <= 0 { if user.MaxCalls > 0 && user.RemainingCalls <= 0 {
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!") replyMessage(ws, "当前 user 点数已经用尽,请充值后再使用或者联系管理员!", false)
return nil return nil
} }
var r = types.ApiRequest{ var r = types.ApiRequest{
@ -91,7 +93,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
r.Messages = append(context, types.Message{ r.Messages = append(context, types.Message{
Role: "user", Role: "user",
Content: text, Content: prompt,
}) })
requestBody, err := json.Marshal(r) requestBody, err := json.Marshal(r)
@ -107,7 +109,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
} }
request.Header.Add("Content-Type", "application/json") request.Header.Add("Content-Type", "application/json")
var retryCount = 3 var retryCount = 5
var response *http.Response var response *http.Response
var failedKey = "" var failedKey = ""
var failedProxyURL = "" var failedProxyURL = ""
@ -145,10 +147,11 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
// 如果三次请求都失败的话,则返回对应的错误信息 // 如果三次请求都失败的话,则返回对应的错误信息
if err != nil { if err != nil {
replyError(ws, ErrorMsg) replyMessage(ws, ErrorMsg, false)
return err return err
} }
// 循环读取 Chunk 消息
var message = types.Message{} var message = types.Message{}
var contents = make([]string, 0) var contents = make([]string, 0)
var responseBody = types.ApiResponse{} var responseBody = types.ApiResponse{}
@ -161,58 +164,59 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
} }
if line == "" { if line == "" {
replyMessage(types.WsMessage{Type: types.WsEnd}, ws) replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
break break
} else if len(line) < 20 { } else if len(line) < 20 {
continue continue
} }
err = json.Unmarshal([]byte(line[6:]), &responseBody) err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { if err != nil { // 数据解析出错
logger.Error(line) logger.Error(err, line)
replyError(ws, ErrorMsg) replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break break
} }
// 初始化 role // 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role message.Role = responseBody.Choices[0].Delta.Role
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: false}, ws) replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: false})
continue continue
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了 } else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}, ws) replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break break
} else { } else {
content := responseBody.Choices[0].Delta.Content content := responseBody.Choices[0].Delta.Content
contents = append(contents, content) contents = append(contents, content)
replyMessage(types.WsMessage{ replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle, Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content, Content: responseBody.Choices[0].Delta.Content,
IsHelloMsg: false, IsHelloMsg: false,
}, ws) })
} }
} }
// 当前 Token 调用次数减 1 _ = response.Body.Close() // 关闭资源
if token.MaxCalls > 0 {
token.RemainingCalls -= 1 // 当前 Username 调用次数减 1
_ = PutToken(*token) if user.MaxCalls > 0 {
user.RemainingCalls -= 1
_ = PutUser(*user)
} }
// 追加历史消息 // 追加上下文消息
context = append(context, types.Message{ useMsg := types.Message{Role: "user", Content: prompt}
Role: "user", context = append(context, useMsg)
Content: text,
})
message.Content = strings.Join(contents, "") message.Content = strings.Join(contents, "")
context = append(context, message) context = append(context, message)
// 保存上下文
s.ChatContext[key] = context s.ChatContext[key] = context
_ = response.Body.Close() // 关闭资源
return nil
}
func replyError(ws Client, message string) { // 追加历史消息
replyMessage(types.WsMessage{Type: types.WsStart}, ws) if user.EnableHistory {
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws) err = AppendChatHistory(user.Name, role.Key, useMsg)
replyMessage(types.WsMessage{Type: types.WsEnd}, ws) if err != nil {
return err
}
err = AppendChatHistory(user.Name, role.Key, message)
}
return err
} }
// 随机获取一个 API Key如果请求失败则更换 API Key 重试 // 随机获取一个 API Key如果请求失败则更换 API Key 重试
@ -267,8 +271,8 @@ func (s *Server) getProxyURL(failedProxyURL string) string {
return "" return ""
} }
// 回复客户端消息 // 回复客户片段端消息
func replyMessage(message types.WsMessage, client Client) { func replyChunkMessage(client Client, message types.WsMessage) {
msg, err := json.Marshal(message) msg, err := json.Marshal(message)
if err != nil { if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error()) logger.Errorf("Error for decoding json data: %v", err.Error())
@ -279,3 +283,61 @@ func replyMessage(message types.WsMessage, client Client) {
logger.Errorf("Error for reply message: %v", err.Error()) logger.Errorf("Error for reply message: %v", err.Error())
} }
} }
// 回复客户端一条完整的消息
func replyMessage(ws Client, message string, isHelloMsg bool) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: isHelloMsg})
}
func (s *Server) GetChatHistoryHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
var data struct {
Role string `json:"role"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
session := s.ChatSession[sessionId]
history, err := GetChatHistory(session.Username, data.Role)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "No history message"})
return
}
var messages = make([]types.HistoryMessage, 0)
role, err := GetChatRole(data.Role)
if err == nil {
// 先将打招呼的消息追加上去
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: role.HelloMsg,
})
for _, v := range history {
if v.Role == "user" {
messages = append(messages, types.HistoryMessage{
Type: "prompt",
Id: utils.RandString(32),
Icon: "images/avatar/user.png",
Content: v.Content,
})
} else if v.Role == "assistant" {
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: v.Content,
})
}
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: messages})
}

View File

@ -37,13 +37,13 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
s.Config.Chat.Temperature = float32(v) s.Config.Chat.Temperature = float32(v)
} }
// max_tokens // max_users
if maxTokens, ok := data["max_tokens"]; ok { if maxTokens, ok := data["max_tokens"]; ok {
v, err := strconv.Atoi(maxTokens) v, err := strconv.Atoi(maxTokens)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{ c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams, Code: types.InvalidParams,
Message: "max_tokens must be a int parameter", Message: "max_users must be a int parameter",
}) })
return return
} }
@ -86,8 +86,8 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
} }
// SetDebug 开启/关闭调试模式 // SetDebugHandle 开启/关闭调试模式
func (s *Server) SetDebug(c *gin.Context) { func (s *Server) SetDebugHandle(c *gin.Context) {
var data struct { var data struct {
Debug bool `json:"debug"` Debug bool `json:"debug"`
} }
@ -101,9 +101,9 @@ func (s *Server) SetDebug(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
} }
// AddToken 添加 Token // AddUserHandle 添加 Username
func (s *Server) AddToken(c *gin.Context) { func (s *Server) AddUserHandle(c *gin.Context) {
var data types.Token var data types.User
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
@ -116,25 +116,25 @@ func (s *Server) AddToken(c *gin.Context) {
return return
} }
// 检查当前要添加的 token 是否已经存在 // 检查当前要添加的 Username 是否已经存在
_, err = GetToken(data.Name) _, err = GetUser(data.Name)
if err == nil { if err == nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + data.Name + " already exists"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username " + data.Name + " already exists"})
return return
} }
token := types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls} user := types.User{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}
err = PutToken(token) err = PutUser(user)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return return
} }
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: user})
} }
// BatchAddToken 批量生成 Token // BatchAddUserHandle 批量生成 Username
func (s *Server) BatchAddToken(c *gin.Context) { func (s *Server) BatchAddUserHandle(c *gin.Context) {
var data struct { var data struct {
Number int `json:"number"` Number int `json:"number"`
MaxCalls int `json:"max_calls"` MaxCalls int `json:"max_calls"`
@ -145,24 +145,24 @@ func (s *Server) BatchAddToken(c *gin.Context) {
return return
} }
var tokens = make([]string, 0) var users = make([]string, 0)
for i := 0; i < data.Number; i++ { for i := 0; i < data.Number; i++ {
name := utils.RandString(12) name := utils.RandString(12)
_, err := GetToken(name) _, err := GetUser(name)
for err == nil { for err == nil {
name = utils.RandString(12) name = utils.RandString(12)
} }
err = PutToken(types.Token{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}) err = PutUser(types.User{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls})
if err == nil { if err == nil {
tokens = append(tokens, name) users = append(users, name)
} }
} }
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: tokens}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: users})
} }
func (s *Server) SetToken(c *gin.Context) { func (s *Server) SetUserHandle(c *gin.Context) {
var data types.Token var data types.User
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
@ -175,49 +175,50 @@ func (s *Server) SetToken(c *gin.Context) {
return return
} }
token, err := GetToken(data.Name) user, err := GetUser(data.Name)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username not found"})
return return
} }
if data.MaxCalls > 0 { if data.MaxCalls > 0 {
token.RemainingCalls += data.MaxCalls - token.MaxCalls user.RemainingCalls += data.MaxCalls - user.MaxCalls
if token.RemainingCalls < 0 { if user.RemainingCalls < 0 {
token.RemainingCalls = 0 user.RemainingCalls = 0
} }
} }
token.MaxCalls = data.MaxCalls user.MaxCalls = data.MaxCalls
user.EnableHistory = data.EnableHistory
err = PutToken(*token) err = PutUser(*user)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return return
} }
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: user})
} }
// RemoveToken 删除 Token // RemoveUserHandle 删除 Username
func (s *Server) RemoveToken(c *gin.Context) { func (s *Server) RemoveUserHandle(c *gin.Context) {
var data types.Token var data types.User
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return return
} }
err = RemoveToken(data.Name) err = RemoveUser(data.Name)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return return
} }
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetUsers()})
} }
// AddApiKey 添加一个 API key // AddApiKeyHandle 添加一个 API key
func (s *Server) AddApiKey(c *gin.Context) { func (s *Server) AddApiKeyHandle(c *gin.Context) {
var data map[string]string var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
@ -239,8 +240,8 @@ func (s *Server) AddApiKey(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
} }
// RemoveApiKey 移除一个 API key // RemoveApiKeyHandle 移除一个 API key
func (s *Server) RemoveApiKey(c *gin.Context) { func (s *Server) RemoveApiKeyHandle(c *gin.Context) {
var data map[string]string var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
@ -267,12 +268,13 @@ func (s *Server) RemoveApiKey(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
} }
// ListApiKeys 获取 API key 列表 // ListApiKeysHandle 获取 API key 列表
func (s *Server) ListApiKeys(c *gin.Context) { func (s *Server) ListApiKeysHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
} }
func (s *Server) GetChatRoleList(c *gin.Context) { // GetChatRoleListHandle 获取聊天角色列表
func (s *Server) GetChatRoleListHandle(c *gin.Context) {
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"} var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
var res = make([]interface{}, 0) var res = make([]interface{}, 0)
var roles = GetChatRoles() var roles = GetChatRoles()
@ -292,8 +294,8 @@ func (s *Server) GetChatRoleList(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: res}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: res})
} }
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作 // UpdateChatRoleHandle 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
func (s *Server) UpdateChatRole(c *gin.Context) { func (s *Server) UpdateChatRoleHandle(c *gin.Context) {
var data map[string]string var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
@ -345,8 +347,8 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: role}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: role})
} }
// AddProxy 添加一个代理 // AddProxyHandle 添加一个代理
func (s *Server) AddProxy(c *gin.Context) { func (s *Server) AddProxyHandle(c *gin.Context) {
var data map[string]string var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
@ -371,7 +373,8 @@ func (s *Server) AddProxy(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.ProxyURL}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.ProxyURL})
} }
func (s *Server) RemoveProxy(c *gin.Context) { // RemoveProxyHandle 删除一个代理
func (s *Server) RemoveProxyHandle(c *gin.Context) {
var data map[string]string var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {

View File

@ -7,7 +7,7 @@ import (
) )
const ( const (
TokenPrefix = "chat/tokens/" UserPrefix = "chat/users/"
ChatRolePrefix = "chat/roles/" ChatRolePrefix = "chat/roles/"
ChatHistoryPrefix = "chat/history/" ChatHistoryPrefix = "chat/history/"
) )
@ -22,45 +22,45 @@ func init() {
db = leveldb db = leveldb
} }
// GetTokens 获取 token 信息 // GetUsers 获取 user 信息
// chat/tokens // chat/users
func GetTokens() []types.Token { func GetUsers() []types.User {
items := db.Search(TokenPrefix) items := db.Search(UserPrefix)
var tokens = make([]types.Token, 0) var users = make([]types.User, 0)
for _, v := range items { for _, v := range items {
var token types.Token var user types.User
err := json.Unmarshal([]byte(v), &token) err := json.Unmarshal([]byte(v), &user)
if err != nil { if err != nil {
continue continue
} }
tokens = append(tokens, token) users = append(users, user)
} }
return tokens return users
} }
func PutToken(token types.Token) error { func PutUser(user types.User) error {
key := TokenPrefix + token.Name key := UserPrefix + user.Name
return db.Put(key, token) return db.Put(key, user)
} }
func GetToken(name string) (*types.Token, error) { func GetUser(username string) (*types.User, error) {
key := TokenPrefix + name key := UserPrefix + username
bytes, err := db.Get(key) bytes, err := db.Get(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var token types.Token var user types.User
err = json.Unmarshal(bytes, &token) err = json.Unmarshal(bytes, &user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &token, nil return &user, nil
} }
func RemoveToken(token string) error { func RemoveUser(username string) error {
key := TokenPrefix + token key := UserPrefix + username
return db.Delete(key) return db.Delete(key)
} }
@ -86,7 +86,7 @@ func PutChatRole(role types.ChatRole) error {
} }
func GetChatRole(key string) (*types.ChatRole, error) { func GetChatRole(key string) (*types.ChatRole, error) {
key = ChatHistoryPrefix + key key = ChatRolePrefix + key
bytes, err := db.Get(key) bytes, err := db.Get(key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -102,7 +102,37 @@ func GetChatRole(key string) (*types.ChatRole, error) {
} }
// GetChatHistory 获取聊天历史记录 // GetChatHistory 获取聊天历史记录
// chat/history/{token}/{role} // chat/history/{user}/{role}
func GetChatHistory() []types.Message { func GetChatHistory(user string, role string) ([]types.Message, error) {
return nil key := ChatHistoryPrefix + user + "/" + role
bytes, err := db.Get(key)
if err != nil {
return nil, err
}
var message []types.Message
err = json.Unmarshal(bytes, &message)
if err != nil {
return nil, err
}
return message, nil
}
// AppendChatHistory 追加聊天记录
func AppendChatHistory(user string, role string, message types.Message) error {
messages, err := GetChatHistory(user, role)
if err != nil {
messages = make([]types.Message, 0)
}
messages = append(messages, message)
key := ChatHistoryPrefix + user + "/" + role
return db.Put(key, messages)
}
// ClearChatHistory 清空某个角色下的聊天记录
func ClearChatHistory(user string, role string) error {
key := ChatHistoryPrefix + user + "/" + role
return db.Delete(key)
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io/fs" "io/fs"
"log"
"net/http" "net/http"
logger2 "openai/logger" logger2 "openai/logger"
"openai/types" "openai/types"
@ -36,7 +35,7 @@ type Server struct {
ConfigPath string ConfigPath string
ChatContext map[string][]types.Message // 聊天上下文 [SessionID] => []Messages ChatContext map[string][]types.Message // 聊天上下文 [SessionID] => []Messages
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次 // 保存 Websocket 会话 Username, 每个 Username 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API // 防止第三方直接连接 socket 调用 OpenAI API
ChatSession map[string]types.ChatSession ChatSession map[string]types.ChatSession
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
@ -83,19 +82,21 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
engine.GET("/api/session/get", s.GetSessionHandle) engine.GET("/api/session/get", s.GetSessionHandle)
engine.POST("/api/login", s.LoginHandle) engine.POST("/api/login", s.LoginHandle)
engine.Any("/api/chat", s.ChatHandle) engine.Any("/api/chat", s.ChatHandle)
engine.POST("api/chat/history", s.GetChatHistoryHandle)
engine.POST("/api/config/set", s.ConfigSetHandle) engine.POST("/api/config/set", s.ConfigSetHandle)
engine.GET("/api/config/chat-roles/get", s.GetChatRoleList) engine.GET("/api/config/chat-roles/get", s.GetChatRoleListHandle)
engine.POST("api/config/token/add", s.AddToken) engine.POST("api/config/user/add", s.AddUserHandle)
engine.POST("api/config/token/batch-add", s.BatchAddToken) engine.POST("api/config/user/batch-add", s.BatchAddUserHandle)
engine.POST("api/config/token/set", s.SetToken) engine.POST("api/config/user/set", s.SetUserHandle)
engine.POST("api/config/token/remove", s.RemoveToken) engine.POST("api/config/user/remove", s.RemoveUserHandle)
engine.POST("api/config/apikey/add", s.AddApiKey) engine.POST("api/config/apikey/add", s.AddApiKeyHandle)
engine.POST("api/config/apikey/remove", s.RemoveApiKey) engine.POST("api/config/apikey/remove", s.RemoveApiKeyHandle)
engine.POST("api/config/apikey/list", s.ListApiKeys) engine.POST("api/config/apikey/list", s.ListApiKeysHandle)
engine.POST("api/config/role/set", s.UpdateChatRole) engine.POST("api/config/role/set", s.UpdateChatRoleHandle)
engine.POST("api/config/proxy/add", s.AddProxy) engine.POST("api/config/proxy/add", s.AddProxyHandle)
engine.POST("api/config/proxy/remove", s.RemoveProxy) engine.POST("api/config/proxy/remove", s.RemoveProxyHandle)
engine.POST("api/config/debug", s.SetDebug) engine.POST("api/config/debug", s.SetDebugHandle)
engine.NoRoute(func(c *gin.Context) { engine.NoRoute(func(c *gin.Context) {
if c.Request.URL.Path == "/favicon.ico" { if c.Request.URL.Path == "/favicon.ico" {
@ -122,7 +123,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
func Recover(c *gin.Context) { func Recover(c *gin.Context) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Printf("panic: %v\n", r) logger.Error("panic: %v\n", r)
debug.PrintStack() debug.PrintStack()
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
c.Abort() c.Abort()
@ -156,7 +157,7 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段 //允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Token") c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Username")
// 允许浏览器(客户端)可以解析的头部 (重要) // 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间 //设置缓存时间
@ -171,7 +172,7 @@ func corsMiddleware() gin.HandlerFunc {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
log.Printf("Panic info is: %v", err) logger.Info("Panic info is: %v", err)
} }
}() }()
@ -242,27 +243,28 @@ func (s *Server) GetSessionHandle(c *gin.Context) {
} }
func (s *Server) LoginHandle(c *gin.Context) { func (s *Server) LoginHandle(c *gin.Context) {
var data map[string]string var data struct {
Token string `json:"token"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
return return
} }
token := data["token"] if !utils.Containuser(GetUsers(), data.Token) {
if !utils.ContainToken(GetTokens(), token) { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid user"})
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"})
return return
} }
sessionId := utils.RandString(42) sessionId := utils.RandString(42)
session := sessions.Default(c) session := sessions.Default(c)
session.Set(sessionId, token) session.Set(sessionId, data.Token)
err = session.Save() err = session.Save()
if err != nil { if err != nil {
logger.Error("Error for save session: ", err) logger.Error("Error for save session: ", err)
} }
// 记录客户端 IP 地址 // 记录客户端 IP 地址
s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Token: token, SessionId: sessionId} s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Username: data.Token, SessionId: sessionId}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId})
} }

View File

@ -14,6 +14,14 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
// HistoryMessage 历史聊天消息
type HistoryMessage struct {
Type string `json:"type"`
Id string `json:"id"`
Icon string `json:"icon"`
Content string `json:"content"`
}
type ApiResponse struct { type ApiResponse struct {
Choices []ChoiceItem `json:"choices"` Choices []ChoiceItem `json:"choices"`
} }
@ -37,7 +45,7 @@ type ChatRole struct {
type ChatSession struct { type ChatSession struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
ClientIP string `json:"client_ip"` // 客户端 IP ClientIP string `json:"client_ip"` // 客户端 IP
Token string `json:"token"` // 当前登录的 token Username string `json:"user"` // 当前登录的 user
} }
func GetDefaultChatRole() map[string]ChatRole { func GetDefaultChatRole() map[string]ChatRole {

View File

@ -13,10 +13,11 @@ type Config struct {
Chat Chat Chat Chat
} }
type Token struct { type User struct {
Name string `json:"name"` Name string `json:"name"`
MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制 MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制
RemainingCalls int `json:"remaining_calls"` // 剩余调用次数 RemainingCalls int `json:"remaining_calls"` // 剩余调用次数
EnableHistory bool `json:"enable_history"` // 是否启用聊天记录
} }
// Chat configs struct // Chat configs struct

View File

@ -6,7 +6,7 @@ type BizVo struct {
Page int `json:"page,omitempty"` Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"` PageSize int `json:"page_size,omitempty"`
Total int `json:"total,omitempty"` Total int `json:"total,omitempty"`
Message string `json:"message"` Message string `json:"message,omitempty"`
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
@ -36,5 +36,5 @@ const (
ErrorMsg = "系统开小差了" ErrorMsg = "系统开小差了"
) )
const TokenName = "ChatGPT-Token" const TokenName = "ChatGPT-TOKEN"
const SessionKey = "WEB_SSH_SESSION" const SessionKey = "WEB_SSH_SESSION"

View File

@ -41,9 +41,9 @@ func ContainsStr(slice []string, item string) bool {
return false return false
} }
func ContainToken(slice []types.Token, token string) bool { func Containuser(slice []types.User, user string) bool {
for _, e := range slice { for _, e := range slice {
if e.Name == token { if e.Name == user {
return true return true
} }
} }

View File

@ -10,7 +10,7 @@ axios.defaults.headers.post['Content-Type'] = 'application/json'
axios.interceptors.request.use( axios.interceptors.request.use(
config => { config => {
// set token // set token
config.headers['ChatGPT-Token'] = getSessionId(); config.headers['ChatGPT-TOKEN'] = getSessionId();
return config return config
}, error => { }, error => {
return Promise.reject(error) return Promise.reject(error)

View File

@ -185,10 +185,13 @@ export default defineComponent({
window.addEventListener("resize", () => { window.addEventListener("resize", () => {
this.chatBoxHeight = window.innerHeight - this.toolBoxHeight; this.chatBoxHeight = window.innerHeight - this.toolBoxHeight;
this.inputBoxWidth = window.innerWidth - 20;
}); });
this.connect(); this.connect();
this.fetchChatHistory();
}, },
methods: { methods: {
@ -300,6 +303,17 @@ export default defineComponent({
break; break;
} }
} }
this.fetchChatHistory();
},
//
fetchChatHistory: function () {
httpPost("/api/chat/history", {role: this.role}).then((res) => {
this.chatData = res.data
}).catch((e) => {
console.error(e.message)
})
}, },
inputKeyDown: function (e) { inputKeyDown: function (e) {