From 20bdf121804085255e78f0b01f08e9f8c793848c Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 22 Mar 2023 13:51:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20API=20Key=20=E8=B4=9F?= =?UTF-8?q?=E8=BD=BD=E5=9D=87=E8=A1=A1=EF=BC=8C=E4=BF=AE=E5=A4=8D=20WebSoc?= =?UTF-8?q?ket=20session=20=E5=A4=B1=E6=95=88=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 +-- main.go | 4 +-- main_test.go | 11 ++++++++ server/chat_handler.go | 47 ++++++++++++++++++++++++++++------ server/config_handler.go | 5 +++- server/server.go | 54 ++++++++++++++++++++++++---------------- types/config.go | 2 +- web/.env.development | 2 +- web/src/views/Chat.vue | 27 +++++++++++++++++--- web/vue.config.js | 4 +-- 10 files changed, 118 insertions(+), 42 deletions(-) create mode 100644 main_test.go diff --git a/README.md b/README.md index 1800621f..b1c2e062 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ * [ ] 使用 level DB 保存用户聊天的上下文 * [ ] 使用 MySQL 保存用户的聊天的历史记录 -* [ ] 用户聊天鉴权,设置口令模式 +* [x] 用户聊天鉴权,设置口令模式 * [ ] 每次连接自动加载历史记录 -* [ ] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封 +* [x] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封 * [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师... * [ ] markdown 语法解析 * [ ] 用户配置界面 diff --git a/main.go b/main.go index fc965081..cd1bc1cc 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,7 @@ import ( var logger = logger2.GetLogger() -//go:embed web +//go:embed dist var webRoot embed.FS func main() { @@ -42,5 +42,5 @@ func main() { if err != nil { panic(err) } - s.Run(webRoot, "web") + s.Run(webRoot, "dist") } diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..157a8693 --- /dev/null +++ b/main_test.go @@ -0,0 +1,11 @@ +package main + +import ( + "fmt" + "testing" + "time" +) + +func TestTime(t *testing.T) { + fmt.Println(time.Now().Unix()) +} diff --git a/server/chat_handler.go b/server/chat_handler.go index a362b8c2..7ef7989f 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -22,6 +22,7 @@ func (s *Server) ChatHandle(c *gin.Context) { logger.Fatal(err) return } + token := c.Query("token") logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr) client := NewWsClient(ws) go func() { @@ -34,8 +35,8 @@ func (s *Server) ChatHandle(c *gin.Context) { } logger.Info(string(message)) - // TODO: 根据会话请求,传入不同的用户 ID - err = s.sendMessage("test", string(message), client) + // TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录 + err = s.sendMessage(token, string(message), client) if err != nil { logger.Error(err) } @@ -54,7 +55,6 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { var history []types.Message if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext { history = v - //logger.Infof("上下文历史消息:%+v", history) } else { history = make([]types.Message, 0) } @@ -74,14 +74,16 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { } request.Header.Add("Content-Type", "application/json") - // 随机获取一个 API Key,如果请求失败,则更换 API Key 重试 - // TODO: 需要将失败的 Key 移除列表 - rand.Seed(time.Now().UnixNano()) var retryCount = 3 var response *http.Response + var failedKey = "" for retryCount > 0 { - index := rand.Intn(len(s.Config.Chat.ApiKeys)) - apiKey := s.Config.Chat.ApiKeys[index] + apiKey := s.getApiKey(failedKey) + if apiKey == "" { + logger.Info("Too many requests, all Api Key is not available") + time.Sleep(time.Second) + continue + } logger.Infof("Use API KEY: %s", apiKey) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) response, err = s.Client.Do(request) @@ -89,6 +91,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { break } else { logger.Error(err) + failedKey = apiKey } retryCount-- } @@ -148,6 +151,34 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { return nil } +// 随机获取一个 API Key,如果请求失败,则更换 API Key 重试 +func (s *Server) getApiKey(failedKey string) string { + var keys = make([]string, 0) + for _, v := range s.Config.Chat.ApiKeys { + // 过滤掉刚刚失败的 Key + if v == failedKey { + continue + } + + // 获取 API Key 的上次调用时间,控制调用频率 + var lastAccess int64 + if t, ok := s.ApiKeyAccessStat[v]; ok { + lastAccess = t + } + // 保持每分钟访问不超过 15 次 + if time.Now().Unix()-lastAccess <= 4 { + continue + } + + keys = append(keys, v) + } + rand.Seed(time.Now().UnixNano()) + if len(keys) > 0 { + return keys[rand.Intn(len(keys))] + } + return "" +} + // 回复客户端消息 func replyMessage(message types.WsMessage, client Client) { msg, err := json.Marshal(message) diff --git a/server/config_handler.go b/server/config_handler.go index 806cda50..338d9423 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "openai/types" + "openai/utils" "strconv" ) @@ -91,7 +92,9 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { } if token, ok := data["token"]; ok { - s.Config.Tokens = append(s.Config.Tokens, token) + if !utils.ContainsItem(s.Config.Tokens, token) { + s.Config.Tokens = append(s.Config.Tokens, token) + } } // 保存配置文件 diff --git a/server/server.go b/server/server.go index a3fe8c4a..67a00f33 100644 --- a/server/server.go +++ b/server/server.go @@ -3,7 +3,6 @@ package server import ( "embed" "encoding/json" - "fmt" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" @@ -38,7 +37,10 @@ type Server struct { Client *http.Client History map[string][]types.Message - WsSession map[string]string // 关闭 Websocket 会话 + // 保存 Websocket 会话 Token, 每个 Token 只能连接一次 + // 防止第三方直接连接 socket 调用 OpenAI API + WsSession map[string]string + ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 } func NewServer(configPath string) (*Server, error) { @@ -56,11 +58,12 @@ func NewServer(configPath string) (*Server, error) { }, } return &Server{ - Config: config, - Client: client, - ConfigPath: configPath, - History: make(map[string][]types.Message, 16), - WsSession: make(map[string]string), + Config: config, + Client: client, + ConfigPath: configPath, + History: make(map[string][]types.Message, 16), + WsSession: make(map[string]string), + ApiKeyAccessStat: make(map[string]int64), }, nil } @@ -143,22 +146,32 @@ func corsMiddleware() gin.HandlerFunc { // AuthorizeMiddleware 用户授权验证 func AuthorizeMiddleware(s *Server) gin.HandlerFunc { return func(c *gin.Context) { - if !s.Config.EnableAuth || c.Request.URL.Path == "/api/login" || c.Request.URL.Path == "/api/config/set" { + if !s.Config.EnableAuth || + c.Request.URL.Path == "/api/login" || + c.Request.URL.Path == "/api/config/set" || + !strings.HasPrefix(c.Request.URL.Path, "/api") { c.Next() return } - tokenName := c.Query("token") - if tokenName == "" { - tokenName = c.GetHeader(types.TokenName) - } - // TODO: 会话过期设置 - if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() { - session := sessions.Default(c) - user := session.Get(tokenName) - if user != nil { - c.Set(types.SessionKey, user) + // WebSocket 连接请求验证 + if c.Request.URL.Path == "/api/chat" { + tokenName := c.Query("token") + if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() { + // 每个令牌只能连接一次 + delete(s.WsSession, tokenName) + c.Next() + } else { + c.Abort() } + return + } + + tokenName := c.GetHeader(types.TokenName) + session := sessions.Default(c) + userInfo := session.Get(tokenName) + if userInfo != nil { + c.Set(types.SessionKey, userInfo) c.Next() } else { c.Abort() @@ -171,8 +184,7 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc { } func (s *Server) GetSessionHandle(c *gin.Context) { - session := sessions.Default(c) - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: session.Get(types.TokenName)}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) } func (s *Server) LoginHandle(c *gin.Context) { @@ -201,5 +213,5 @@ func (s *Server) LoginHandle(c *gin.Context) { } func Hello(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": fmt.Sprintf("HELLO, ChatGPT !!!")}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: "HELLO, ChatGPT !!!"}) } diff --git a/types/config.go b/types/config.go index 3837eb91..72a45b09 100644 --- a/types/config.go +++ b/types/config.go @@ -46,7 +46,7 @@ func NewDefaultConfig() *Config { Session: Session{ SecretKey: utils.RandString(64), - Name: "CHAT_GPT_SESSION_ID", + Name: "CHAT_SESSION_ID", Domain: "", Path: "/", MaxAge: 86400, diff --git a/web/.env.development b/web/.env.development index ceae57d5..3ea203f8 100644 --- a/web/.env.development +++ b/web/.env.development @@ -1,2 +1,2 @@ -VUE_APP_API_HOST=127.0.0.1:5678 +VUE_APP_API_HOST=172.22.11.200:5678 VUE_APP_API_SECURE=false \ No newline at end of file diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index 21860d4e..d3d7664e 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -36,7 +36,7 @@
发送 - + @@ -137,7 +137,11 @@ export default defineComponent({ // 检查会话 checkSession: function () { httpPost("/api/session/get").then(() => { - this.connect(); + if (this.socket == null) { + this.connect(); + } + // 发送心跳 + setTimeout(() => this.checkSession(), 5000); }).catch((res) => { if (res.code === 400) { this.showLoginDialog = true; @@ -230,7 +234,16 @@ export default defineComponent({ }, // 发送消息 - sendMessage: function () { + sendMessage: function (e) { + // 强制按钮失去焦点 + if (e) { + let target = e.target; + if (target.nodeName === "SPAN") { + target = e.target.parentNode; + } + target.blur(); + } + if (this.sending || this.inputValue.trim().length === 0) { return false; } @@ -248,7 +261,7 @@ export default defineComponent({ this.$refs["text-input"].blur(); this.inputValue = ''; // 等待 textarea 重新调整尺寸之后再自动获取焦点 - setTimeout(() => this.$refs["text-input"].focus(), 100) + setTimeout(() => this.$refs["text-input"].focus(), 100); return true; }, @@ -377,6 +390,12 @@ export default defineComponent({ .send { width 60px; height 40px; + background-color: var(--el-color-success) + } + + .is-disabled { + background-color: var(--el-button-disabled-bg-color); + border-color: var(--el-button-disabled-border-color); } } } diff --git a/web/vue.config.js b/web/vue.config.js index a56b2be9..f64cb61c 100644 --- a/web/vue.config.js +++ b/web/vue.config.js @@ -17,10 +17,10 @@ module.exports = defineConfig({ }, publicPath: process.env.NODE_ENV === 'production' - ? '/web' + ? '/chat' : '/', - outputDir: 'dist', + outputDir: '../dist', crossorigin: "anonymous", devServer: { allowedHosts: ['127.0.0.1:5678'],