mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-26 18:15:57 +08:00
实现 API Key 负载均衡,修复 WebSocket session 失效问题
This commit is contained in:
@@ -22,6 +22,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
logger.Fatal(err)
|
||||
return
|
||||
}
|
||||
token := c.Query("token")
|
||||
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
||||
client := NewWsClient(ws)
|
||||
go func() {
|
||||
@@ -34,8 +35,8 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
logger.Info(string(message))
|
||||
// TODO: 根据会话请求,传入不同的用户 ID
|
||||
err = s.sendMessage("test", string(message), client)
|
||||
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||
err = s.sendMessage(token, string(message), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -54,7 +55,6 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
var history []types.Message
|
||||
if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext {
|
||||
history = v
|
||||
//logger.Infof("上下文历史消息:%+v", history)
|
||||
} else {
|
||||
history = make([]types.Message, 0)
|
||||
}
|
||||
@@ -74,14 +74,16 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
}
|
||||
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
// 随机获取一个 API Key,如果请求失败,则更换 API Key 重试
|
||||
// TODO: 需要将失败的 Key 移除列表
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
var retryCount = 3
|
||||
var response *http.Response
|
||||
var failedKey = ""
|
||||
for retryCount > 0 {
|
||||
index := rand.Intn(len(s.Config.Chat.ApiKeys))
|
||||
apiKey := s.Config.Chat.ApiKeys[index]
|
||||
apiKey := s.getApiKey(failedKey)
|
||||
if apiKey == "" {
|
||||
logger.Info("Too many requests, all Api Key is not available")
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Use API KEY: %s", apiKey)
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
response, err = s.Client.Do(request)
|
||||
@@ -89,6 +91,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
break
|
||||
} else {
|
||||
logger.Error(err)
|
||||
failedKey = apiKey
|
||||
}
|
||||
retryCount--
|
||||
}
|
||||
@@ -148,6 +151,34 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 随机获取一个 API Key,如果请求失败,则更换 API Key 重试
|
||||
func (s *Server) getApiKey(failedKey string) string {
|
||||
var keys = make([]string, 0)
|
||||
for _, v := range s.Config.Chat.ApiKeys {
|
||||
// 过滤掉刚刚失败的 Key
|
||||
if v == failedKey {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 API Key 的上次调用时间,控制调用频率
|
||||
var lastAccess int64
|
||||
if t, ok := s.ApiKeyAccessStat[v]; ok {
|
||||
lastAccess = t
|
||||
}
|
||||
// 保持每分钟访问不超过 15 次
|
||||
if time.Now().Unix()-lastAccess <= 4 {
|
||||
continue
|
||||
}
|
||||
|
||||
keys = append(keys, v)
|
||||
}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
if len(keys) > 0 {
|
||||
return keys[rand.Intn(len(keys))]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 回复客户端消息
|
||||
func replyMessage(message types.WsMessage, client Client) {
|
||||
msg, err := json.Marshal(message)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"openai/types"
|
||||
"openai/utils"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -91,7 +92,9 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
if token, ok := data["token"]; ok {
|
||||
s.Config.Tokens = append(s.Config.Tokens, token)
|
||||
if !utils.ContainsItem(s.Config.Tokens, token) {
|
||||
s.Config.Tokens = append(s.Config.Tokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -38,7 +37,10 @@ type Server struct {
|
||||
Client *http.Client
|
||||
History map[string][]types.Message
|
||||
|
||||
WsSession map[string]string // 关闭 Websocket 会话
|
||||
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
WsSession map[string]string
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
}
|
||||
|
||||
func NewServer(configPath string) (*Server, error) {
|
||||
@@ -56,11 +58,12 @@ func NewServer(configPath string) (*Server, error) {
|
||||
},
|
||||
}
|
||||
return &Server{
|
||||
Config: config,
|
||||
Client: client,
|
||||
ConfigPath: configPath,
|
||||
History: make(map[string][]types.Message, 16),
|
||||
WsSession: make(map[string]string),
|
||||
Config: config,
|
||||
Client: client,
|
||||
ConfigPath: configPath,
|
||||
History: make(map[string][]types.Message, 16),
|
||||
WsSession: make(map[string]string),
|
||||
ApiKeyAccessStat: make(map[string]int64),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -143,22 +146,32 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
// AuthorizeMiddleware 用户授权验证
|
||||
func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !s.Config.EnableAuth || c.Request.URL.Path == "/api/login" || c.Request.URL.Path == "/api/config/set" {
|
||||
if !s.Config.EnableAuth ||
|
||||
c.Request.URL.Path == "/api/login" ||
|
||||
c.Request.URL.Path == "/api/config/set" ||
|
||||
!strings.HasPrefix(c.Request.URL.Path, "/api") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
tokenName := c.Query("token")
|
||||
if tokenName == "" {
|
||||
tokenName = c.GetHeader(types.TokenName)
|
||||
}
|
||||
// TODO: 会话过期设置
|
||||
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
|
||||
session := sessions.Default(c)
|
||||
user := session.Get(tokenName)
|
||||
if user != nil {
|
||||
c.Set(types.SessionKey, user)
|
||||
// WebSocket 连接请求验证
|
||||
if c.Request.URL.Path == "/api/chat" {
|
||||
tokenName := c.Query("token")
|
||||
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
|
||||
// 每个令牌只能连接一次
|
||||
delete(s.WsSession, tokenName)
|
||||
c.Next()
|
||||
} else {
|
||||
c.Abort()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokenName := c.GetHeader(types.TokenName)
|
||||
session := sessions.Default(c)
|
||||
userInfo := session.Get(tokenName)
|
||||
if userInfo != nil {
|
||||
c.Set(types.SessionKey, userInfo)
|
||||
c.Next()
|
||||
} else {
|
||||
c.Abort()
|
||||
@@ -171,8 +184,7 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func (s *Server) GetSessionHandle(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: session.Get(types.TokenName)})
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
|
||||
}
|
||||
|
||||
func (s *Server) LoginHandle(c *gin.Context) {
|
||||
@@ -201,5 +213,5 @@ func (s *Server) LoginHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
func Hello(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": fmt.Sprintf("HELLO, ChatGPT !!!")})
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: "HELLO, ChatGPT !!!"})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user