diff --git a/.gitignore b/.gitignore index 8ac5ced4..d67820c0 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ dist-ssr *.sw? tmp bin -web/.env.development +web/.env.development +data diff --git a/README.md b/README.md index dbb32a81..6ab730c2 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,13 @@ ## TODOLIST * [ ] 使用 level DB 保存用户聊天的上下文 -* [ ] 使用 MySQL 保存用户的聊天的历史记录 * [x] 用户聊天鉴权,设置口令模式 * [ ] 定期清理不在线的会话 sessionID 和聊天上下文记录 +* [ ] 给 Token 设置调用次数 * [x] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封 -* [ ] 角色设定,预设一些角色,比如程序员,客服,作家,老师,艺术家... +* [x] 角色设定,预设一些角色,比如程序员,客服,作家,老师,艺术家... * [x] markdown 语法解析和代码高亮 -* [ ] 用户配置界面,配置用户的使用习惯 +* [ ] 用户配置界面,配置用户的使用习惯,可以让用户配置自己的 API KEY,调用自己的 API Key,将不记 Token 的使用次数 * [ ] 嵌入 AI 绘画功能,支持根据描述词生成图片 * [ ] 增加 Buffer 层,将相同的问题答案缓存起来,相同问题直接返回答案。 diff --git a/fresh.conf b/fresh.conf index bc79c11a..aac77afe 100644 --- a/fresh.conf +++ b/fresh.conf @@ -4,7 +4,7 @@ build_name: runner-build build_log: runner-build-errors.log valid_ext: .go, .tpl, .tmpl, .html no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue -ignored: assets, tmp, web, .git, .idea, test +ignored: assets, tmp, web, .git, .idea, test, data build_delay: 600 colors: 1 log_color_main: cyan diff --git a/go.mod b/go.mod index 210aa20e..59a9c58b 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/go-playground/universal-translator v0.17.0 // indirect github.com/go-playground/validator/v10 v10.4.1 // indirect github.com/golang/protobuf v1.3.3 // indirect + github.com/golang/snappy v0.0.1 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect @@ -25,6 +26,7 @@ require ( github.com/mattn/go-isatty v0.0.12 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect + github.com/syndtr/goleveldb v1.0.0 // indirect github.com/ugorji/go/codec v1.1.7 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect diff --git a/go.sum b/go.sum index 60530692..5aef75a6 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE= github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -19,8 +20,12 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= @@ -30,6 +35,7 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -47,6 +53,9 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -56,6 +65,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= +github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= @@ -75,12 +86,15 @@ golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqt golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -105,6 +119,9 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/server/chat_handler.go b/server/chat_handler.go index 5aa708d7..e0ac9607 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -23,14 +23,17 @@ func (s *Server) ChatHandle(c *gin.Context) { logger.Fatal(err) return } - token := c.Query("token") + sessionId := c.Query("sessionId") role := c.Query("role") logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr) client := NewWsClient(ws) - // TODO: 这里需要先判断一下角色是否存在,并且角色是被启用的 + if !s.ChatRoles[role].Enable { // 角色未启用 + c.Abort() + return + } // 发送打招呼信息 replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client) - replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.Config.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client) + replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client) replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client) go func() { for { @@ -43,7 +46,7 @@ func (s *Server) ChatHandle(c *gin.Context) { logger.Info("Receive a message: ", string(message)) // TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录 - err = s.sendMessage(token, role, string(message), client) + err = s.sendMessage(sessionId, role, string(message), client) if err != nil { logger.Error(err) } @@ -64,9 +67,13 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext { context = v } else { - context = s.Config.ChatRoles[role].Context + context = s.ChatRoles[role].Context } - logger.Infof("会话上下文:%+v", context) + + if s.DebugMode { + logger.Infof("会话上下文:%+v", context) + } + r.Messages = append(context, types.Message{ Role: "user", Content: text, @@ -179,6 +186,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie context = append(context, message) // 保存上下文 s.ChatContext[key] = context + _ = response.Body.Close() // 关闭资源 return nil } diff --git a/server/config_handler.go b/server/config_handler.go index f0a51c6b..aa76bca8 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -77,7 +77,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { } // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) + err = utils.SaveConfig(s.Config, s.ConfigPath) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) return @@ -86,6 +86,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) } +// AddToken 添加 Token func (s *Server) AddToken(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) @@ -95,22 +96,38 @@ func (s *Server) AddToken(c *gin.Context) { return } - if token, ok := data["token"]; ok { - if !utils.ContainsItem(s.Config.Tokens, token) { - s.Config.Tokens = append(s.Config.Tokens, token) - } - } - - // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) + var name = data["name"] + var maxCalls = data["max_calls"] + if name == "" || maxCalls == "" { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens}) + n, err := strconv.Atoi(maxCalls) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{ + Code: types.InvalidParams, + Message: "enable_auth must be a int parameter", + }) + return + } + + var tokens = GetTokens() + if utils.ContainToken(tokens, name) { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"}) + return + } + + err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n}) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) + return + } + + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) } +// RemoveToken 删除 Token func (s *Server) RemoveToken(c *gin.Context) { var data map[string]string err := json.NewDecoder(c.Request.Body).Decode(&data) @@ -121,22 +138,14 @@ func (s *Server) RemoveToken(c *gin.Context) { } if token, ok := data["token"]; ok { - for i, v := range s.Config.Tokens { - if v == token { - s.Config.Tokens = append(s.Config.Tokens[:i], s.Config.Tokens[i+1:]...) - break - } + err = RemoveToken(token) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) + return } } - // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) - return - } - - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) } // AddApiKey 添加一个 API key @@ -153,7 +162,7 @@ func (s *Server) AddApiKey(c *gin.Context) { } // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) + err = utils.SaveConfig(s.Config, s.ConfigPath) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) return @@ -181,7 +190,7 @@ func (s *Server) RemoveApiKey(c *gin.Context) { } // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) + err = utils.SaveConfig(s.Config, s.ConfigPath) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) return @@ -196,22 +205,22 @@ func (s *Server) ListApiKeys(c *gin.Context) { } func (s *Server) GetChatRoles(c *gin.Context) { - var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"} - var roles = make([]interface{}, 0) - for _, k := range rolesOrder { - if v, ok := s.Config.ChatRoles[k]; ok && v.Enable { - roles = append(roles, struct { - Key string `json:"key"` - Name string `json:"name"` - Icon string `json:"icon"` - }{ - Key: v.Key, - Name: v.Name, - Icon: v.Icon, - }) - } - } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles}) + //var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"} + //var roles = make([]interface{}, 0) + //for _, k := range rolesOrder { + // if v, ok := s.Config.ChatRoles[k]; ok && v.Enable { + // roles = append(roles, struct { + // Key string `json:"key"` + // Name string `json:"name"` + // Icon string `json:"icon"` + // }{ + // Key: v.Key, + // Name: v.Name, + // Icon: v.Icon, + // }) + // } + //} + //c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles}) } // UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作 @@ -229,39 +238,39 @@ func (s *Server) UpdateChatRole(c *gin.Context) { return } - role := s.Config.ChatRoles[key] - if enable, ok := data["enable"]; ok { - v, err := strconv.ParseBool(enable) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "enable must be a bool parameter", - }) - return - } - role.Enable = v - } + //role := s.Config.ChatRoles[key] + //if enable, ok := data["enable"]; ok { + // v, err := strconv.ParseBool(enable) + // if err != nil { + // c.JSON(http.StatusOK, types.BizVo{ + // Code: types.InvalidParams, + // Message: "enable must be a bool parameter", + // }) + // return + // } + // role.Enable = v + //} - if name, ok := data["name"]; ok { - role.Name = name - } - if helloMsg, ok := data["hello_msg"]; ok { - role.HelloMsg = helloMsg - } - if icon, ok := data["icon"]; ok { - role.Icon = icon - } - - s.Config.ChatRoles[key] = role - - // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) - return - } - - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) + //if name, ok := data["name"]; ok { + // role.Name = name + //} + //if helloMsg, ok := data["hello_msg"]; ok { + // role.HelloMsg = helloMsg + //} + //if icon, ok := data["icon"]; ok { + // role.Icon = icon + //} + // + //s.Config.ChatRoles[key] = role + // + //// 保存配置文件 + //err = types.SaveConfig(s.Config, s.ConfigPath) + //if err != nil { + // c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) + // return + //} + // + //c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) } // AddProxy 添加一个代理 @@ -275,13 +284,13 @@ func (s *Server) AddProxy(c *gin.Context) { } if proxy, ok := data["proxy"]; ok { - if !utils.ContainsItem(s.Config.ProxyURL, proxy) { + if !utils.ContainsStr(s.Config.ProxyURL, proxy) { s.Config.ProxyURL = append(s.Config.ProxyURL, proxy) } } // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) + err = utils.SaveConfig(s.Config, s.ConfigPath) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) return @@ -309,7 +318,7 @@ func (s *Server) RemoveProxy(c *gin.Context) { } // 保存配置文件 - err = types.SaveConfig(s.Config, s.ConfigPath) + err = utils.SaveConfig(s.Config, s.ConfigPath) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) return diff --git a/server/db.go b/server/db.go new file mode 100644 index 00000000..81530c1b --- /dev/null +++ b/server/db.go @@ -0,0 +1,61 @@ +package server + +import ( + "encoding/json" + "openai/types" + "openai/utils" +) + +const ( + TokenPrefix = "chat/tokens/" + ChatRolePrefix = "chat/roles/" + ChatHistoryPrefix = "chat/history/" +) + +var db *utils.LevelDB + +func init() { + leveldb, err := utils.NewLevelDB("data") + if err != nil { + panic(err) + } + db = leveldb +} + +// GetTokens 获取 token 信息 +// chat/tokens +func GetTokens() []types.Token { + items := db.Search(TokenPrefix) + var tokens = make([]types.Token, 0) + for _, v := range items { + var token types.Token + err := json.Unmarshal([]byte(v), &token) + if err != nil { + continue + } + tokens = append(tokens, token) + } + return tokens +} + +func PutToken(token types.Token) error { + key := TokenPrefix + token.Name + return db.Put(key, token) +} + +func RemoveToken(token string) error { + key := TokenPrefix + token + return db.Delete(key) +} + +// GetChatRoles 获取聊天角色 +// chat/roles +func GetChatRoles() map[string]types.ChatRole { + return nil +} + +// GetChatHistory 获取聊天历史记录 +// chat/history/{token}/{role} +func GetChatHistory() []types.Message { + return nil +} diff --git a/server/server.go b/server/server.go index ba100fd8..78e2f876 100644 --- a/server/server.go +++ b/server/server.go @@ -38,29 +38,34 @@ type Server struct { // 保存 Websocket 会话 Token, 每个 Token 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API WsSession map[string]string - ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 + ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 + DebugMode bool // 是否开启调试模式 + ChatRoles map[string]types.ChatRole // 保存预设角色信息 } func NewServer(configPath string) (*Server, error) { // load service configs - config, err := types.LoadConfig(configPath) - if config.ChatRoles == nil { - config.ChatRoles = types.GetDefaultChatRole() - } + config, err := utils.LoadConfig(configPath) if err != nil { return nil, err } + roles := GetChatRoles() + if roles == nil { + roles = types.GetDefaultChatRole() + } return &Server{ Config: config, ConfigPath: configPath, ChatContext: make(map[string][]types.Message, 16), WsSession: make(map[string]string), ApiKeyAccessStat: make(map[string]int64), + ChatRoles: roles, }, nil } func (s *Server) Run(webRoot embed.FS, path string, debug bool) { + s.DebugMode = debug gin.SetMode(gin.ReleaseMode) engine := gin.Default() if debug { @@ -225,7 +230,7 @@ func (s *Server) LoginHandle(c *gin.Context) { return } token := data["token"] - if !utils.ContainsItem(s.Config.Tokens, token) { + if !utils.ContainToken(GetTokens(), token) { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"}) return } diff --git a/test/test.go b/test/test.go index 3ca6dfa5..b8042e94 100644 --- a/test/test.go +++ b/test/test.go @@ -2,10 +2,26 @@ package main import ( "fmt" + "openai/utils" ) func main() { - var data = make(map[string]string) - fmt.Println(data["key"] == "") + // leveldb 测试 + db, err := utils.NewLevelDB("data") + if err != nil { + panic(err) + } + defer db.Close() + err = db.Put("name", "xiaoming") + if err != nil { + panic(err) + } + + name, err := db.Get("name") + if err != nil { + panic(err) + } + + fmt.Println("name: ", name) } diff --git a/types/config.go b/types/config.go index f058f833..1dbc81a8 100644 --- a/types/config.go +++ b/types/config.go @@ -1,12 +1,7 @@ package types import ( - "bytes" - "github.com/BurntSushi/toml" "net/http" - logger2 "openai/logger" - "openai/utils" - "os" ) type Config struct { @@ -14,9 +9,13 @@ type Config struct { Session Session ProxyURL []string Chat Chat - EnableAuth bool // 是否开启鉴权 - Tokens []string // 授权的白名单列表 TODO: 后期要存储到 LevelDB 或者 Mysql 数据库 - ChatRoles map[string]ChatRole // 保存预设角色信息 + EnableAuth bool // 是否开启鉴权 +} + +type Token struct { + Name string `json:"name"` + MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制 + RemainingCalls int `json:"remaining_calls"` // 剩余调用次数 } // Chat configs struct @@ -40,65 +39,3 @@ type Session struct { HttpOnly bool SameSite http.SameSite } - -func NewDefaultConfig() *Config { - return &Config{ - Listen: "0.0.0.0:5678", - ProxyURL: make([]string, 0), - - Session: Session{ - SecretKey: utils.RandString(64), - Name: "CHAT_SESSION_ID", - Domain: "", - Path: "/", - MaxAge: 86400, - Secure: true, - HttpOnly: false, - SameSite: http.SameSiteLaxMode, - }, - Chat: Chat{ - ApiURL: "https://api.openai.com/v1/chat/completions", - ApiKeys: []string{""}, - Model: "gpt-3.5-turbo", - MaxTokens: 1024, - Temperature: 0.9, - EnableContext: true, - }, - EnableAuth: true, - ChatRoles: GetDefaultChatRole(), - } -} - -var logger = logger2.GetLogger() - -func LoadConfig(configFile string) (*Config, error) { - var config *Config - _, err := os.Stat(configFile) - if err != nil { - logger.Errorf("Error open config file: %s", err.Error()) - config = NewDefaultConfig() - // save config - err := SaveConfig(config, configFile) - if err != nil { - return nil, err - } - - return config, nil - } - _, err = toml.DecodeFile(configFile, &config) - if err != nil { - return nil, err - } - - return config, err -} - -func SaveConfig(config *Config, configFile string) error { - buf := new(bytes.Buffer) - encoder := toml.NewEncoder(buf) - if err := encoder.Encode(&config); err != nil { - return err - } - - return os.WriteFile(configFile, buf.Bytes(), 0644) -} diff --git a/utils/config.go b/utils/config.go new file mode 100644 index 00000000..0495d23a --- /dev/null +++ b/utils/config.go @@ -0,0 +1,71 @@ +package utils + +import ( + "bytes" + "github.com/BurntSushi/toml" + "net/http" + logger2 "openai/logger" + "openai/types" + "os" +) + +func NewDefaultConfig() *types.Config { + return &types.Config{ + Listen: "0.0.0.0:5678", + ProxyURL: make([]string, 0), + + Session: types.Session{ + SecretKey: RandString(64), + Name: "CHAT_SESSION_ID", + Domain: "", + Path: "/", + MaxAge: 86400, + Secure: true, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + }, + Chat: types.Chat{ + ApiURL: "https://api.openai.com/v1/chat/completions", + ApiKeys: []string{""}, + Model: "gpt-3.5-turbo", + MaxTokens: 1024, + Temperature: 0.9, + EnableContext: true, + }, + EnableAuth: true, + } +} + +var logger = logger2.GetLogger() + +func LoadConfig(configFile string) (*types.Config, error) { + var config *types.Config + _, err := os.Stat(configFile) + if err != nil { + logger.Errorf("Error open config file: %s", err.Error()) + config = NewDefaultConfig() + // save config + err := SaveConfig(config, configFile) + if err != nil { + return nil, err + } + + return config, nil + } + _, err = toml.DecodeFile(configFile, &config) + if err != nil { + return nil, err + } + + return config, err +} + +func SaveConfig(config *types.Config, configFile string) error { + buf := new(bytes.Buffer) + encoder := toml.NewEncoder(buf) + if err := encoder.Encode(&config); err != nil { + return err + } + + return os.WriteFile(configFile, buf.Bytes(), 0644) +} diff --git a/utils/leveldb.go b/utils/leveldb.go new file mode 100644 index 00000000..83d0a922 --- /dev/null +++ b/utils/leveldb.go @@ -0,0 +1,63 @@ +package utils + +import ( + "encoding/json" + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type LevelDB struct { + driver *leveldb.DB +} + +func NewLevelDB(path string) (*LevelDB, error) { + db, err := leveldb.OpenFile(path, nil) + if err != nil { + return nil, err + } + return &LevelDB{ + driver: db, + }, nil +} + +func (db *LevelDB) Put(key string, value interface{}) error { + bytes, err := json.Marshal(value) + if err != nil { + return err + } + return db.driver.Put([]byte(key), bytes, nil) +} + +func (db *LevelDB) Get(key string) (interface{}, error) { + bytes, err := db.driver.Get([]byte(key), nil) + if err != nil { + return nil, err + } + + var value interface{} + err = json.Unmarshal(bytes, &value) + if err != nil { + return nil, err + } + + return value, nil +} + +func (db *LevelDB) Search(prefix string) []string { + var items = make([]string, 0) + iter := db.driver.NewIterator(util.BytesPrefix([]byte(prefix)), nil) + for iter.Next() { + items = append(items, string(iter.Value())) + } + iter.Release() + return items +} + +func (db *LevelDB) Delete(key string) error { + return db.driver.Delete([]byte(key), nil) +} + +// Close release resources +func (db *LevelDB) Close() error { + return db.driver.Close() +} diff --git a/utils/utils.go b/utils/utils.go index cde0d879..5d903596 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "math/rand" + "openai/types" "strconv" "strings" "time" @@ -31,7 +32,7 @@ func IsBlank(value string) bool { return len(strings.TrimSpace(value)) == 0 } -func ContainsItem(slice []string, item string) bool { +func ContainsStr(slice []string, item string) bool { for _, e := range slice { if e == item { return true @@ -39,3 +40,12 @@ func ContainsItem(slice []string, item string) bool { } return false } + +func ContainToken(slice []types.Token, token string) bool { + for _, e := range slice { + if e.Name == token { + return true + } + } + return false +} diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index f5d79c6f..c1098f9f 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -195,8 +195,8 @@ export default defineComponent({ // 创建 socket 会话连接 connect: function () { // 初始化 WebSocket 对象 - const token = getSessionId(); - const socket = new WebSocket(process.env.VUE_APP_WS_HOST + `/api/chat?token=${token}&role=${this.role}`); + const sessionId = getSessionId(); + const socket = new WebSocket(process.env.VUE_APP_WS_HOST + `/api/chat?sessionId=${sessionId}&role=${this.role}`); socket.addEventListener('open', () => { // 获取聊天角色 httpGet("/api/config/chat-roles/get").then((res) => { @@ -219,11 +219,6 @@ export default defineComponent({ reader.readAsText(event.data, "UTF-8"); reader.onload = () => { const data = JSON.parse(String(reader.result)); - // 过滤掉重复的打招呼信息 - if (data['is_hello_msg'] && this.chatData.length > 1) { - return - } - if (data.type === 'start') { this.chatData.push({ type: "reply",