rename Token to User, the chat history function is ready

This commit is contained in:
RockYang 2023-03-28 16:03:41 +08:00
parent 95f9dfa9cb
commit ebc2041e8a
10 changed files with 261 additions and 141 deletions

View File

@ -12,6 +12,7 @@ import (
"net/http"
"net/url"
"openai/types"
"openai/utils"
"strings"
"time"
)
@ -28,7 +29,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
sessionId := c.Query("sessionId")
roleKey := c.Query("role")
session := s.ChatSession[sessionId]
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token)
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := NewWsClient(ws)
var roles = GetChatRoles()
var chatRole = roles[roleKey]
@ -36,10 +37,11 @@ func (s *Server) ChatHandle(c *gin.Context) {
c.Abort()
return
}
// 发送打招呼信息
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
// 加载历史消息,如果历史消息为空则发送打招呼消息
_, err = GetChatHistory(session.Username, roleKey)
if err != nil {
replyMessage(client, chatRole.HelloMsg, true)
}
go func() {
for {
_, message, err := client.Receive()
@ -60,15 +62,15 @@ func (s *Server) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error {
token, err := GetToken(session.Token)
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client) error {
user, err := GetUser(session.Username)
if err != nil {
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!")
replyMessage(ws, "当前 user 无效,请使用合法的 user 登录!", false)
return err
}
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!")
if user.MaxCalls > 0 && user.RemainingCalls <= 0 {
replyMessage(ws, "当前 user 点数已经用尽,请充值后再使用或者联系管理员!", false)
return nil
}
var r = types.ApiRequest{
@ -91,7 +93,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
r.Messages = append(context, types.Message{
Role: "user",
Content: text,
Content: prompt,
})
requestBody, err := json.Marshal(r)
@ -107,7 +109,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
}
request.Header.Add("Content-Type", "application/json")
var retryCount = 3
var retryCount = 5
var response *http.Response
var failedKey = ""
var failedProxyURL = ""
@ -145,10 +147,11 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
// 如果三次请求都失败的话,则返回对应的错误信息
if err != nil {
replyError(ws, ErrorMsg)
replyMessage(ws, ErrorMsg, false)
return err
}
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
@ -161,58 +164,59 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
}
if line == "" {
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
break
} else if len(line) < 20 {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
logger.Error(line)
replyError(ws, ErrorMsg)
if err != nil { // 数据解析出错
logger.Error(err, line)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: false}, ws)
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: false})
continue
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}, ws)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyMessage(types.WsMessage{
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
IsHelloMsg: false,
}, ws)
})
}
}
// 当前 Token 调用次数减 1
if token.MaxCalls > 0 {
token.RemainingCalls -= 1
_ = PutToken(*token)
_ = response.Body.Close() // 关闭资源
// 当前 Username 调用次数减 1
if user.MaxCalls > 0 {
user.RemainingCalls -= 1
_ = PutUser(*user)
}
// 追加历史消息
context = append(context, types.Message{
Role: "user",
Content: text,
})
// 追加上下文消息
useMsg := types.Message{Role: "user", Content: prompt}
context = append(context, useMsg)
message.Content = strings.Join(contents, "")
context = append(context, message)
// 保存上下文
s.ChatContext[key] = context
_ = response.Body.Close() // 关闭资源
return nil
}
func replyError(ws Client, message string) {
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws)
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
// 追加历史消息
if user.EnableHistory {
err = AppendChatHistory(user.Name, role.Key, useMsg)
if err != nil {
return err
}
err = AppendChatHistory(user.Name, role.Key, message)
}
return err
}
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
@ -267,8 +271,8 @@ func (s *Server) getProxyURL(failedProxyURL string) string {
return ""
}
// 回复客户端消息
func replyMessage(message types.WsMessage, client Client) {
// 回复客户片段端消息
func replyChunkMessage(client Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
@ -279,3 +283,61 @@ func replyMessage(message types.WsMessage, client Client) {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws Client, message string, isHelloMsg bool) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: isHelloMsg})
}
func (s *Server) GetChatHistoryHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
var data struct {
Role string `json:"role"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
session := s.ChatSession[sessionId]
history, err := GetChatHistory(session.Username, data.Role)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "No history message"})
return
}
var messages = make([]types.HistoryMessage, 0)
role, err := GetChatRole(data.Role)
if err == nil {
// 先将打招呼的消息追加上去
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: role.HelloMsg,
})
for _, v := range history {
if v.Role == "user" {
messages = append(messages, types.HistoryMessage{
Type: "prompt",
Id: utils.RandString(32),
Icon: "images/avatar/user.png",
Content: v.Content,
})
} else if v.Role == "assistant" {
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: v.Content,
})
}
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: messages})
}

View File

@ -37,13 +37,13 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
s.Config.Chat.Temperature = float32(v)
}
// max_tokens
// max_users
if maxTokens, ok := data["max_tokens"]; ok {
v, err := strconv.Atoi(maxTokens)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "max_tokens must be a int parameter",
Message: "max_users must be a int parameter",
})
return
}
@ -86,8 +86,8 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
}
// SetDebug 开启/关闭调试模式
func (s *Server) SetDebug(c *gin.Context) {
// SetDebugHandle 开启/关闭调试模式
func (s *Server) SetDebugHandle(c *gin.Context) {
var data struct {
Debug bool `json:"debug"`
}
@ -101,9 +101,9 @@ func (s *Server) SetDebug(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
}
// AddToken 添加 Token
func (s *Server) AddToken(c *gin.Context) {
var data types.Token
// AddUserHandle 添加 Username
func (s *Server) AddUserHandle(c *gin.Context) {
var data types.User
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
@ -116,25 +116,25 @@ func (s *Server) AddToken(c *gin.Context) {
return
}
// 检查当前要添加的 token 是否已经存在
_, err = GetToken(data.Name)
// 检查当前要添加的 Username 是否已经存在
_, err = GetUser(data.Name)
if err == nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + data.Name + " already exists"})
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username " + data.Name + " already exists"})
return
}
token := types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}
err = PutToken(token)
user := types.User{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}
err = PutUser(user)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token})
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: user})
}
// BatchAddToken 批量生成 Token
func (s *Server) BatchAddToken(c *gin.Context) {
// BatchAddUserHandle 批量生成 Username
func (s *Server) BatchAddUserHandle(c *gin.Context) {
var data struct {
Number int `json:"number"`
MaxCalls int `json:"max_calls"`
@ -145,24 +145,24 @@ func (s *Server) BatchAddToken(c *gin.Context) {
return
}
var tokens = make([]string, 0)
var users = make([]string, 0)
for i := 0; i < data.Number; i++ {
name := utils.RandString(12)
_, err := GetToken(name)
_, err := GetUser(name)
for err == nil {
name = utils.RandString(12)
}
err = PutToken(types.Token{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls})
err = PutUser(types.User{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls})
if err == nil {
tokens = append(tokens, name)
users = append(users, name)
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: tokens})
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: users})
}
func (s *Server) SetToken(c *gin.Context) {
var data types.Token
func (s *Server) SetUserHandle(c *gin.Context) {
var data types.User
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
@ -175,49 +175,50 @@ func (s *Server) SetToken(c *gin.Context) {
return
}
token, err := GetToken(data.Name)
user, err := GetUser(data.Name)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"})
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Username not found"})
return
}
if data.MaxCalls > 0 {
token.RemainingCalls += data.MaxCalls - token.MaxCalls
if token.RemainingCalls < 0 {
token.RemainingCalls = 0
user.RemainingCalls += data.MaxCalls - user.MaxCalls
if user.RemainingCalls < 0 {
user.RemainingCalls = 0
}
}
token.MaxCalls = data.MaxCalls
user.MaxCalls = data.MaxCalls
user.EnableHistory = data.EnableHistory
err = PutToken(*token)
err = PutUser(*user)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token})
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: user})
}
// RemoveToken 删除 Token
func (s *Server) RemoveToken(c *gin.Context) {
var data types.Token
// RemoveUserHandle 删除 Username
func (s *Server) RemoveUserHandle(c *gin.Context) {
var data types.User
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
err = RemoveToken(data.Name)
err = RemoveUser(data.Name)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetUsers()})
}
// AddApiKey 添加一个 API key
func (s *Server) AddApiKey(c *gin.Context) {
// AddApiKeyHandle 添加一个 API key
func (s *Server) AddApiKeyHandle(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
@ -239,8 +240,8 @@ func (s *Server) AddApiKey(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
}
// RemoveApiKey 移除一个 API key
func (s *Server) RemoveApiKey(c *gin.Context) {
// RemoveApiKeyHandle 移除一个 API key
func (s *Server) RemoveApiKeyHandle(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
@ -267,12 +268,13 @@ func (s *Server) RemoveApiKey(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
}
// ListApiKeys 获取 API key 列表
func (s *Server) ListApiKeys(c *gin.Context) {
// ListApiKeysHandle 获取 API key 列表
func (s *Server) ListApiKeysHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
}
func (s *Server) GetChatRoleList(c *gin.Context) {
// GetChatRoleListHandle 获取聊天角色列表
func (s *Server) GetChatRoleListHandle(c *gin.Context) {
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
var res = make([]interface{}, 0)
var roles = GetChatRoles()
@ -292,8 +294,8 @@ func (s *Server) GetChatRoleList(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: res})
}
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
func (s *Server) UpdateChatRole(c *gin.Context) {
// UpdateChatRoleHandle 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
func (s *Server) UpdateChatRoleHandle(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
@ -345,8 +347,8 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: role})
}
// AddProxy 添加一个代理
func (s *Server) AddProxy(c *gin.Context) {
// AddProxyHandle 添加一个代理
func (s *Server) AddProxyHandle(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
@ -371,7 +373,8 @@ func (s *Server) AddProxy(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.ProxyURL})
}
func (s *Server) RemoveProxy(c *gin.Context) {
// RemoveProxyHandle 删除一个代理
func (s *Server) RemoveProxyHandle(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {

View File

@ -7,7 +7,7 @@ import (
)
const (
TokenPrefix = "chat/tokens/"
UserPrefix = "chat/users/"
ChatRolePrefix = "chat/roles/"
ChatHistoryPrefix = "chat/history/"
)
@ -22,45 +22,45 @@ func init() {
db = leveldb
}
// GetTokens 获取 token 信息
// chat/tokens
func GetTokens() []types.Token {
items := db.Search(TokenPrefix)
var tokens = make([]types.Token, 0)
// GetUsers 获取 user 信息
// chat/users
func GetUsers() []types.User {
items := db.Search(UserPrefix)
var users = make([]types.User, 0)
for _, v := range items {
var token types.Token
err := json.Unmarshal([]byte(v), &token)
var user types.User
err := json.Unmarshal([]byte(v), &user)
if err != nil {
continue
}
tokens = append(tokens, token)
users = append(users, user)
}
return tokens
return users
}
func PutToken(token types.Token) error {
key := TokenPrefix + token.Name
return db.Put(key, token)
func PutUser(user types.User) error {
key := UserPrefix + user.Name
return db.Put(key, user)
}
func GetToken(name string) (*types.Token, error) {
key := TokenPrefix + name
func GetUser(username string) (*types.User, error) {
key := UserPrefix + username
bytes, err := db.Get(key)
if err != nil {
return nil, err
}
var token types.Token
err = json.Unmarshal(bytes, &token)
var user types.User
err = json.Unmarshal(bytes, &user)
if err != nil {
return nil, err
}
return &token, nil
return &user, nil
}
func RemoveToken(token string) error {
key := TokenPrefix + token
func RemoveUser(username string) error {
key := UserPrefix + username
return db.Delete(key)
}
@ -86,7 +86,7 @@ func PutChatRole(role types.ChatRole) error {
}
func GetChatRole(key string) (*types.ChatRole, error) {
key = ChatHistoryPrefix + key
key = ChatRolePrefix + key
bytes, err := db.Get(key)
if err != nil {
return nil, err
@ -102,7 +102,37 @@ func GetChatRole(key string) (*types.ChatRole, error) {
}
// GetChatHistory 获取聊天历史记录
// chat/history/{token}/{role}
func GetChatHistory() []types.Message {
return nil
// chat/history/{user}/{role}
func GetChatHistory(user string, role string) ([]types.Message, error) {
key := ChatHistoryPrefix + user + "/" + role
bytes, err := db.Get(key)
if err != nil {
return nil, err
}
var message []types.Message
err = json.Unmarshal(bytes, &message)
if err != nil {
return nil, err
}
return message, nil
}
// AppendChatHistory 追加聊天记录
func AppendChatHistory(user string, role string, message types.Message) error {
messages, err := GetChatHistory(user, role)
if err != nil {
messages = make([]types.Message, 0)
}
messages = append(messages, message)
key := ChatHistoryPrefix + user + "/" + role
return db.Put(key, messages)
}
// ClearChatHistory 清空某个角色下的聊天记录
func ClearChatHistory(user string, role string) error {
key := ChatHistoryPrefix + user + "/" + role
return db.Delete(key)
}

View File

@ -7,7 +7,6 @@ import (
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"io/fs"
"log"
"net/http"
logger2 "openai/logger"
"openai/types"
@ -36,7 +35,7 @@ type Server struct {
ConfigPath string
ChatContext map[string][]types.Message // 聊天上下文 [SessionID] => []Messages
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
// 保存 Websocket 会话 Username, 每个 Username 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
ChatSession map[string]types.ChatSession
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
@ -83,19 +82,21 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
engine.GET("/api/session/get", s.GetSessionHandle)
engine.POST("/api/login", s.LoginHandle)
engine.Any("/api/chat", s.ChatHandle)
engine.POST("api/chat/history", s.GetChatHistoryHandle)
engine.POST("/api/config/set", s.ConfigSetHandle)
engine.GET("/api/config/chat-roles/get", s.GetChatRoleList)
engine.POST("api/config/token/add", s.AddToken)
engine.POST("api/config/token/batch-add", s.BatchAddToken)
engine.POST("api/config/token/set", s.SetToken)
engine.POST("api/config/token/remove", s.RemoveToken)
engine.POST("api/config/apikey/add", s.AddApiKey)
engine.POST("api/config/apikey/remove", s.RemoveApiKey)
engine.POST("api/config/apikey/list", s.ListApiKeys)
engine.POST("api/config/role/set", s.UpdateChatRole)
engine.POST("api/config/proxy/add", s.AddProxy)
engine.POST("api/config/proxy/remove", s.RemoveProxy)
engine.POST("api/config/debug", s.SetDebug)
engine.GET("/api/config/chat-roles/get", s.GetChatRoleListHandle)
engine.POST("api/config/user/add", s.AddUserHandle)
engine.POST("api/config/user/batch-add", s.BatchAddUserHandle)
engine.POST("api/config/user/set", s.SetUserHandle)
engine.POST("api/config/user/remove", s.RemoveUserHandle)
engine.POST("api/config/apikey/add", s.AddApiKeyHandle)
engine.POST("api/config/apikey/remove", s.RemoveApiKeyHandle)
engine.POST("api/config/apikey/list", s.ListApiKeysHandle)
engine.POST("api/config/role/set", s.UpdateChatRoleHandle)
engine.POST("api/config/proxy/add", s.AddProxyHandle)
engine.POST("api/config/proxy/remove", s.RemoveProxyHandle)
engine.POST("api/config/debug", s.SetDebugHandle)
engine.NoRoute(func(c *gin.Context) {
if c.Request.URL.Path == "/favicon.ico" {
@ -122,7 +123,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
func Recover(c *gin.Context) {
defer func() {
if r := recover(); r != nil {
log.Printf("panic: %v\n", r)
logger.Error("panic: %v\n", r)
debug.PrintStack()
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
c.Abort()
@ -156,7 +157,7 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Token")
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Username")
// 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间
@ -171,7 +172,7 @@ func corsMiddleware() gin.HandlerFunc {
defer func() {
if err := recover(); err != nil {
log.Printf("Panic info is: %v", err)
logger.Info("Panic info is: %v", err)
}
}()
@ -242,27 +243,28 @@ func (s *Server) GetSessionHandle(c *gin.Context) {
}
func (s *Server) LoginHandle(c *gin.Context) {
var data map[string]string
var data struct {
Token string `json:"token"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
return
}
token := data["token"]
if !utils.ContainToken(GetTokens(), token) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"})
if !utils.Containuser(GetUsers(), data.Token) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid user"})
return
}
sessionId := utils.RandString(42)
session := sessions.Default(c)
session.Set(sessionId, token)
session.Set(sessionId, data.Token)
err = session.Save()
if err != nil {
logger.Error("Error for save session: ", err)
}
// 记录客户端 IP 地址
s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Token: token, SessionId: sessionId}
s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Username: data.Token, SessionId: sessionId}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId})
}

