mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
实现 API Key 负载均衡,修复 WebSocket session 失效问题
This commit is contained in:
parent
97acfe57e7
commit
20bdf12180
@ -6,9 +6,9 @@
|
|||||||
|
|
||||||
* [ ] 使用 level DB 保存用户聊天的上下文
|
* [ ] 使用 level DB 保存用户聊天的上下文
|
||||||
* [ ] 使用 MySQL 保存用户的聊天的历史记录
|
* [ ] 使用 MySQL 保存用户的聊天的历史记录
|
||||||
* [ ] 用户聊天鉴权,设置口令模式
|
* [x] 用户聊天鉴权,设置口令模式
|
||||||
* [ ] 每次连接自动加载历史记录
|
* [ ] 每次连接自动加载历史记录
|
||||||
* [ ] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封
|
* [x] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封
|
||||||
* [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师...
|
* [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师...
|
||||||
* [ ] markdown 语法解析
|
* [ ] markdown 语法解析
|
||||||
* [ ] 用户配置界面
|
* [ ] 用户配置界面
|
||||||
|
4
main.go
4
main.go
@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
//go:embed web
|
//go:embed dist
|
||||||
var webRoot embed.FS
|
var webRoot embed.FS
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -42,5 +42,5 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
s.Run(webRoot, "web")
|
s.Run(webRoot, "dist")
|
||||||
}
|
}
|
||||||
|
11
main_test.go
Normal file
11
main_test.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTime(t *testing.T) {
|
||||||
|
fmt.Println(time.Now().Unix())
|
||||||
|
}
|
@ -22,6 +22,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
|||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
token := c.Query("token")
|
||||||
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
||||||
client := NewWsClient(ws)
|
client := NewWsClient(ws)
|
||||||
go func() {
|
go func() {
|
||||||
@ -34,8 +35,8 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info(string(message))
|
logger.Info(string(message))
|
||||||
// TODO: 根据会话请求,传入不同的用户 ID
|
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||||
err = s.sendMessage("test", string(message), client)
|
err = s.sendMessage(token, string(message), client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
@ -54,7 +55,6 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
|||||||
var history []types.Message
|
var history []types.Message
|
||||||
if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext {
|
if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext {
|
||||||
history = v
|
history = v
|
||||||
//logger.Infof("上下文历史消息:%+v", history)
|
|
||||||
} else {
|
} else {
|
||||||
history = make([]types.Message, 0)
|
history = make([]types.Message, 0)
|
||||||
}
|
}
|
||||||
@ -74,14 +74,16 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Header.Add("Content-Type", "application/json")
|
request.Header.Add("Content-Type", "application/json")
|
||||||
// 随机获取一个 API Key,如果请求失败,则更换 API Key 重试
|
|
||||||
// TODO: 需要将失败的 Key 移除列表
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
var retryCount = 3
|
var retryCount = 3
|
||||||
var response *http.Response
|
var response *http.Response
|
||||||
|
var failedKey = ""
|
||||||
for retryCount > 0 {
|
for retryCount > 0 {
|
||||||
index := rand.Intn(len(s.Config.Chat.ApiKeys))
|
apiKey := s.getApiKey(failedKey)
|
||||||
apiKey := s.Config.Chat.ApiKeys[index]
|
if apiKey == "" {
|
||||||
|
logger.Info("Too many requests, all Api Key is not available")
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
logger.Infof("Use API KEY: %s", apiKey)
|
logger.Infof("Use API KEY: %s", apiKey)
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||||
response, err = s.Client.Do(request)
|
response, err = s.Client.Do(request)
|
||||||
@ -89,6 +91,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
|||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
|
failedKey = apiKey
|
||||||
}
|
}
|
||||||
retryCount--
|
retryCount--
|
||||||
}
|
}
|
||||||
@ -148,6 +151,34 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 随机获取一个 API Key,如果请求失败,则更换 API Key 重试
|
||||||
|
func (s *Server) getApiKey(failedKey string) string {
|
||||||
|
var keys = make([]string, 0)
|
||||||
|
for _, v := range s.Config.Chat.ApiKeys {
|
||||||
|
// 过滤掉刚刚失败的 Key
|
||||||
|
if v == failedKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 API Key 的上次调用时间,控制调用频率
|
||||||
|
var lastAccess int64
|
||||||
|
if t, ok := s.ApiKeyAccessStat[v]; ok {
|
||||||
|
lastAccess = t
|
||||||
|
}
|
||||||
|
// 保持每分钟访问不超过 15 次
|
||||||
|
if time.Now().Unix()-lastAccess <= 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
keys = append(keys, v)
|
||||||
|
}
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
if len(keys) > 0 {
|
||||||
|
return keys[rand.Intn(len(keys))]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// 回复客户端消息
|
// 回复客户端消息
|
||||||
func replyMessage(message types.WsMessage, client Client) {
|
func replyMessage(message types.WsMessage, client Client) {
|
||||||
msg, err := json.Marshal(message)
|
msg, err := json.Marshal(message)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"openai/types"
|
"openai/types"
|
||||||
|
"openai/utils"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -91,8 +92,10 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if token, ok := data["token"]; ok {
|
if token, ok := data["token"]; ok {
|
||||||
|
if !utils.ContainsItem(s.Config.Tokens, token) {
|
||||||
s.Config.Tokens = append(s.Config.Tokens, token)
|
s.Config.Tokens = append(s.Config.Tokens, token)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 保存配置文件
|
// 保存配置文件
|
||||||
logger.Infof("Config: %+v", s.Config)
|
logger.Infof("Config: %+v", s.Config)
|
||||||
|
@ -3,7 +3,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-contrib/sessions/cookie"
|
"github.com/gin-contrib/sessions/cookie"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -38,7 +37,10 @@ type Server struct {
|
|||||||
Client *http.Client
|
Client *http.Client
|
||||||
History map[string][]types.Message
|
History map[string][]types.Message
|
||||||
|
|
||||||
WsSession map[string]string // 关闭 Websocket 会话
|
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
|
||||||
|
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||||
|
WsSession map[string]string
|
||||||
|
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(configPath string) (*Server, error) {
|
func NewServer(configPath string) (*Server, error) {
|
||||||
@ -61,6 +63,7 @@ func NewServer(configPath string) (*Server, error) {
|
|||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
History: make(map[string][]types.Message, 16),
|
History: make(map[string][]types.Message, 16),
|
||||||
WsSession: make(map[string]string),
|
WsSession: make(map[string]string),
|
||||||
|
ApiKeyAccessStat: make(map[string]int64),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,22 +146,32 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
// AuthorizeMiddleware 用户授权验证
|
// AuthorizeMiddleware 用户授权验证
|
||||||
func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if !s.Config.EnableAuth || c.Request.URL.Path == "/api/login" || c.Request.URL.Path == "/api/config/set" {
|
if !s.Config.EnableAuth ||
|
||||||
|
c.Request.URL.Path == "/api/login" ||
|
||||||
|
c.Request.URL.Path == "/api/config/set" ||
|
||||||
|
!strings.HasPrefix(c.Request.URL.Path, "/api") {
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebSocket 连接请求验证
|
||||||
|
if c.Request.URL.Path == "/api/chat" {
|
||||||
tokenName := c.Query("token")
|
tokenName := c.Query("token")
|
||||||
if tokenName == "" {
|
|
||||||
tokenName = c.GetHeader(types.TokenName)
|
|
||||||
}
|
|
||||||
// TODO: 会话过期设置
|
|
||||||
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
|
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
|
||||||
session := sessions.Default(c)
|
// 每个令牌只能连接一次
|
||||||
user := session.Get(tokenName)
|
delete(s.WsSession, tokenName)
|
||||||
if user != nil {
|
c.Next()
|
||||||
c.Set(types.SessionKey, user)
|
} else {
|
||||||
|
c.Abort()
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenName := c.GetHeader(types.TokenName)
|
||||||
|
session := sessions.Default(c)
|
||||||
|
userInfo := session.Get(tokenName)
|
||||||
|
if userInfo != nil {
|
||||||
|
c.Set(types.SessionKey, userInfo)
|
||||||
c.Next()
|
c.Next()
|
||||||
} else {
|
} else {
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@ -171,8 +184,7 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetSessionHandle(c *gin.Context) {
|
func (s *Server) GetSessionHandle(c *gin.Context) {
|
||||||
session := sessions.Default(c)
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: session.Get(types.TokenName)})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) LoginHandle(c *gin.Context) {
|
func (s *Server) LoginHandle(c *gin.Context) {
|
||||||
@ -201,5 +213,5 @@ func (s *Server) LoginHandle(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Hello(c *gin.Context) {
|
func Hello(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": fmt.Sprintf("HELLO, ChatGPT !!!")})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: "HELLO, ChatGPT !!!"})
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ func NewDefaultConfig() *Config {
|
|||||||
|
|
||||||
Session: Session{
|
Session: Session{
|
||||||
SecretKey: utils.RandString(64),
|
SecretKey: utils.RandString(64),
|
||||||
Name: "CHAT_GPT_SESSION_ID",
|
Name: "CHAT_SESSION_ID",
|
||||||
Domain: "",
|
Domain: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: 86400,
|
MaxAge: 86400,
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
VUE_APP_API_HOST=127.0.0.1:5678
|
VUE_APP_API_HOST=172.22.11.200:5678
|
||||||
VUE_APP_API_SECURE=false
|
VUE_APP_API_SECURE=false
|
@ -36,7 +36,7 @@
|
|||||||
<div class="btn-container">
|
<div class="btn-container">
|
||||||
<el-row>
|
<el-row>
|
||||||
<el-button type="success" class="send" :disabled="sending" v-on:click="sendMessage">发送</el-button>
|
<el-button type="success" class="send" :disabled="sending" v-on:click="sendMessage">发送</el-button>
|
||||||
<el-button type="info" class="config" circle @click="showConnectDialog = true">
|
<el-button type="info" class="config" ref="send-btn" circle @click="showConnectDialog = true">
|
||||||
<el-icon>
|
<el-icon>
|
||||||
<Tools/>
|
<Tools/>
|
||||||
</el-icon>
|
</el-icon>
|
||||||
@ -137,7 +137,11 @@ export default defineComponent({
|
|||||||
// 检查会话
|
// 检查会话
|
||||||
checkSession: function () {
|
checkSession: function () {
|
||||||
httpPost("/api/session/get").then(() => {
|
httpPost("/api/session/get").then(() => {
|
||||||
|
if (this.socket == null) {
|
||||||
this.connect();
|
this.connect();
|
||||||
|
}
|
||||||
|
// 发送心跳
|
||||||
|
setTimeout(() => this.checkSession(), 5000);
|
||||||
}).catch((res) => {
|
}).catch((res) => {
|
||||||
if (res.code === 400) {
|
if (res.code === 400) {
|
||||||
this.showLoginDialog = true;
|
this.showLoginDialog = true;
|
||||||
@ -230,7 +234,16 @@ export default defineComponent({
|
|||||||
},
|
},
|
||||||
|
|
||||||
// 发送消息
|
// 发送消息
|
||||||
sendMessage: function () {
|
sendMessage: function (e) {
|
||||||
|
// 强制按钮失去焦点
|
||||||
|
if (e) {
|
||||||
|
let target = e.target;
|
||||||
|
if (target.nodeName === "SPAN") {
|
||||||
|
target = e.target.parentNode;
|
||||||
|
}
|
||||||
|
target.blur();
|
||||||
|
}
|
||||||
|
|
||||||
if (this.sending || this.inputValue.trim().length === 0) {
|
if (this.sending || this.inputValue.trim().length === 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -248,7 +261,7 @@ export default defineComponent({
|
|||||||
this.$refs["text-input"].blur();
|
this.$refs["text-input"].blur();
|
||||||
this.inputValue = '';
|
this.inputValue = '';
|
||||||
// 等待 textarea 重新调整尺寸之后再自动获取焦点
|
// 等待 textarea 重新调整尺寸之后再自动获取焦点
|
||||||
setTimeout(() => this.$refs["text-input"].focus(), 100)
|
setTimeout(() => this.$refs["text-input"].focus(), 100);
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -377,6 +390,12 @@ export default defineComponent({
|
|||||||
.send {
|
.send {
|
||||||
width 60px;
|
width 60px;
|
||||||
height 40px;
|
height 40px;
|
||||||
|
background-color: var(--el-color-success)
|
||||||
|
}
|
||||||
|
|
||||||
|
.is-disabled {
|
||||||
|
background-color: var(--el-button-disabled-bg-color);
|
||||||
|
border-color: var(--el-button-disabled-border-color);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,10 +17,10 @@ module.exports = defineConfig({
|
|||||||
},
|
},
|
||||||
|
|
||||||
publicPath: process.env.NODE_ENV === 'production'
|
publicPath: process.env.NODE_ENV === 'production'
|
||||||
? '/web'
|
? '/chat'
|
||||||
: '/',
|
: '/',
|
||||||
|
|
||||||
outputDir: 'dist',
|
outputDir: '../dist',
|
||||||
crossorigin: "anonymous",
|
crossorigin: "anonymous",
|
||||||
devServer: {
|
devServer: {
|
||||||
allowedHosts: ['127.0.0.1:5678'],
|
allowedHosts: ['127.0.0.1:5678'],
|
||||||
|
Loading…
Reference in New Issue
Block a user