mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
使用 leveldb 存储用户 token 和聊天记录
This commit is contained in:
parent
6a38de7eaa
commit
a6bab7b12d
1
.gitignore
vendored
1
.gitignore
vendored
@ -25,3 +25,4 @@ dist-ssr
|
||||
tmp
|
||||
bin
|
||||
web/.env.development
|
||||
data
|
||||
|
@ -5,13 +5,13 @@
|
||||
## TODOLIST
|
||||
|
||||
* [ ] 使用 level DB 保存用户聊天的上下文
|
||||
* [ ] 使用 MySQL 保存用户的聊天的历史记录
|
||||
* [x] 用户聊天鉴权,设置口令模式
|
||||
* [ ] 定期清理不在线的会话 sessionID 和聊天上下文记录
|
||||
* [ ] 给 Token 设置调用次数
|
||||
* [x] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封
|
||||
* [ ] 角色设定,预设一些角色,比如程序员,客服,作家,老师,艺术家...
|
||||
* [x] 角色设定,预设一些角色,比如程序员,客服,作家,老师,艺术家...
|
||||
* [x] markdown 语法解析和代码高亮
|
||||
* [ ] 用户配置界面,配置用户的使用习惯
|
||||
* [ ] 用户配置界面,配置用户的使用习惯,可以让用户配置自己的 API KEY,调用自己的 API Key,将不记 Token 的使用次数
|
||||
* [ ] 嵌入 AI 绘画功能,支持根据描述词生成图片
|
||||
* [ ] 增加 Buffer 层,将相同的问题答案缓存起来,相同问题直接返回答案。
|
||||
|
||||
|
@ -4,7 +4,7 @@ build_name: runner-build
|
||||
build_log: runner-build-errors.log
|
||||
valid_ext: .go, .tpl, .tmpl, .html
|
||||
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
|
||||
ignored: assets, tmp, web, .git, .idea, test
|
||||
ignored: assets, tmp, web, .git, .idea, test, data
|
||||
build_delay: 600
|
||||
colors: 1
|
||||
log_color_main: cyan
|
||||
|
2
go.mod
2
go.mod
@ -17,6 +17,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.17.0 // indirect
|
||||
github.com/go-playground/validator/v10 v10.4.1 // indirect
|
||||
github.com/golang/protobuf v1.3.3 // indirect
|
||||
github.com/golang/snappy v0.0.1 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
@ -25,6 +26,7 @@ require (
|
||||
github.com/mattn/go-isatty v0.0.12 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 // indirect
|
||||
github.com/syndtr/goleveldb v1.0.0 // indirect
|
||||
github.com/ugorji/go/codec v1.1.7 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.6.0 // indirect
|
||||
|
17
go.sum
17
go.sum
@ -5,6 +5,7 @@ github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
|
||||
github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
@ -19,8 +20,12 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87
|
||||
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
|
||||
github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE=
|
||||
github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
@ -30,6 +35,7 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
@ -47,6 +53,9 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@ -56,6 +65,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
||||
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
|
||||
github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs=
|
||||
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
||||
@ -75,12 +86,15 @@ golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqt
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -105,6 +119,9 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
@ -23,14 +23,17 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
logger.Fatal(err)
|
||||
return
|
||||
}
|
||||
token := c.Query("token")
|
||||
sessionId := c.Query("sessionId")
|
||||
role := c.Query("role")
|
||||
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
||||
client := NewWsClient(ws)
|
||||
// TODO: 这里需要先判断一下角色是否存在,并且角色是被启用的
|
||||
if !s.ChatRoles[role].Enable { // 角色未启用
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 发送打招呼信息
|
||||
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.Config.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
|
||||
go func() {
|
||||
for {
|
||||
@ -43,7 +46,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||
err = s.sendMessage(token, role, string(message), client)
|
||||
err = s.sendMessage(sessionId, role, string(message), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@ -64,9 +67,13 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext {
|
||||
context = v
|
||||
} else {
|
||||
context = s.Config.ChatRoles[role].Context
|
||||
context = s.ChatRoles[role].Context
|
||||
}
|
||||
|
||||
if s.DebugMode {
|
||||
logger.Infof("会话上下文:%+v", context)
|
||||
}
|
||||
|
||||
r.Messages = append(context, types.Message{
|
||||
Role: "user",
|
||||
Content: text,
|
||||
@ -179,6 +186,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
context = append(context, message)
|
||||
// 保存上下文
|
||||
s.ChatContext[key] = context
|
||||
_ = response.Body.Close() // 关闭资源
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@ -86,6 +86,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
}
|
||||
|
||||
// AddToken 添加 Token
|
||||
func (s *Server) AddToken(c *gin.Context) {
|
||||
var data map[string]string
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
@ -95,22 +96,38 @@ func (s *Server) AddToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if token, ok := data["token"]; ok {
|
||||
if !utils.ContainsItem(s.Config.Tokens, token) {
|
||||
s.Config.Tokens = append(s.Config.Tokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
var name = data["name"]
|
||||
var maxCalls = data["max_calls"]
|
||||
if name == "" || maxCalls == "" {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
|
||||
n, err := strconv.Atoi(maxCalls)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "enable_auth must be a int parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var tokens = GetTokens()
|
||||
if utils.ContainToken(tokens, name) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"})
|
||||
return
|
||||
}
|
||||
|
||||
err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||
}
|
||||
|
||||
// RemoveToken 删除 Token
|
||||
func (s *Server) RemoveToken(c *gin.Context) {
|
||||
var data map[string]string
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
@ -121,22 +138,14 @@ func (s *Server) RemoveToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if token, ok := data["token"]; ok {
|
||||
for i, v := range s.Config.Tokens {
|
||||
if v == token {
|
||||
s.Config.Tokens = append(s.Config.Tokens[:i], s.Config.Tokens[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = RemoveToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||
}
|
||||
|
||||
// AddApiKey 添加一个 API key
|
||||
@ -153,7 +162,7 @@ func (s *Server) AddApiKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@ -181,7 +190,7 @@ func (s *Server) RemoveApiKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@ -196,22 +205,22 @@ func (s *Server) ListApiKeys(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) GetChatRoles(c *gin.Context) {
|
||||
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||
var roles = make([]interface{}, 0)
|
||||
for _, k := range rolesOrder {
|
||||
if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
|
||||
roles = append(roles, struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
}{
|
||||
Key: v.Key,
|
||||
Name: v.Name,
|
||||
Icon: v.Icon,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
|
||||
//var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||
//var roles = make([]interface{}, 0)
|
||||
//for _, k := range rolesOrder {
|
||||
// if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
|
||||
// roles = append(roles, struct {
|
||||
// Key string `json:"key"`
|
||||
// Name string `json:"name"`
|
||||
// Icon string `json:"icon"`
|
||||
// }{
|
||||
// Key: v.Key,
|
||||
// Name: v.Name,
|
||||
// Icon: v.Icon,
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
|
||||
}
|
||||
|
||||
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
|
||||
@ -229,39 +238,39 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
role := s.Config.ChatRoles[key]
|
||||
if enable, ok := data["enable"]; ok {
|
||||
v, err := strconv.ParseBool(enable)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{
|
||||
Code: types.InvalidParams,
|
||||
Message: "enable must be a bool parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
role.Enable = v
|
||||
}
|
||||
//role := s.Config.ChatRoles[key]
|
||||
//if enable, ok := data["enable"]; ok {
|
||||
// v, err := strconv.ParseBool(enable)
|
||||
// if err != nil {
|
||||
// c.JSON(http.StatusOK, types.BizVo{
|
||||
// Code: types.InvalidParams,
|
||||
// Message: "enable must be a bool parameter",
|
||||
// })
|
||||
// return
|
||||
// }
|
||||
// role.Enable = v
|
||||
//}
|
||||
|
||||
if name, ok := data["name"]; ok {
|
||||
role.Name = name
|
||||
}
|
||||
if helloMsg, ok := data["hello_msg"]; ok {
|
||||
role.HelloMsg = helloMsg
|
||||
}
|
||||
if icon, ok := data["icon"]; ok {
|
||||
role.Icon = icon
|
||||
}
|
||||
|
||||
s.Config.ChatRoles[key] = role
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
//if name, ok := data["name"]; ok {
|
||||
// role.Name = name
|
||||
//}
|
||||
//if helloMsg, ok := data["hello_msg"]; ok {
|
||||
// role.HelloMsg = helloMsg
|
||||
//}
|
||||
//if icon, ok := data["icon"]; ok {
|
||||
// role.Icon = icon
|
||||
//}
|
||||
//
|
||||
//s.Config.ChatRoles[key] = role
|
||||
//
|
||||
//// 保存配置文件
|
||||
//err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
//if err != nil {
|
||||
// c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
|
||||
}
|
||||
|
||||
// AddProxy 添加一个代理
|
||||
@ -275,13 +284,13 @@ func (s *Server) AddProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
if proxy, ok := data["proxy"]; ok {
|
||||
if !utils.ContainsItem(s.Config.ProxyURL, proxy) {
|
||||
if !utils.ContainsStr(s.Config.ProxyURL, proxy) {
|
||||
s.Config.ProxyURL = append(s.Config.ProxyURL, proxy)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
@ -309,7 +318,7 @@ func (s *Server) RemoveProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置文件
|
||||
err = types.SaveConfig(s.Config, s.ConfigPath)
|
||||
err = utils.SaveConfig(s.Config, s.ConfigPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
|
||||
return
|
||||
|
61
server/db.go
Normal file
61
server/db.go
Normal file
@ -0,0 +1,61 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"openai/types"
|
||||
"openai/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenPrefix = "chat/tokens/"
|
||||
ChatRolePrefix = "chat/roles/"
|
||||
ChatHistoryPrefix = "chat/history/"
|
||||
)
|
||||
|
||||
var db *utils.LevelDB
|
||||
|
||||
func init() {
|
||||
leveldb, err := utils.NewLevelDB("data")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db = leveldb
|
||||
}
|
||||
|
||||
// GetTokens 获取 token 信息
|
||||
// chat/tokens
|
||||
func GetTokens() []types.Token {
|
||||
items := db.Search(TokenPrefix)
|
||||
var tokens = make([]types.Token, 0)
|
||||
for _, v := range items {
|
||||
var token types.Token
|
||||
err := json.Unmarshal([]byte(v), &token)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func PutToken(token types.Token) error {
|
||||
key := TokenPrefix + token.Name
|
||||
return db.Put(key, token)
|
||||
}
|
||||
|
||||
func RemoveToken(token string) error {
|
||||
key := TokenPrefix + token
|
||||
return db.Delete(key)
|
||||
}
|
||||
|
||||
// GetChatRoles 获取聊天角色
|
||||
// chat/roles
|
||||
func GetChatRoles() map[string]types.ChatRole {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatHistory 获取聊天历史记录
|
||||
// chat/history/{token}/{role}
|
||||
func GetChatHistory() []types.Message {
|
||||
return nil
|
||||
}
|
@ -39,28 +39,33 @@ type Server struct {
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
WsSession map[string]string
|
||||
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
|
||||
DebugMode bool // 是否开启调试模式
|
||||
ChatRoles map[string]types.ChatRole // 保存预设角色信息
|
||||
}
|
||||
|
||||
func NewServer(configPath string) (*Server, error) {
|
||||
// load service configs
|
||||
config, err := types.LoadConfig(configPath)
|
||||
if config.ChatRoles == nil {
|
||||
config.ChatRoles = types.GetDefaultChatRole()
|
||||
}
|
||||
config, err := utils.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles := GetChatRoles()
|
||||
if roles == nil {
|
||||
roles = types.GetDefaultChatRole()
|
||||
}
|
||||
return &Server{
|
||||
Config: config,
|
||||
ConfigPath: configPath,
|
||||
ChatContext: make(map[string][]types.Message, 16),
|
||||
WsSession: make(map[string]string),
|
||||
ApiKeyAccessStat: make(map[string]int64),
|
||||
ChatRoles: roles,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
|
||||
s.DebugMode = debug
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
engine := gin.Default()
|
||||
if debug {
|
||||
@ -225,7 +230,7 @@ func (s *Server) LoginHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
token := data["token"]
|
||||
if !utils.ContainsItem(s.Config.Tokens, token) {
|
||||
if !utils.ContainToken(GetTokens(), token) {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid token"})
|
||||
return
|
||||
}
|
||||
|
20
test/test.go
20
test/test.go
@ -2,10 +2,26 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"openai/utils"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var data = make(map[string]string)
|
||||
fmt.Println(data["key"] == "")
|
||||
// leveldb 测试
|
||||
db, err := utils.NewLevelDB("data")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Put("name", "xiaoming")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
name, err := db.Get("name")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println("name: ", name)
|
||||
}
|
||||
|
@ -1,12 +1,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/BurntSushi/toml"
|
||||
"net/http"
|
||||
logger2 "openai/logger"
|
||||
"openai/utils"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@ -15,8 +10,12 @@ type Config struct {
|
||||
ProxyURL []string
|
||||
Chat Chat
|
||||
EnableAuth bool // 是否开启鉴权
|
||||
Tokens []string // 授权的白名单列表 TODO: 后期要存储到 LevelDB 或者 Mysql 数据库
|
||||
ChatRoles map[string]ChatRole // 保存预设角色信息
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
Name string `json:"name"`
|
||||
MaxCalls int `json:"max_calls"` // 最多调用次数,如果为 0 则表示不限制
|
||||
RemainingCalls int `json:"remaining_calls"` // 剩余调用次数
|
||||
}
|
||||
|
||||
// Chat configs struct
|
||||
@ -40,65 +39,3 @@ type Session struct {
|
||||
HttpOnly bool
|
||||
SameSite http.SameSite
|
||||
}
|
||||
|
||||
func NewDefaultConfig() *Config {
|
||||
return &Config{
|
||||
Listen: "0.0.0.0:5678",
|
||||
ProxyURL: make([]string, 0),
|
||||
|
||||
Session: Session{
|
||||
SecretKey: utils.RandString(64),
|
||||
Name: "CHAT_SESSION_ID",
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
Secure: true,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
},
|
||||
Chat: Chat{
|
||||
ApiURL: "https://api.openai.com/v1/chat/completions",
|
||||
ApiKeys: []string{""},
|
||||
Model: "gpt-3.5-turbo",
|
||||
MaxTokens: 1024,
|
||||
Temperature: 0.9,
|
||||
EnableContext: true,
|
||||
},
|
||||
EnableAuth: true,
|
||||
ChatRoles: GetDefaultChatRole(),
|
||||
}
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func LoadConfig(configFile string) (*Config, error) {
|
||||
var config *Config
|
||||
_, err := os.Stat(configFile)
|
||||
if err != nil {
|
||||
logger.Errorf("Error open config file: %s", err.Error())
|
||||
config = NewDefaultConfig()
|
||||
// save config
|
||||
err := SaveConfig(config, configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
_, err = toml.DecodeFile(configFile, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, err
|
||||
}
|
||||
|
||||
func SaveConfig(config *Config, configFile string) error {
|
||||
buf := new(bytes.Buffer)
|
||||
encoder := toml.NewEncoder(buf)
|
||||
if err := encoder.Encode(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(configFile, buf.Bytes(), 0644)
|
||||
}
|
||||
|
71
utils/config.go
Normal file
71
utils/config.go
Normal file
@ -0,0 +1,71 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/BurntSushi/toml"
|
||||
"net/http"
|
||||
logger2 "openai/logger"
|
||||
"openai/types"
|
||||
"os"
|
||||
)
|
||||
|
||||
func NewDefaultConfig() *types.Config {
|
||||
return &types.Config{
|
||||
Listen: "0.0.0.0:5678",
|
||||
ProxyURL: make([]string, 0),
|
||||
|
||||
Session: types.Session{
|
||||
SecretKey: RandString(64),
|
||||
Name: "CHAT_SESSION_ID",
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
Secure: true,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
},
|
||||
Chat: types.Chat{
|
||||
ApiURL: "https://api.openai.com/v1/chat/completions",
|
||||
ApiKeys: []string{""},
|
||||
Model: "gpt-3.5-turbo",
|
||||
MaxTokens: 1024,
|
||||
Temperature: 0.9,
|
||||
EnableContext: true,
|
||||
},
|
||||
EnableAuth: true,
|
||||
}
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func LoadConfig(configFile string) (*types.Config, error) {
|
||||
var config *types.Config
|
||||
_, err := os.Stat(configFile)
|
||||
if err != nil {
|
||||
logger.Errorf("Error open config file: %s", err.Error())
|
||||
config = NewDefaultConfig()
|
||||
// save config
|
||||
err := SaveConfig(config, configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
_, err = toml.DecodeFile(configFile, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, err
|
||||
}
|
||||
|
||||
func SaveConfig(config *types.Config, configFile string) error {
|
||||
buf := new(bytes.Buffer)
|
||||
encoder := toml.NewEncoder(buf)
|
||||
if err := encoder.Encode(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(configFile, buf.Bytes(), 0644)
|
||||
}
|
63
utils/leveldb.go
Normal file
63
utils/leveldb.go
Normal file
@ -0,0 +1,63 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/syndtr/goleveldb/leveldb"
|
||||
"github.com/syndtr/goleveldb/leveldb/util"
|
||||
)
|
||||
|
||||
type LevelDB struct {
|
||||
driver *leveldb.DB
|
||||
}
|
||||
|
||||
func NewLevelDB(path string) (*LevelDB, error) {
|
||||
db, err := leveldb.OpenFile(path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LevelDB{
|
||||
driver: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *LevelDB) Put(key string, value interface{}) error {
|
||||
bytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.driver.Put([]byte(key), bytes, nil)
|
||||
}
|
||||
|
||||
func (db *LevelDB) Get(key string) (interface{}, error) {
|
||||
bytes, err := db.driver.Get([]byte(key), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var value interface{}
|
||||
err = json.Unmarshal(bytes, &value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (db *LevelDB) Search(prefix string) []string {
|
||||
var items = make([]string, 0)
|
||||
iter := db.driver.NewIterator(util.BytesPrefix([]byte(prefix)), nil)
|
||||
for iter.Next() {
|
||||
items = append(items, string(iter.Value()))
|
||||
}
|
||||
iter.Release()
|
||||
return items
|
||||
}
|
||||
|
||||
func (db *LevelDB) Delete(key string) error {
|
||||
return db.driver.Delete([]byte(key), nil)
|
||||
}
|
||||
|
||||
// Close release resources
|
||||
func (db *LevelDB) Close() error {
|
||||
return db.driver.Close()
|
||||
}
|
@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"openai/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -31,7 +32,7 @@ func IsBlank(value string) bool {
|
||||
return len(strings.TrimSpace(value)) == 0
|
||||
}
|
||||
|
||||
func ContainsItem(slice []string, item string) bool {
|
||||
func ContainsStr(slice []string, item string) bool {
|
||||
for _, e := range slice {
|
||||
if e == item {
|
||||
return true
|
||||
@ -39,3 +40,12 @@ func ContainsItem(slice []string, item string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ContainToken(slice []types.Token, token string) bool {
|
||||
for _, e := range slice {
|
||||
if e.Name == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -195,8 +195,8 @@ export default defineComponent({
|
||||
// 创建 socket 会话连接
|
||||
connect: function () {
|
||||
// 初始化 WebSocket 对象
|
||||
const token = getSessionId();
|
||||
const socket = new WebSocket(process.env.VUE_APP_WS_HOST + `/api/chat?token=${token}&role=${this.role}`);
|
||||
const sessionId = getSessionId();
|
||||
const socket = new WebSocket(process.env.VUE_APP_WS_HOST + `/api/chat?sessionId=${sessionId}&role=${this.role}`);
|
||||
socket.addEventListener('open', () => {
|
||||
// 获取聊天角色
|
||||
httpGet("/api/config/chat-roles/get").then((res) => {
|
||||
@ -219,11 +219,6 @@ export default defineComponent({
|
||||
reader.readAsText(event.data, "UTF-8");
|
||||
reader.onload = () => {
|
||||
const data = JSON.parse(String(reader.result));
|
||||
// 过滤掉重复的打招呼信息
|
||||
if (data['is_hello_msg'] && this.chatData.length > 1) {
|
||||
return
|
||||
}
|
||||
|
||||
if (data.type === 'start') {
|
||||
this.chatData.push({
|
||||
type: "reply",
|
||||
|
Loading…
Reference in New Issue
Block a user