View File

@ -14,6 +14,14 @@ type Message struct {
Content string `json:"content"`
}
// HistoryMessage 历史聊天消息
type HistoryMessage struct {
Type string `json:"type"`
Id string `json:"id"`
Icon string `json:"icon"`
Content string `json:"content"`
}
type ApiResponse struct {
Choices []ChoiceItem `json:"choices"`
}
@ -37,7 +45,7 @@ type ChatRole struct {
type ChatSession struct {
SessionId string `json:"session_id"`
ClientIP string `json:"client_ip"` // 客户端 IP
Token string `json:"token"` // 当前登录的 token
Username string `json:"user"` // 当前登录的 user
}
func GetDefaultChatRole() map[string]ChatRole {

View File

@ -13,10 +13,11 @@ type Config struct {
Chat Chat
}
type Token struct {
type User struct {
Name string `json:"name"`
MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制
RemainingCalls int `json:"remaining_calls"` // 剩余调用次数
EnableHistory bool `json:"enable_history"` // 是否启用聊天记录
}
// Chat configs struct

View File

@ -6,7 +6,7 @@ type BizVo struct {
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
Total int `json:"total,omitempty"`
Message string `json:"message"`
Message string `json:"message,omitempty"`
Data interface{} `json:"data,omitempty"`
}
@ -36,5 +36,5 @@ const (
ErrorMsg = "系统开小差了"
)
const TokenName = "ChatGPT-Token"
const TokenName = "ChatGPT-TOKEN"
const SessionKey = "WEB_SSH_SESSION"

View File

@ -41,9 +41,9 @@ func ContainsStr(slice []string, item string) bool {
return false
}
func ContainToken(slice []types.Token, token string) bool {
func Containuser(slice []types.User, user string) bool {
for _, e := range slice {
if e.Name == token {
if e.Name == user {
return true
}
}

View File

@ -10,7 +10,7 @@ axios.defaults.headers.post['Content-Type'] = 'application/json'
axios.interceptors.request.use(
config => {
// set token
config.headers['ChatGPT-Token'] = getSessionId();
config.headers['ChatGPT-TOKEN'] = getSessionId();
return config
}, error => {
return Promise.reject(error)

View File

@ -185,10 +185,13 @@ export default defineComponent({
window.addEventListener("resize", () => {
this.chatBoxHeight = window.innerHeight - this.toolBoxHeight;
this.inputBoxWidth = window.innerWidth - 20;
});
this.connect();
this.fetchChatHistory();
},
methods: {
@ -300,6 +303,17 @@ export default defineComponent({
break;
}
}
this.fetchChatHistory();
},
//
fetchChatHistory: function () {
httpPost("/api/chat/history", {role: this.role}).then((res) => {
this.chatData = res.data
}).catch((e) => {
console.error(e.message)
})
},
inputKeyDown: function (e) {