From 3bb681449376b0cc981ddd8dc85cb05ddd9b150c Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 20 Mar 2023 15:02:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=20ChatGPT=20API=20=E9=87=8D?= =?UTF-8?q?=E8=AF=95=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 +- server/chat_handler.go | 39 +++---- server/config_handler.go | 7 +- server/server.go | 11 ++ web/.env.development | 2 +- web/src/components/ChatPrompt.vue | 2 +- web/src/components/ChatReply.vue | 35 +++--- web/src/components/ConfigDialog.vue | 6 +- web/src/views/Chat.vue | 172 +++++++++++++++++----------- 9 files changed, 167 insertions(+), 113 deletions(-) diff --git a/README.md b/README.md index 0c08ab5f..ca5b2840 100644 --- a/README.md +++ b/README.md @@ -6,5 +6,9 @@ * [ ] 使用 level DB 保存用户聊天的上下文 * [ ] 使用 MySQL 保存用户的聊天的历史记录 -* [ ] 用户聊天鉴权 +* [ ] 用户聊天鉴权,设置口令模式 +* [ ] 每次连接自动加载历史记录 +* [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师... +* [ ] markdown 语法解析 +* [ ] 用户配置界面 diff --git a/server/chat_handler.go b/server/chat_handler.go index 4762d99d..58640423 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -10,7 +10,6 @@ import ( "io" "math/rand" "net/http" - "net/url" "openai/types" "strings" "time" @@ -74,34 +73,32 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { return err } - // TODO: API KEY 负载均衡 - rand.Seed(time.Now().UnixNano()) - index := rand.Intn(len(s.Config.Chat.ApiKeys)) - logger.Infof("Use API KEY: %s", s.Config.Chat.ApiKeys[index]) request.Header.Add("Content-Type", "application/json") - request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.Chat.ApiKeys[index])) - - uri := url.URL{} - proxy, _ := uri.Parse(s.Config.ProxyURL) - client := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxy), - }, - } - response, err := client.Do(request) + // 随机获取一个 API Key,如果请求失败,则更换 API Key 重试 + // TODO: 需要将失败的 Key 移除列表 + rand.Seed(time.Now().UnixNano()) var retryCount = 3 - for err != nil { - if retryCount <= 0 { - return err + var response *http.Response + for retryCount > 0 { + index := rand.Intn(len(s.Config.Chat.ApiKeys)) + apiKey := s.Config.Chat.ApiKeys[index] + logger.Infof("Use API KEY: %s", apiKey) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + response, err = s.Client.Do(request) + if err == nil { + break + } else { + logger.Error(err) } - response, err = client.Do(request) retryCount-- } + if err != nil { + return err + } var message = types.Message{} var contents = make([]string, 0) var responseBody = types.ApiResponse{} - reader := bufio.NewReader(response.Body) for { line, err := reader.ReadString('\n') @@ -119,7 +116,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error { err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil { - logger.Error(err) + logger.Error(line) continue } // 初始化 role diff --git a/server/config_handler.go b/server/config_handler.go index bb6d80f9..c1c4bc44 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -73,6 +73,11 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { // 保存配置文件 logger.Infof("Config: %+v", s.Config) - types.SaveConfig(s.Config, s.ConfigPath) + err = types.SaveConfig(s.Config, s.ConfigPath) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"}) + return + } + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) } diff --git a/server/server.go b/server/server.go index ba1fe235..cf712964 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "io/fs" "log" "net/http" + "net/url" logger2 "openai/logger" "openai/types" "os" @@ -32,6 +33,7 @@ func (s StaticFile) Open(name string) (fs.File, error) { type Server struct { Config *types.Config ConfigPath string + Client *http.Client History map[string][]types.Message } @@ -41,8 +43,17 @@ func NewServer(configPath string) (*Server, error) { if err != nil { return nil, err } + + uri := url.URL{} + proxy, _ := uri.Parse(config.ProxyURL) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } return &Server{ Config: config, + Client: client, ConfigPath: configPath, History: make(map[string][]types.Message, 16)}, nil } diff --git a/web/.env.development b/web/.env.development index c6d32f62..ee674e92 100644 --- a/web/.env.development +++ b/web/.env.development @@ -1 +1 @@ -VUE_APP_WS_HOST=ws://127.0.0.1:5678 \ No newline at end of file +VUE_APP_WS_HOST=ws://172.22.11.200:5678 \ No newline at end of file diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index 92f0c047..763d4150 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -34,7 +34,7 @@ export default defineComponent({ \ No newline at end of file diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index 0a2ec31a..f81c8327 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -17,16 +17,16 @@
- - - +
@@ -41,6 +41,7 @@
+ @@ -52,7 +53,7 @@ import {defineComponent, nextTick} from 'vue' import ChatPrompt from "@/components/ChatPrompt.vue"; import ChatReply from "@/components/ChatReply.vue"; import {randString} from "@/utils/libs"; -import {ElMessage} from 'element-plus' +import {ElMessage, ElMessageBox} from 'element-plus' import {Tools} from '@element-plus/icons-vue' import ConfigDialog from '@/components/ConfigDialog.vue' @@ -67,6 +68,7 @@ export default defineComponent({ chatBoxHeight: 0, showDialog: false, + connectingMessageBox: null, socket: null, sending: false } @@ -78,53 +80,84 @@ export default defineComponent({ nextTick(() => { this.chatBoxHeight = window.innerHeight - 61; }) - - // 初始化 WebSocket 对象 - const socket = new WebSocket(process.env.VUE_APP_WS_HOST + '/api/chat'); - socket.addEventListener('open', () => { - ElMessage.success('创建会话成功!'); - }); - socket.addEventListener('message', event => { - if (event.data instanceof Blob) { - const reader = new FileReader(); - reader.readAsText(event.data, "UTF-8"); - reader.onload = () => { - const data = JSON.parse(String(reader.result)); - if (data.type === 'start') { - this.chatData.push({ - type: "reply", - id: randString(32), - icon: 'images/gpt-icon.png', - content: "", - cursor: true - }); - } else if (data.type === 'end') { - this.sending = false; - this.chatData[this.chatData.length - 1]["cursor"] = false; - } else { - let content = data.content; - if (content.indexOf("\n\n") >= 0) { - content = content.replace("\n\n", "
"); - } - this.chatData[this.chatData.length - 1]["content"] += content; - } - // 将聊天框的滚动条滑动到最底部 - nextTick(() => { - document.getElementById('container').scrollTo(0, document.getElementById('container').scrollHeight) - }) - }; - } - - }); - socket.addEventListener('error', () => { - ElMessage.error('会话发生异常,请刷新页面后重试'); - }); - - this.socket = socket; + this.connect(); }, methods: { + connect: function () { + if (this.online) { + return + } + + // 初始化 WebSocket 对象 + const socket = new WebSocket(process.env.VUE_APP_WS_HOST + '/api/chat'); + socket.addEventListener('open', () => { + ElMessage.success('创建会话成功!'); + + if (this.connectingMessageBox != null) { + this.connectingMessageBox.close(); + this.connectingMessageBox = null; + } + }); + + socket.addEventListener('message', event => { + if (event.data instanceof Blob) { + const reader = new FileReader(); + reader.readAsText(event.data, "UTF-8"); + reader.onload = () => { + const data = JSON.parse(String(reader.result)); + if (data.type === 'start') { + this.chatData.push({ + type: "reply", + id: randString(32), + icon: 'images/gpt-icon.png', + content: "", + cursor: true + }); + } else if (data.type === 'end') { + this.sending = false; + this.chatData[this.chatData.length - 1]["cursor"] = false; + } else { + let content = data.content; + // 替换换行符 + if (content.indexOf("\n\n") >= 0) { + content = content.replace("\n\n", "
"); + } + this.chatData[this.chatData.length - 1]["content"] += content; + } + // 将聊天框的滚动条滑动到最底部 + nextTick(() => { + document.getElementById('container').scrollTo(0, document.getElementById('container').scrollHeight) + }) + }; + } + + }); + socket.addEventListener('close', () => { + ElMessageBox.confirm( + '^_^ 会话发生异常,您已经从服务器断开连接!', + '注意:', + { + confirmButtonText: '重连会话', + cancelButtonText: '不聊了', + type: 'warning', + } + ) + .then(() => { + this.connect(); + }) + .catch(() => { + ElMessage({ + type: 'info', + message: '您关闭了会话', + }) + }) + }); + + this.socket = socket; + }, + inputKeyDown: function (e) { if (e.keyCode === 13) { if (this.sending) { @@ -152,7 +185,10 @@ export default defineComponent({ // TODO: 使用 websocket 提交数据到后端 this.sending = true; this.socket.send(this.inputValue); + this.$refs["text-input"].blur(); this.inputValue = ''; + // 等待 textarea 重新调整尺寸之后再自动获取焦点 + setTimeout(() => this.$refs["text-input"].focus(), 100) return true; }, @@ -227,20 +263,12 @@ export default defineComponent({ background-color: rgba(255, 255, 255, 1); padding: 5px 10px; - .input-text { - font-size: 16px; - padding 0 - margin 0 - outline: none; - width 100%; - border none - background #ffffff - resize none - line-height 24px; - color #333; + .el-textarea__inner { + box-shadow: none + padding 5px 0 } - .input-text::-webkit-scrollbar { + .el-textarea__inner::-webkit-scrollbar { width: 0; height: 0; } @@ -267,9 +295,19 @@ export default defineComponent({ width: 0; height: 0; } - } } +.el-message-box { + width 90%; + max-width 420px; +} + +.el-message { + width 90%; + min-width: 300px; + max-width 600px; +} +