实现 API Key 负载均衡,修复 WebSocket session 失效问题

This commit is contained in:
RockYang 2023-03-22 13:51:27 +08:00
parent 97acfe57e7
commit 20bdf12180
10 changed files with 118 additions and 42 deletions

View File

@ -6,9 +6,9 @@
* [ ] 使用 level DB 保存用户聊天的上下文 * [ ] 使用 level DB 保存用户聊天的上下文
* [ ] 使用 MySQL 保存用户的聊天的历史记录 * [ ] 使用 MySQL 保存用户的聊天的历史记录
* [ ] 用户聊天鉴权,设置口令模式 * [x] 用户聊天鉴权,设置口令模式
* [ ] 每次连接自动加载历史记录 * [ ] 每次连接自动加载历史记录
* [ ] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次防止被封 * [x] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次防止被封
* [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师... * [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师...
* [ ] markdown 语法解析 * [ ] markdown 语法解析
* [ ] 用户配置界面 * [ ] 用户配置界面

View File

@ -11,7 +11,7 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
//go:embed web //go:embed dist
var webRoot embed.FS var webRoot embed.FS
func main() { func main() {
@ -42,5 +42,5 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
s.Run(webRoot, "web") s.Run(webRoot, "dist")
} }

11
main_test.go Normal file
View File

@ -0,0 +1,11 @@
package main
import (
"fmt"
"testing"
"time"
)
func TestTime(t *testing.T) {
fmt.Println(time.Now().Unix())
}

View File

@ -22,6 +22,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
logger.Fatal(err) logger.Fatal(err)
return return
} }
token := c.Query("token")
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr) logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
client := NewWsClient(ws) client := NewWsClient(ws)
go func() { go func() {
@ -34,8 +35,8 @@ func (s *Server) ChatHandle(c *gin.Context) {
} }
logger.Info(string(message)) logger.Info(string(message))
// TODO: 根据会话请求,传入不同的用户 ID // TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
err = s.sendMessage("test", string(message), client) err = s.sendMessage(token, string(message), client)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
@ -54,7 +55,6 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
var history []types.Message var history []types.Message
if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext { if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext {
history = v history = v
//logger.Infof("上下文历史消息:%+v", history)
} else { } else {
history = make([]types.Message, 0) history = make([]types.Message, 0)
} }
@ -74,14 +74,16 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
} }
request.Header.Add("Content-Type", "application/json") request.Header.Add("Content-Type", "application/json")
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
// TODO: 需要将失败的 Key 移除列表
rand.Seed(time.Now().UnixNano())
var retryCount = 3 var retryCount = 3
var response *http.Response var response *http.Response
var failedKey = ""
for retryCount > 0 { for retryCount > 0 {
index := rand.Intn(len(s.Config.Chat.ApiKeys)) apiKey := s.getApiKey(failedKey)
apiKey := s.Config.Chat.ApiKeys[index] if apiKey == "" {
logger.Info("Too many requests, all Api Key is not available")
time.Sleep(time.Second)
continue
}
logger.Infof("Use API KEY: %s", apiKey) logger.Infof("Use API KEY: %s", apiKey)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
response, err = s.Client.Do(request) response, err = s.Client.Do(request)
@ -89,6 +91,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
break break
} else { } else {
logger.Error(err) logger.Error(err)
failedKey = apiKey
} }
retryCount-- retryCount--
} }
@ -148,6 +151,34 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
return nil return nil
} }
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
func (s *Server) getApiKey(failedKey string) string {
var keys = make([]string, 0)
for _, v := range s.Config.Chat.ApiKeys {
// 过滤掉刚刚失败的 Key
if v == failedKey {
continue
}
// 获取 API Key 的上次调用时间,控制调用频率
var lastAccess int64
if t, ok := s.ApiKeyAccessStat[v]; ok {
lastAccess = t
}
// 保持每分钟访问不超过 15 次
if time.Now().Unix()-lastAccess <= 4 {
continue
}
keys = append(keys, v)
}
rand.Seed(time.Now().UnixNano())
if len(keys) > 0 {
return keys[rand.Intn(len(keys))]
}
return ""
}
// 回复客户端消息 // 回复客户端消息
func replyMessage(message types.WsMessage, client Client) { func replyMessage(message types.WsMessage, client Client) {
msg, err := json.Marshal(message) msg, err := json.Marshal(message)

View File

@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"openai/types" "openai/types"
"openai/utils"
"strconv" "strconv"
) )
@ -91,8 +92,10 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
} }
if token, ok := data["token"]; ok { if token, ok := data["token"]; ok {
if !utils.ContainsItem(s.Config.Tokens, token) {
s.Config.Tokens = append(s.Config.Tokens, token) s.Config.Tokens = append(s.Config.Tokens, token)
} }
}
// 保存配置文件 // 保存配置文件
logger.Infof("Config: %+v", s.Config) logger.Infof("Config: %+v", s.Config)

View File

