mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
支持 TOKEN 设置最大调用次数
This commit is contained in:
parent
a6bab7b12d
commit
5f702d92dc
2
go.mod
2
go.mod
@ -8,6 +8,7 @@ require (
|
||||
github.com/gin-gonic/gin v1.7.7
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/mitchellh/go-homedir v1.1.0
|
||||
github.com/syndtr/goleveldb v1.0.0
|
||||
go.uber.org/zap v1.21.0
|
||||
)
|
||||
|
||||
@ -26,7 +27,6 @@ require (
|
||||
github.com/mattn/go-isatty v0.0.12 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect
|
||||
github.com/syndtr/goleveldb v1.0.0 // indirect
|
||||
github.com/ugorji/go/codec v1.1.7 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.6.0 // indirect
|
||||
|
7
go.sum
7
go.sum
@ -35,6 +35,7 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
@ -54,7 +55,9 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs=
|
||||
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU=
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@ -90,6 +93,7 @@ golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73r
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -109,6 +113,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
@ -119,7 +124,9 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
@ -16,6 +16,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const ErrorMsg = "抱歉,AI 助手开小差了,我马上找人去盘它。"
|
||||
|
||||
// ChatHandle 处理聊天 WebSocket 请求
|
||||
func (s *Server) ChatHandle(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
@ -24,16 +26,19 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
sessionId := c.Query("sessionId")
|
||||
role := c.Query("role")
|
||||
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
||||
roleKey := c.Query("role")
|
||||
session := s.ChatSession[sessionId]
|
||||
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token)
|
||||
client := NewWsClient(ws)
|
||||
if !s.ChatRoles[role].Enable { // 角色未启用
|
||||
var roles = GetChatRoles()
|
||||
var chatRole = roles[roleKey]
|
||||
if !chatRole.Enable { // 角色未启用
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 发送打招呼信息
|
||||
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
|
||||
go func() {
|
||||
for {
|
||||
@ -46,7 +51,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||
err = s.sendMessage(sessionId, role, string(message), client)
|
||||
err = s.sendMessage(session, chatRole, string(message), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@ -55,7 +60,17 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (s *Server) sendMessage(sessionId string, role string, text string, ws Client) error {
|
||||
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error {
|
||||
token, err := GetToken(session.Token)
|
||||
if err != nil {
|
||||
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!")
|
||||
return err
|
||||
}
|
||||
|
||||
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
|
||||
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用!")
|
||||
return nil
|
||||
}
|
||||
var r = types.ApiRequest{
|
||||
Model: s.Config.Chat.Model,
|
||||
Temperature: s.Config.Chat.Temperature,
|
||||
@ -63,11 +78,11 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
Stream: true,
|
||||
}
|
||||
var context []types.Message
|
||||
var key = sessionId + role
|
||||
var key = session.SessionId + role.Name
|
||||
if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext {
|
||||
context = v
|
||||
} else {
|
||||
context = s.ChatRoles[role].Context
|
||||
context = role.Context
|
||||
}
|
||||
|
||||
if s.DebugMode {
|
||||
@ -130,7 +145,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
|
||||
// 如果三次请求都失败的话,则返回对应的错误信息
|
||||
if err != nil {
|
||||
replyError(ws)
|
||||
replyError(ws, ErrorMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -155,7 +170,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil {
|
||||
logger.Error(line)
|
||||
replyError(ws)
|
||||
replyError(ws, ErrorMsg)
|
||||
break
|
||||
}
|
||||
// 初始化 role
|
||||
@ -176,7 +191,10 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
}, ws)
|
||||
}
|
||||
}
|
||||
|
||||
// 当前 Token 调用次数减 1
|
||||
if token.MaxCalls > 0 {
|
||||
token.RemainingCalls -= 1
|
||||
}
|
||||
// 追加历史消息
|
||||
context = append(context, types.Message{
|
||||
Role: "user",
|
||||
@ -190,9 +208,9 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
return nil
|
||||
}
|
||||
|
||||
func replyError(ws Client) {
|
||||
func replyError(ws Client, message string) {
|
||||
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: "抱歉,AI 助手开小差了,我马上找人去盘它。"}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,7 @@ func (s *Server) AddToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 参数处理
|
||||
var name = data["name"]
|
||||
var maxCalls = data["max_calls"]
|
||||
if name == "" || maxCalls == "" {
|
||||
@ -112,8 +113,9 @@ func (s *Server) AddToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var tokens = GetTokens()
|
||||
if utils.ContainToken(tokens, name) {
|
||||
// 检查当前要添加的 token 是否已经存在
|
||||
_, err = GetToken(name)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"})
|
||||
return
|
||||
}
|
||||
@ -127,6 +129,50 @@ func (s *Server) AddToken(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||
}
|
||||
|
||||
func (s *Server) SetToken(c *gin.Context) {
|
||||
var data map[string]string
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
if err != nil {
|
||||
logger.Errorf("Error decode json data: %s", err.Error())
|
||||
c.JSON(http.StatusBadRequest, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 参数处理
|
||||
var name = data["name"]
|
||||
var maxCalls = data["max_calls"]
|
||||
if name == "" || maxCalls == "" {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
||||
return
|
||||
}
|
||||
|
||||
token, err := GetToken(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"})
|
||||
return
|
||||
}
|
||||
|
||||
n, err := strconv.Atoi(maxCalls)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "enable_auth must be a int parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token.RemainingCalls += n - token.MaxCalls
|
||||
token.MaxCalls = n
|
||||
|
||||
err = PutToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||
}
|
||||
|
||||
// RemoveToken 删除 Token
|
||||
func (s *Server) RemoveToken(c *gin.Context) {
|
||||
var data map[string]string
|
||||
@ -205,22 +251,23 @@ func (s *Server) ListApiKeys(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) GetChatRoles(c *gin.Context) {
|
||||
//var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||
//var roles = make([]interface{}, 0)
|
||||
//for _, k := range rolesOrder {
|
||||
// if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
|
||||
// roles = append(roles, struct {
|
||||
// Key string `json:"key"`
|
||||
// Name string `json:"name"`
|
||||
// Icon string `json:"icon"`
|
||||
// }{
|
||||
// Key: v.Key,
|
||||
// Name: v.Name,
|
||||
// Icon: v.Icon,
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
|
||||
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||
var res = make([]interface{}, 0)
|
||||
var roles = GetChatRoles()
|
||||
for _, k := range rolesOrder {
|
||||
if v, ok := roles[k]; ok && v.Enable {
|
||||
res = append(res, struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
}{
|
||||
Key: v.Key,
|
||||
Name: v.Name,
|
||||
Icon: v.Icon,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: res})
|
||||
}
|
||||
|
||||
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
|
||||
@ -238,39 +285,43 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
//role := s.Config.ChatRoles[key]
|
||||
//if enable, ok := data["enable"]; ok {
|
||||
// v, err := strconv.ParseBool(enable)
|
||||
// if err != nil {
|
||||
// c.JSON(http.StatusOK, types.BizVo{
|
||||
// Code: types.InvalidParams,
|
||||
// Message: "enable must be a bool parameter",
|
||||
// })
|
||||
// return
|
||||
// }
|
||||
// role.Enable = v
|
||||
//}
|
||||
roles := GetChatRoles()
|
||||
role := roles[key]
|
||||
if role.Key == "" {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Role key not exists"})
|
||||
return
|
||||
}
|
||||
|
||||
//if name, ok := data["name"]; ok {
|
||||
// role.Name = name
|
||||
//}
|
||||
//if helloMsg, ok := data["hello_msg"]; ok {
|
||||
// role.HelloMsg = helloMsg
|
||||
//}
|
||||
//if icon, ok := data["icon"]; ok {
|
||||
// role.Icon = icon
|
||||
//}
|
||||
//
|
||||
//s.Config.ChatRoles[key] = role
|
||||
//
|
||||
//// 保存配置文件
|
||||
//err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
//if err != nil {
|
||||
// c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
if enable, ok := data["enable"]; ok {
|
||||
v, err := strconv.ParseBool(enable)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "enable must be a bool parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
role.Enable = v
|
||||
}
|
||||
|
||||
if name, ok := data["name"]; ok {
|
||||
role.Name = name
|
||||
}
|
||||
if helloMsg, ok := data["hello_msg"]; ok {
|
||||
role.HelloMsg = helloMsg
|
||||
}
|
||||
if icon, ok := data["icon"]; ok {
|
||||
role.Icon = icon
|
||||
}
|
||||
|
||||
// 保存到 leveldb
|
||||
err = PutChatRole(role)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: role})
|
||||
}
|
||||
|
||||
// AddProxy 添加一个代理
|
||||
|
27
server/db.go
27
server/db.go
@ -43,6 +43,16 @@ func PutToken(token types.Token) error {
|
||||
return db.Put(key, token)
|
||||
}
|
||||
|
||||
func GetToken(name string) (types.Token, error) {
|
||||
key := TokenPrefix + name
|
||||
token, err := db.Get(key)
|
||||
if err != nil {
|
||||
return types.Token{}, err
|
||||
}
|
||||
|
||||
return token.(types.Token), nil
|
||||
}
|
||||
|
||||
func RemoveToken(token string) error {
|
||||
key := TokenPrefix + token
|
||||
return db.Delete(key)
|
||||
@ -51,7 +61,22 @@ func RemoveToken(token string) error {
|
||||
// GetChatRoles 获取聊天角色
|
||||
// chat/roles
|
||||
func GetChatRoles() map[string]types.ChatRole {
|
||||
return nil
|
||||
items := db.Search(ChatRolePrefix)
|
||||
var roles = make(map[string]types.ChatRole)
|
||||
for _, v := range items {
|
||||
var role types.ChatRole
|
||||
err := json.Unmarshal([]byte(v), &role)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
roles[role.Key] = role
|
||||
}
|
||||
return roles
|
||||
}
|
||||
|
||||
func PutChatRole(role types.ChatRole) error {
|
||||
key := ChatRolePrefix + role.Key
|
||||
return db.Put(key, role)
|
||||
}
|
||||
|
||||
// GetChatHistory 获取聊天历史记录
|
||||
|
@ -37,10 +37,9 @@ type Server struct {
|
||||
|
||||
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
WsSession map[string]string
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
DebugMode bool // 是否开启调试模式
|
||||
ChatRoles map[string]types.ChatRole // 保存预设角色信息
|
||||
ChatSession map[string]types.ChatSession
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
DebugMode bool // 是否开启调试模式
|
||||
}
|
||||
|
||||
func NewServer(configPath string) (*Server, error) {
|
||||
@ -49,18 +48,22 @@ func NewServer(configPath string) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles := GetChatRoles()
|
||||
if roles == nil {
|
||||
if len(roles) == 0 { // 初始化默认聊天角色到 leveldb
|
||||
roles = types.GetDefaultChatRole()
|
||||
for _, v := range roles {
|
||||
err := PutChatRole(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Server{
|
||||
Config: config,
|
||||
ConfigPath: configPath,
|
||||
ChatContext: make(map[string][]types.Message, 16),
|
||||
WsSession: make(map[string]string),
|
||||
ChatSession: make(map[string]types.ChatSession),
|
||||
ApiKeyAccessStat: make(map[string]int64),
|
||||
ChatRoles: roles,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -81,6 +84,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
|
||||
engine.POST("/api/config/set", s.ConfigSetHandle)
|
||||
engine.GET("/api/config/chat-roles/get", s.GetChatRoles)
|
||||
engine.POST("api/config/token/add", s.AddToken)
|
||||
engine.POST("api/config/token/set", s.SetToken)
|
||||
engine.POST("api/config/token/remove", s.RemoveToken)
|
||||
engine.POST("api/config/apikey/add", s.AddApiKey)
|
||||
engine.POST("api/config/apikey/remove", s.RemoveApiKey)
|
||||
@ -182,10 +186,8 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
||||
|
||||
// WebSocket 连接请求验证
|
||||
if c.Request.URL.Path == "/api/chat" {
|
||||
tokenName := c.Query("token")
|
||||
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
|
||||
// 每个令牌只能连接一次
|
||||
//delete(s.WsSession, tokenName)
|
||||
sessionId := c.Query("sessionId")
|
||||
if session, ok := s.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
|
||||
c.Next()
|
||||
} else {
|
||||
c.Abort()
|
||||
@ -210,9 +212,9 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func (s *Server) GetSessionHandle(c *gin.Context) {
|
||||
tokenName := c.GetHeader(types.TokenName)
|
||||
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: addr})
|
||||
sessionId := c.GetHeader(types.TokenName)
|
||||
if session, ok := s.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.NotAuthorized,
|
||||
@ -243,7 +245,7 @@ func (s *Server) LoginHandle(c *gin.Context) {
|
||||
logger.Error("Error for save session: ", err)
|
||||
}
|
||||
// 记录客户端 IP 地址
|
||||
s.WsSession[sessionId] = c.ClientIP()
|
||||
s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Token: token, SessionId: sessionId}
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId})
|
||||
}
|
||||
|
||||
|
@ -33,6 +33,13 @@ type ChatRole struct {
|
||||
Enable bool `json:"enable"` // 是否启用被启用
|
||||
}
|
||||
|
||||
// ChatSession 聊天会话对象
|
||||
type ChatSession struct {
|
||||
SessionId string `json:"session_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
Token string `json:"token"` // 当前登录的 token
|
||||
}
|
||||
|
||||
func GetDefaultChatRole() map[string]ChatRole {
|
||||
return map[string]ChatRole{
|
||||
"gpt": {
|
||||
@ -108,7 +115,7 @@ func GetDefaultChatRole() map[string]ChatRole {
|
||||
},
|
||||
HelloMsg: "你好,我是中颂福的销售代表颂福。中颂福酒,好喝不上头,是人民的福酒。",
|
||||
Icon: "images/avatar/seller.jpg",
|
||||
Enable: false,
|
||||
Enable: true,
|
||||
},
|
||||
|
||||
"english_trainer": {
|
||||
|
Loading…
Reference in New Issue
Block a user