diff --git a/README.md b/README.md new file mode 100644 index 00000000..0c08ab5f --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# Wechat-GPT + +基于 ChatGPT 的聊天应用 + +## TODOLIST + +* [ ] 使用 level DB 保存用户聊天的上下文 +* [ ] 使用 MySQL 保存用户的聊天的历史记录 +* [ ] 用户聊天鉴权 + diff --git a/server/chat_handler.go b/server/chat_handler.go index 25efaec2..4762d99d 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "openai/types" + "strings" "time" ) @@ -52,8 +53,9 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { Stream: true, } var history []types.Message - if v, ok := s.History[userId]; ok { + if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext { history = v + //logger.Infof("上下文历史消息:%+v", history) } else { history = make([]types.Message, 0) } @@ -62,22 +64,22 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { Content: text, }) - logger.Info("上下文历史消息:%+v", s.History[userId]) requestBody, err := json.Marshal(r) if err != nil { return err } - request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody)) + request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody)) if err != nil { return err } // TODO: API KEY 负载均衡 rand.Seed(time.Now().UnixNano()) - index := rand.Intn(len(s.Config.OpenAi.ApiKeys)) + index := rand.Intn(len(s.Config.Chat.ApiKeys)) + logger.Infof("Use API KEY: %s", s.Config.Chat.ApiKeys[index]) request.Header.Add("Content-Type", "application/json") - request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index])) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.Chat.ApiKeys[index])) uri := url.URL{} proxy, _ := uri.Parse(s.Config.ProxyURL) @@ -104,11 +106,12 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { for { line, err := reader.ReadString('\n') if err != nil && err != io.EOF { - fmt.Println(err) + logger.Error(err) break } if line == "" { + replyMessage(types.WsMessage{Type: types.WsEnd}, ws) break } else if len(line) < 20 { continue @@ -116,29 +119,47 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil { - fmt.Println(err) + logger.Error(err) continue } // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role + replyMessage(types.WsMessage{Type: types.WsStart}, ws) continue - } else { - contents = append(contents, responseBody.Choices[0].Delta.Content) - } - // 推送消息到客户端 - err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content)) - if err != nil { - logger.Error(err) - } - fmt.Print(responseBody.Choices[0].Delta.Content) - if responseBody.Choices[0].FinishReason != "" { + } else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了 + replyMessage(types.WsMessage{Type: types.WsEnd}, ws) break + } else { + content := responseBody.Choices[0].Delta.Content + contents = append(contents, content) + replyMessage(types.WsMessage{ + Type: types.WsMiddle, + Content: responseBody.Choices[0].Delta.Content, + }, ws) } } // 追加历史消息 + history = append(history, types.Message{ + Role: "user", + Content: text, + }) + message.Content = strings.Join(contents, "") history = append(history, message) s.History[userId] = history return nil } + +// 回复客户端消息 +func replyMessage(message types.WsMessage, client Client) { + msg, err := json.Marshal(message) + if err != nil { + logger.Errorf("Error for decoding json data: %v", err.Error()) + return + } + err = client.(*WsClient).Send(msg) + if err != nil { + logger.Errorf("Error for reply message: %v", err.Error()) + } +} diff --git a/server/config_handler.go b/server/config_handler.go index a33b6778..bb6d80f9 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -19,7 +19,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { } // API key if key, ok := data["api_key"]; ok && len(key) > 20 { - s.Config.OpenAi.ApiKeys = append(s.Config.OpenAi.ApiKeys, key) + s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key) } // proxy URL @@ -29,7 +29,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { // Model if model, ok := data["model"]; ok { - s.Config.OpenAi.Model = model + s.Config.Chat.Model = model } // Temperature @@ -42,7 +42,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { }) return } - s.Config.OpenAi.Temperature = float32(v) + s.Config.Chat.Temperature = float32(v) } // max_tokens @@ -55,8 +55,20 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { }) return } - s.Config.OpenAi.MaxTokens = v + s.Config.Chat.MaxTokens = v + } + // enable Context + if enableContext, ok := data["enable_context"]; ok { + v, err := strconv.ParseBool(enableContext) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{ + Code: types.InvalidParams, + Message: "enable_context must be a bool parameter", + }) + return + } + s.Config.Chat.EnableContext = v } // 保存配置文件 diff --git a/types/config.go b/types/config.go index 90ebc285..a0dbac50 100644 --- a/types/config.go +++ b/types/config.go @@ -2,9 +2,9 @@ package types import ( "bytes" - "fmt" "github.com/BurntSushi/toml" "net/http" + logger2 "openai/logger" "openai/utils" "os" ) @@ -13,16 +13,17 @@ type Config struct { Listen string Session Session ProxyURL string - OpenAi OpenAi + Chat Chat } -// OpenAi configs struct -type OpenAi struct { - ApiURL string - ApiKeys []string - Model string - Temperature float32 - MaxTokens int +// Chat configs struct +type Chat struct { + ApiURL string + ApiKeys []string + Model string + Temperature float32 + MaxTokens int + EnableContext bool // 是否保持聊天上下文 } // Session configs struct @@ -51,21 +52,24 @@ func NewDefaultConfig() *Config { HttpOnly: false, SameSite: http.SameSiteLaxMode, }, - OpenAi: OpenAi{ - ApiURL: "https://api.openai.com/v1/chat/completions", - ApiKeys: []string{""}, - Model: "gpt-3.5-turbo", - MaxTokens: 1024, - Temperature: 1.0, + Chat: Chat{ + ApiURL: "https://api.openai.com/v1/chat/completions", + ApiKeys: []string{""}, + Model: "gpt-3.5-turbo", + MaxTokens: 1024, + Temperature: 1.0, + EnableContext: true, }, } } +var logger = logger2.GetLogger() + func LoadConfig(configFile string) (*Config, error) { var config *Config _, err := os.Stat(configFile) if err != nil { - fmt.Errorf("Error: %s", err.Error()) + logger.Errorf("Error open config file: %s", err.Error()) config = NewDefaultConfig() // save config err := SaveConfig(config, configFile) @@ -76,7 +80,6 @@ func LoadConfig(configFile string) (*Config, error) { return config, nil } _, err = toml.DecodeFile(configFile, &config) - fmt.Println(config) if err != nil { return nil, err } diff --git a/types/web.go b/types/web.go index 858b51fe..10eaac84 100644 --- a/types/web.go +++ b/types/web.go @@ -10,11 +10,18 @@ type BizVo struct { Data interface{} `json:"data,omitempty"` } -// WsVo Websocket 信息 VO -type WsVo struct { - Stop bool - Content string +// WsMessage Websocket message +type WsMessage struct { + Type WsMsgType `json:"type"` // 消息类别,start, end + Content string `json:"content"` } +type WsMsgType string + +const ( + WsStart = WsMsgType("start") + WsMiddle = WsMsgType("middle") + WsEnd = WsMsgType("end") +) type BizCode int diff --git a/web/README.md b/web/README.md deleted file mode 100644 index 5ac2fc15..00000000 --- a/web/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# qark-webssh - -## Project setup -``` -npm install -``` - -### Compiles and hot-reloads for development -``` -npm run serve -``` - -### Compiles and minifies for production -``` -npm run build -``` - -### Lints and fixes files -``` -npm run lint -``` diff --git a/web/src/components/ChatReply.vue b/web/src/components/ChatReply.vue index c77e5db3..9031c4b9 100644 --- a/web/src/components/ChatReply.vue +++ b/web/src/components/ChatReply.vue @@ -6,7 +6,10 @@