diff --git a/README.md b/README.md
index 0c08ab5f..ca5b2840 100644
--- a/README.md
+++ b/README.md
@@ -6,5 +6,9 @@
* [ ] 使用 level DB 保存用户聊天的上下文
* [ ] 使用 MySQL 保存用户的聊天的历史记录
-* [ ] 用户聊天鉴权
+* [ ] 用户聊天鉴权,设置口令模式
+* [ ] 每次连接自动加载历史记录
+* [ ] 角色设定,预设一些角色,比如程序员,产品经理,医生,作家,老师...
+* [ ] markdown 语法解析
+* [ ] 用户配置界面
diff --git a/server/chat_handler.go b/server/chat_handler.go
index 4762d99d..58640423 100644
--- a/server/chat_handler.go
+++ b/server/chat_handler.go
@@ -10,7 +10,6 @@ import (
"io"
"math/rand"
"net/http"
- "net/url"
"openai/types"
"strings"
"time"
@@ -74,34 +73,32 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
return err
}
- // TODO: API KEY 负载均衡
- rand.Seed(time.Now().UnixNano())
- index := rand.Intn(len(s.Config.Chat.ApiKeys))
- logger.Infof("Use API KEY: %s", s.Config.Chat.ApiKeys[index])
request.Header.Add("Content-Type", "application/json")
- request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.Chat.ApiKeys[index]))
-
- uri := url.URL{}
- proxy, _ := uri.Parse(s.Config.ProxyURL)
- client := &http.Client{
- Transport: &http.Transport{
- Proxy: http.ProxyURL(proxy),
- },
- }
- response, err := client.Do(request)
+ // 随机获取一个 API Key,如果请求失败,则更换 API Key 重试
+ // TODO: 需要将失败的 Key 移除列表
+ rand.Seed(time.Now().UnixNano())
var retryCount = 3
- for err != nil {
- if retryCount <= 0 {
- return err
+ var response *http.Response
+ for retryCount > 0 {
+ index := rand.Intn(len(s.Config.Chat.ApiKeys))
+ apiKey := s.Config.Chat.ApiKeys[index]
+ logger.Infof("Use API KEY: %s", apiKey)
+ request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
+ response, err = s.Client.Do(request)
+ if err == nil {
+ break
+ } else {
+ logger.Error(err)
}
- response, err = client.Do(request)
retryCount--
}
+ if err != nil {
+ return err
+ }
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
-
reader := bufio.NewReader(response.Body)
for {
line, err := reader.ReadString('\n')
@@ -119,7 +116,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
- logger.Error(err)
+ logger.Error(line)
continue
}
// 初始化 role
diff --git a/server/config_handler.go b/server/config_handler.go
index bb6d80f9..c1c4bc44 100644
--- a/server/config_handler.go
+++ b/server/config_handler.go
@@ -73,6 +73,11 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
// 保存配置文件
logger.Infof("Config: %+v", s.Config)
- types.SaveConfig(s.Config, s.ConfigPath)
+ err = types.SaveConfig(s.Config, s.ConfigPath)
+ if err != nil {
+ c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
+ return
+ }
+
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
}
diff --git a/server/server.go b/server/server.go
index ba1fe235..cf712964 100644
--- a/server/server.go
+++ b/server/server.go
@@ -9,6 +9,7 @@ import (
"io/fs"
"log"
"net/http"
+ "net/url"
logger2 "openai/logger"
"openai/types"
"os"
@@ -32,6 +33,7 @@ func (s StaticFile) Open(name string) (fs.File, error) {
type Server struct {
Config *types.Config
ConfigPath string
+ Client *http.Client
History map[string][]types.Message
}
@@ -41,8 +43,17 @@ func NewServer(configPath string) (*Server, error) {
if err != nil {
return nil, err
}
+
+ uri := url.URL{}
+ proxy, _ := uri.Parse(config.ProxyURL)
+ client := &http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyURL(proxy),
+ },
+ }
return &Server{
Config: config,
+ Client: client,
ConfigPath: configPath,
History: make(map[string][]types.Message, 16)}, nil
}
diff --git a/web/.env.development b/web/.env.development
index c6d32f62..ee674e92 100644
--- a/web/.env.development
+++ b/web/.env.development
@@ -1 +1 @@
-VUE_APP_WS_HOST=ws://127.0.0.1:5678
\ No newline at end of file
+VUE_APP_WS_HOST=ws://172.22.11.200:5678
\ No newline at end of file
diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue
index 92f0c047..763d4150 100644
--- a/web/src/components/ChatPrompt.vue
+++ b/web/src/components/ChatPrompt.vue
@@ -34,7 +34,7 @@ export default defineComponent({
\ No newline at end of file
diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue
index 0a2ec31a..f81c8327 100644
--- a/web/src/views/Chat.vue
+++ b/web/src/views/Chat.vue
@@ -17,16 +17,16 @@
+
@@ -52,7 +53,7 @@ import {defineComponent, nextTick} from 'vue'
import ChatPrompt from "@/components/ChatPrompt.vue";
import ChatReply from "@/components/ChatReply.vue";
import {randString} from "@/utils/libs";
-import {ElMessage} from 'element-plus'
+import {ElMessage, ElMessageBox} from 'element-plus'
import {Tools} from '@element-plus/icons-vue'
import ConfigDialog from '@/components/ConfigDialog.vue'
@@ -67,6 +68,7 @@ export default defineComponent({
chatBoxHeight: 0,
showDialog: false,
+ connectingMessageBox: null,
socket: null,
sending: false
}
@@ -78,53 +80,84 @@ export default defineComponent({
nextTick(() => {
this.chatBoxHeight = window.innerHeight - 61;
})
-
- // 初始化 WebSocket 对象
- const socket = new WebSocket(process.env.VUE_APP_WS_HOST + '/api/chat');
- socket.addEventListener('open', () => {
- ElMessage.success('创建会话成功!');
- });
- socket.addEventListener('message', event => {
- if (event.data instanceof Blob) {
- const reader = new FileReader();
- reader.readAsText(event.data, "UTF-8");
- reader.onload = () => {
- const data = JSON.parse(String(reader.result));
- if (data.type === 'start') {
- this.chatData.push({
- type: "reply",
- id: randString(32),
- icon: 'images/gpt-icon.png',
- content: "",
- cursor: true
- });
- } else if (data.type === 'end') {
- this.sending = false;
- this.chatData[this.chatData.length - 1]["cursor"] = false;
- } else {
- let content = data.content;
- if (content.indexOf("\n\n") >= 0) {
- content = content.replace("\n\n", "
");
- }
- this.chatData[this.chatData.length - 1]["content"] += content;
- }
- // 将聊天框的滚动条滑动到最底部
- nextTick(() => {
- document.getElementById('container').scrollTo(0, document.getElementById('container').scrollHeight)
- })
- };
- }
-
- });
- socket.addEventListener('error', () => {
- ElMessage.error('会话发生异常,请刷新页面后重试');
- });
-
- this.socket = socket;
+ this.connect();
},
methods: {
+ connect: function () {
+ if (this.online) {
+ return
+ }
+
+ // 初始化 WebSocket 对象
+ const socket = new WebSocket(process.env.VUE_APP_WS_HOST + '/api/chat');
+ socket.addEventListener('open', () => {
+ ElMessage.success('创建会话成功!');
+
+ if (this.connectingMessageBox != null) {
+ this.connectingMessageBox.close();
+ this.connectingMessageBox = null;
+ }
+ });
+
+ socket.addEventListener('message', event => {
+ if (event.data instanceof Blob) {
+ const reader = new FileReader();
+ reader.readAsText(event.data, "UTF-8");
+ reader.onload = () => {
+ const data = JSON.parse(String(reader.result));
+ if (data.type === 'start') {
+ this.chatData.push({
+ type: "reply",
+ id: randString(32),
+ icon: 'images/gpt-icon.png',
+ content: "",
+ cursor: true
+ });
+ } else if (data.type === 'end') {
+ this.sending = false;
+ this.chatData[this.chatData.length - 1]["cursor"] = false;
+ } else {
+ let content = data.content;
+ // 替换换行符
+ if (content.indexOf("\n\n") >= 0) {
+ content = content.replace("\n\n", "
");
+ }
+ this.chatData[this.chatData.length - 1]["content"] += content;
+ }
+ // 将聊天框的滚动条滑动到最底部
+ nextTick(() => {
+ document.getElementById('container').scrollTo(0, document.getElementById('container').scrollHeight)
+ })
+ };
+ }
+
+ });
+ socket.addEventListener('close', () => {
+ ElMessageBox.confirm(
+ '^_^ 会话发生异常,您已经从服务器断开连接!',
+ '注意:',
+ {
+ confirmButtonText: '重连会话',
+ cancelButtonText: '不聊了',
+ type: 'warning',
+ }
+ )
+ .then(() => {
+ this.connect();
+ })
+ .catch(() => {
+ ElMessage({
+ type: 'info',
+ message: '您关闭了会话',
+ })
+ })
+ });
+
+ this.socket = socket;
+ },
+
inputKeyDown: function (e) {
if (e.keyCode === 13) {
if (this.sending) {
@@ -152,7 +185,10 @@ export default defineComponent({
// TODO: 使用 websocket 提交数据到后端
this.sending = true;
this.socket.send(this.inputValue);
+ this.$refs["text-input"].blur();
this.inputValue = '';
+ // 等待 textarea 重新调整尺寸之后再自动获取焦点
+ setTimeout(() => this.$refs["text-input"].focus(), 100)
return true;
},
@@ -227,20 +263,12 @@ export default defineComponent({
background-color: rgba(255, 255, 255, 1);
padding: 5px 10px;
- .input-text {
- font-size: 16px;
- padding 0
- margin 0
- outline: none;
- width 100%;
- border none
- background #ffffff
- resize none
- line-height 24px;
- color #333;
+ .el-textarea__inner {
+ box-shadow: none
+ padding 5px 0
}
- .input-text::-webkit-scrollbar {
+ .el-textarea__inner::-webkit-scrollbar {
width: 0;
height: 0;
}
@@ -267,9 +295,19 @@ export default defineComponent({
width: 0;
height: 0;
}
-
}
}
+.el-message-box {
+ width 90%;
+ max-width 420px;
+}
+
+.el-message {
+ width 90%;
+ min-width: 300px;
+ max-width 600px;
+}
+