diff --git a/README.md b/README.md new file mode 100644 index 00000000..0c08ab5f --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# Wechat-GPT + +基于 ChatGPT 的聊天应用 + +## TODOLIST + +* [ ] 使用 level DB 保存用户聊天的上下文 +* [ ] 使用 MySQL 保存用户的聊天的历史记录 +* [ ] 用户聊天鉴权 + diff --git a/server/chat_handler.go b/server/chat_handler.go index 25efaec2..4762d99d 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "openai/types" + "strings" "time" ) @@ -52,8 +53,9 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { Stream: true, } var history []types.Message - if v, ok := s.History[userId]; ok { + if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext { history = v + //logger.Infof("上下文历史消息:%+v", history) } else { history = make([]types.Message, 0) } @@ -62,22 +64,22 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { Content: text, }) - logger.Info("上下文历史消息:%+v", s.History[userId]) requestBody, err := json.Marshal(r) if err != nil { return err } - request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody)) + request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody)) if err != nil { return err } // TODO: API KEY 负载均衡 rand.Seed(time.Now().UnixNano()) - index := rand.Intn(len(s.Config.OpenAi.ApiKeys)) + index := rand.Intn(len(s.Config.Chat.ApiKeys)) + logger.Infof("Use API KEY: %s", s.Config.Chat.ApiKeys[index]) request.Header.Add("Content-Type", "application/json") - request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index])) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.Chat.ApiKeys[index])) uri := url.URL{} proxy, _ := uri.Parse(s.Config.ProxyURL) @@ -104,11 +106,12 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { for { line, err := reader.ReadString('\n') if err != nil && err != io.EOF { - fmt.Println(err) + logger.Error(err) break } if line == "" { + replyMessage(types.WsMessage{Type: types.WsEnd}, ws) break } else if len(line) < 20 { continue @@ -116,29 +119,47 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil { - fmt.Println(err) + logger.Error(err) continue } // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role + replyMessage(types.WsMessage{Type: types.WsStart}, ws) continue - } else { - contents = append(contents, responseBody.Choices[0].Delta.Content) - } - // 推送消息到客户端 - err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content)) - if err != nil { - logger.Error(err) - } - fmt.Print(responseBody.Choices[0].Delta.Content) - if responseBody.Choices[0].FinishReason != "" { + } else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了 + replyMessage(types.WsMessage{Type: types.WsEnd}, ws) break + } else { + content := responseBody.Choices[0].Delta.Content + contents = append(contents, content) + replyMessage(types.WsMessage{ + Type: types.WsMiddle, + Content: responseBody.Choices[0].Delta.Content, + }, ws) } } // 追加历史消息 + history = append(history, types.Message{ + Role: "user", + Content: text, + }) + message.Content = strings.Join(contents, "") history = append(history, message) s.History[userId] = history return nil } + +// 回复客户端消息 +func replyMessage(message types.WsMessage, client Client) { + msg, err := json.Marshal(message) + if err != nil { + logger.Errorf("Error for decoding json data: %v", err.Error()) + return + } + err = client.(*WsClient).Send(msg) + if err != nil { + logger.Errorf("Error for reply message: %v", err.Error()) + } +} diff --git a/server/config_handler.go b/server/config_handler.go index a33b6778..bb6d80f9 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -19,7 +19,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { } // API key if key, ok := data["api_key"]; ok && len(key) > 20 { - s.Config.OpenAi.ApiKeys = append(s.Config.OpenAi.ApiKeys, key) + s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key) } // proxy URL @@ -29,7 +29,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { // Model if model, ok := data["model"]; ok { - s.Config.OpenAi.Model = model + s.Config.Chat.Model = model } // Temperature @@ -42,7 +42,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { }) return } - s.Config.OpenAi.Temperature = float32(v) + s.Config.Chat.Temperature = float32(v) } // max_tokens @@ -55,8 +55,20 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { }) return } - s.Config.OpenAi.MaxTokens = v + s.Config.Chat.MaxTokens = v + } + // enable Context + if enableContext, ok := data["enable_context"]; ok { + v, err := strconv.ParseBool(enableContext) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{ + Code: types.InvalidParams, + Message: "enable_context must be a bool parameter", + }) + return + } + s.Config.Chat.EnableContext = v } // 保存配置文件 diff --git a/types/config.go b/types/config.go index 90ebc285..a0dbac50 100644 --- a/types/config.go +++ b/types/config.go @@ -2,9 +2,9 @@ package types import ( "bytes" - "fmt" "github.com/BurntSushi/toml" "net/http" + logger2 "openai/logger" "openai/utils" "os" ) @@ -13,16 +13,17 @@ type Config struct { Listen string Session Session ProxyURL string - OpenAi OpenAi + Chat Chat } -// OpenAi configs struct -type OpenAi struct { - ApiURL string - ApiKeys []string - Model string - Temperature float32 - MaxTokens int +// Chat configs struct +type Chat struct { + ApiURL string + ApiKeys []string + Model string + Temperature float32 + MaxTokens int + EnableContext bool // 是否保持聊天上下文 } // Session configs struct @@ -51,21 +52,24 @@ func NewDefaultConfig() *Config { HttpOnly: false, SameSite: http.SameSiteLaxMode, }, - OpenAi: OpenAi{ - ApiURL: "https://api.openai.com/v1/chat/completions", - ApiKeys: []string{""}, - Model: "gpt-3.5-turbo", - MaxTokens: 1024, - Temperature: 1.0, + Chat: Chat{ + ApiURL: "https://api.openai.com/v1/chat/completions", + ApiKeys: []string{""}, + Model: "gpt-3.5-turbo", + MaxTokens: 1024, + Temperature: 1.0, + EnableContext: true, }, } } +var logger = logger2.GetLogger() + func LoadConfig(configFile string) (*Config, error) { var config *Config _, err := os.Stat(configFile) if err != nil { - fmt.Errorf("Error: %s", err.Error()) + logger.Errorf("Error open config file: %s", err.Error()) config = NewDefaultConfig() // save config err := SaveConfig(config, configFile) @@ -76,7 +80,6 @@ func LoadConfig(configFile string) (*Config, error) { return config, nil } _, err = toml.DecodeFile(configFile, &config) - fmt.Println(config) if err != nil { return nil, err } diff --git a/types/web.go b/types/web.go index 858b51fe..10eaac84 100644 --- a/types/web.go +++ b/types/web.go @@ -10,11 +10,18 @@ type BizVo struct { Data interface{} `json:"data,omitempty"` } -// WsVo Websocket 信息 VO -type WsVo struct { - Stop bool - Content string +// WsMessage Websocket message +type WsMessage struct { + Type WsMsgType `json:"type"` // 消息类别,start, end + Content string `json:"content"` } +type WsMsgType string + +const ( + WsStart = WsMsgType("start") + WsMiddle = WsMsgType("middle") + WsEnd = WsMsgType("end") +) type BizCode int diff --git a/web/README.md b/web/README.md deleted file mode 100644 index 5ac2fc15..00000000 --- a/web/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# qark-webssh - -## Project setup -``` -npm install -``` - -### Compiles and hot-reloads for development -``` -npm run serve -``` - -### Compiles and minifies for production -``` -npm run build -``` - -### Lints and fixes files -``` -npm run lint -``` diff --git a/web/src/components/ChatReply.vue b/web/src/components/ChatReply.vue index c77e5db3..9031c4b9 100644 --- a/web/src/components/ChatReply.vue +++ b/web/src/components/ChatReply.vue @@ -6,7 +6,10 @@
-
{{ content }}
+
+ + +
@@ -24,6 +27,10 @@ export default defineComponent({ icon: { type: String, default: 'images/gpt-icon.png', + }, + cursor: { + type: Boolean, + default: true } }, data() { @@ -68,6 +75,21 @@ export default defineComponent({ background-color: #fff; font-size: var(--content-font-size); border-radius: 5px; + + .cursor { + height 24px; + border-left 1px solid black; + + animation: cursorImg 1s infinite steps(1, start); + @keyframes cursorImg { + 0%, 100% { + opacity: 0; + } + 50% { + opacity: 1; + } + } + } } } } diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index e4413ddf..d2318ff6 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -9,6 +9,7 @@ :content="chat.content"/> @@ -42,6 +43,7 @@ import {defineComponent, nextTick} from 'vue' import ChatPrompt from "@/components/ChatPrompt.vue"; import ChatReply from "@/components/ChatReply.vue"; import {randString} from "@/utils/libs"; +import {ElMessage} from 'element-plus' export default defineComponent({ name: "XChat", @@ -49,20 +51,7 @@ export default defineComponent({ data() { return { title: "ChatGPT 控制台", - chatData: [ - { - id: "1", - type: 'prompt', - icon: 'images/user-icon.png', - content: '请问棒球棒可以放进人的耳朵里面吗' - }, - { - id: "2", - type: 'reply', - icon: 'images/gpt-icon.png', - content: '不可以。棒球棒的直径通常都比人的耳道大得多,而且人的耳朵是非常敏感和易受伤的,如果硬塞棒球棒可能会导致耳道损伤、出血和疼痛等问题。此外,塞入耳道的物体还可能引起耳屎的囤积和感染等问题,因此强烈建议不要将任何非耳朵医学用品的物品插入耳朵。如果您有耳道不适或者其他耳朵健康问题,应该咨询专业医生的建议。' - } - ], + chatData: [], inputBoxHeight: 63, inputBoxWidth: 0, inputValue: '', @@ -99,24 +88,34 @@ export default defineComponent({ window.addEventListener('resize', this.windowResize); // 初始化 WebSocket 对象 - const socket = new WebSocket(process.env.VUE_APP_WS_HOST+'/api/chat'); + const socket = new WebSocket(process.env.VUE_APP_WS_HOST + '/api/chat'); socket.addEventListener('open', () => { - console.log('WebSocket 连接已打开'); + ElMessage.success('创建会话成功!'); }); socket.addEventListener('message', event => { if (event.data instanceof Blob) { const reader = new FileReader(); reader.readAsText(event.data, "UTF-8"); reader.onload = () => { - // this.chatData.push({ - // type: "reply", - // id: randString(32), - // icon: 'images/gpt-icon.png', - // content: reader.result - // }); - this.chatData[this.chatData.length - 1]["content"] += reader.result - this.sending = false; - + const data = JSON.parse(String(reader.result)); + if (data.type === 'start') { + this.chatData.push({ + type: "reply", + id: randString(32), + icon: 'images/gpt-icon.png', + content: "", + cursor: true + }); + } else if (data.type === 'end') { + this.sending = false; + this.chatData[this.chatData.length - 1]["cursor"] = false; + } else { + let content = data.content; + if (content.indexOf("\n\n") >= 0) { + content = content.replace("\n\n", "
"); + } + this.chatData[this.chatData.length - 1]["content"] += content; + } // 将聊天框的滚动条滑动到最底部 nextTick(() => { document.getElementById('container').scrollTo(0, document.getElementById('container').scrollHeight) @@ -125,11 +124,11 @@ export default defineComponent({ } }); - socket.addEventListener('close', event => { - console.log('WebSocket 连接已关闭', event.reason); + socket.addEventListener('close', () => { + ElMessage.error('会话发生异常,请刷新页面后重试'); }); socket.addEventListener('error', event => { - console.error('WebSocket 连接发生错误', event); + ElMessage.error('WebSocket 连接发生错误: ' + event.message); }); this.socket = socket; @@ -240,6 +239,7 @@ export default defineComponent({ #container { overflow auto; + width 100%; .chat-box { // 变量定义