@ -3,7 +3,6 @@ package server
import ( import (
"embed" "embed"
"encoding/json" "encoding/json"
"fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -38,7 +37,10 @@ type Server struct {
Client *http.Client Client *http.Client
History map[string][]types.Message History map[string][]types.Message
WsSession map[string]string // 关闭 Websocket 会话 // 保存 Websocket 会话 Token, 每个 Token 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
WsSession map[string]string
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
} }
func NewServer(configPath string) (*Server, error) { func NewServer(configPath string) (*Server, error) {
@ -61,6 +63,7 @@ func NewServer(configPath string) (*Server, error) {
ConfigPath: configPath, ConfigPath: configPath,
History: make(map[string][]types.Message, 16), History: make(map[string][]types.Message, 16),
WsSession: make(map[string]string), WsSession: make(map[string]string),
ApiKeyAccessStat: make(map[string]int64),
}, nil }, nil
} }
@ -143,22 +146,32 @@ func corsMiddleware() gin.HandlerFunc {
// AuthorizeMiddleware 用户授权验证 // AuthorizeMiddleware 用户授权验证
func AuthorizeMiddleware(s *Server) gin.HandlerFunc { func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if !s.Config.EnableAuth || c.Request.URL.Path == "/api/login" || c.Request.URL.Path == "/api/config/set" { if !s.Config.EnableAuth ||
c.Request.URL.Path == "/api/login" ||
c.Request.URL.Path == "/api/config/set" ||
!strings.HasPrefix(c.Request.URL.Path, "/api") {
c.Next() c.Next()
return return
} }
// WebSocket 连接请求验证
if c.Request.URL.Path == "/api/chat" {
tokenName := c.Query("token") tokenName := c.Query("token")
if tokenName == "" {
tokenName = c.GetHeader(types.TokenName)
}
// TODO: 会话过期设置
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() { if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
session := sessions.Default(c) // 每个令牌只能连接一次
user := session.Get(tokenName) delete(s.WsSession, tokenName)
if user != nil { c.Next()
c.Set(types.SessionKey, user) } else {
c.Abort()
} }
return
}
tokenName := c.GetHeader(types.TokenName)
session := sessions.Default(c)
userInfo := session.Get(tokenName)
if userInfo != nil {
c.Set(types.SessionKey, userInfo)
c.Next() c.Next()
} else { } else {
c.Abort() c.Abort()
@ -171,8 +184,7 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
} }
func (s *Server) GetSessionHandle(c *gin.Context) { func (s *Server) GetSessionHandle(c *gin.Context) {
session := sessions.Default(c) c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: session.Get(types.TokenName)})
} }
func (s *Server) LoginHandle(c *gin.Context) { func (s *Server) LoginHandle(c *gin.Context) {
@ -201,5 +213,5 @@ func (s *Server) LoginHandle(c *gin.Context) {
} }
func Hello(c *gin.Context) { func Hello(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0, "message": fmt.Sprintf("HELLO, ChatGPT !!!")}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: "HELLO, ChatGPT !!!"})
} }

View File

@ -46,7 +46,7 @@ func NewDefaultConfig() *Config {
Session: Session{ Session: Session{
SecretKey: utils.RandString(64), SecretKey: utils.RandString(64),
Name: "CHAT_GPT_SESSION_ID", Name: "CHAT_SESSION_ID",
Domain: "", Domain: "",
Path: "/", Path: "/",
MaxAge: 86400, MaxAge: 86400,

View File

@ -1,2 +1,2 @@
VUE_APP_API_HOST=127.0.0.1:5678 VUE_APP_API_HOST=172.22.11.200:5678
VUE_APP_API_SECURE=false VUE_APP_API_SECURE=false

View File

@ -36,7 +36,7 @@
<div class="btn-container"> <div class="btn-container">
<el-row> <el-row>
<el-button type="success" class="send" :disabled="sending" v-on:click="sendMessage">发送</el-button> <el-button type="success" class="send" :disabled="sending" v-on:click="sendMessage">发送</el-button>
<el-button type="info" class="config" circle @click="showConnectDialog = true"> <el-button type="info" class="config" ref="send-btn" circle @click="showConnectDialog = true">
<el-icon> <el-icon>
<Tools/> <Tools/>
</el-icon> </el-icon>
@ -137,7 +137,11 @@ export default defineComponent({
// //
checkSession: function () { checkSession: function () {
httpPost("/api/session/get").then(() => { httpPost("/api/session/get").then(() => {
if (this.socket == null) {
this.connect(); this.connect();
}
//
setTimeout(() => this.checkSession(), 5000);
}).catch((res) => { }).catch((res) => {
if (res.code === 400) { if (res.code === 400) {
this.showLoginDialog = true; this.showLoginDialog = true;
@ -230,7 +234,16 @@ export default defineComponent({
}, },
// //
sendMessage: function () { sendMessage: function (e) {
//
if (e) {
let target = e.target;
if (target.nodeName === "SPAN") {
target = e.target.parentNode;
}
target.blur();
}
if (this.sending || this.inputValue.trim().length === 0) { if (this.sending || this.inputValue.trim().length === 0) {
return false; return false;
} }
@ -248,7 +261,7 @@ export default defineComponent({
this.$refs["text-input"].blur(); this.$refs["text-input"].blur();
this.inputValue = ''; this.inputValue = '';
// textarea // textarea
setTimeout(() => this.$refs["text-input"].focus(), 100) setTimeout(() => this.$refs["text-input"].focus(), 100);
return true; return true;
}, },
@ -377,6 +390,12 @@ export default defineComponent({
.send { .send {
width 60px; width 60px;
height 40px; height 40px;
background-color: var(--el-color-success)
}
.is-disabled {
background-color: var(--el-button-disabled-bg-color);
border-color: var(--el-button-disabled-border-color);
} }
} }
} }

View File

@ -17,10 +17,10 @@ module.exports = defineConfig({
}, },
publicPath: process.env.NODE_ENV === 'production' publicPath: process.env.NODE_ENV === 'production'
? '/web' ? '/chat'
: '/', : '/',
outputDir: 'dist', outputDir: '../dist',
crossorigin: "anonymous", crossorigin: "anonymous",
devServer: { devServer: {
allowedHosts: ['127.0.0.1:5678'], allowedHosts: ['127.0.0.1:5678'],