diff --git a/api/core/app_server.go b/api/core/app_server.go index 16c864c4..83253a80 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -22,8 +22,8 @@ type AppServer struct { Debug bool Config *types.AppConfig Engine *gin.Engine - ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message - ChatConfig *types.ChatConfig // 聊天配置 + ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message + ChatConfig *types.ChatConfig // 聊天配置 // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API @@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer { Debug: false, Config: appConfig, Engine: gin.Default(), - ChatContexts: types.NewLMap[string, []types.Message](), + ChatContexts: types.NewLMap[string, []interface{}](), ChatSession: types.NewLMap[string, types.ChatSession](), ChatClients: types.NewLMap[string, *types.WsClient](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 60858b84..237db83f 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -2,17 +2,17 @@ package types // ApiRequest API 请求实体 type ApiRequest struct { - Model string `json:"model"` - Temperature float32 `json:"temperature"` - MaxTokens int `json:"max_tokens"` - Stream bool `json:"stream"` - Messages []Message `json:"messages"` + Model string `json:"model"` + Temperature float32 `json:"temperature"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + Messages []interface{} `json:"messages"` + Functions []Function `json:"functions"` } type Message struct { - Role string `json:"role"` - Content string `json:"content"` - FunctionCall FunctionCall `json:"function_call"` + Role string `json:"role"` + Content string `json:"content"` } type ApiResponse struct { @@ -21,8 +21,15 @@ type ApiResponse struct { // ChoiceItem API 响应实体 type ChoiceItem struct { - Delta Message `json:"delta"` - FinishReason string `json:"finish_reason"` + Delta Delta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type Delta struct { + Role string `json:"role"` + Name string `json:"name"` + Content interface{} `json:"content"` + FunctionCall FunctionCall `json:"function_call,omitempty"` } // ChatSession 聊天会话对象 diff --git a/api/core/types/function.go b/api/core/types/function.go index 18d64715..b37cbeac 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -6,13 +6,71 @@ type FunctionCall struct { } type Function struct { - Name string - Description string - Parameters []Parameter + Name string `json:"name"` + Description string `json:"description"` + Parameters Parameters `json:"parameters"` } -type Parameter struct { - Type string - Required []string - Properties map[string]interface{} +type Parameters struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]Property `json:"properties"` +} + +type Property struct { + Type string `json:"type"` + Description string `json:"description"` +} + +var InnerFunctions = []Function{ + { + Name: "zao_bao", + Description: "每日早报,获取当天全球的热门新闻事件列表", + Parameters: Parameters{ + + Type: "object", + Properties: map[string]Property{ + "text": { + Type: "string", + Description: "", + }, + }, + Required: []string{}, + }, + }, + { + Name: "weibo_hot", + Description: "新浪微博热搜榜,微博当日热搜榜单", + Parameters: Parameters{ + Type: "object", + Properties: map[string]Property{ + "text": { + Type: "string", + Description: "", + }, + }, + Required: []string{}, + }, + }, + + { + Name: "zhihu_top", + Description: "知乎热榜,知乎当日话题讨论榜单", + Parameters: Parameters{ + Type: "object", + Properties: map[string]Property{ + "text": { + Type: "string", + Description: "", + }, + }, + Required: []string{}, + }, + }, +} + +var FunctionNameMap = map[string]string{ + "zao_bao": "每日早报", + "weibo_hot": "微博热搜", + "zhihu_top": "知乎热榜", } diff --git a/api/core/types/locked_map.go b/api/core/types/locked_map.go index f05f5b3f..36ca48ff 100644 --- a/api/core/types/locked_map.go +++ b/api/core/types/locked_map.go @@ -9,7 +9,7 @@ type MKey interface { string | int } type MValue interface { - *WsClient | ChatSession | []Message | context.CancelFunc + *WsClient | ChatSession | context.CancelFunc | []interface{} } type LMap[K MKey, T MValue] struct { lock sync.RWMutex diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 2c7dec5d..d4792c61 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession Temperature: userVo.ChatConfig.Temperature, MaxTokens: userVo.ChatConfig.MaxTokens, Stream: true, + Functions: types.InnerFunctions, } // 加载聊天上下文 - var chatCtx []types.Message + var chatCtx []interface{} if userVo.ChatConfig.EnableContext { if h.App.ChatContexts.Has(session.ChatId) { chatCtx = h.App.ChatContexts.Get(session.ChatId) @@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession var messages []types.Message err := utils.JsonDecode(role.Context, &messages) if err == nil { - chatCtx = messages + for _, v := range messages { + chatCtx = append(chatCtx, v) + } } - // TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的 + // TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的 var historyMessages []model.HistoryMessage - res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages) + res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages) if res.Error == nil { for _, msg := range historyMessages { ms := types.Message{Role: "user", Content: msg.Content} @@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession logger.Info("聊天上下文:", chatCtx) } } - req.Messages = append(chatCtx, types.Message{ - Role: "user", - Content: prompt, + reqMgs := make([]interface{}, 0) + for _, m := range chatCtx { + reqMgs = append(reqMgs, m) + } + + req.Messages = append(reqMgs, map[string]interface{}{ + "role": "user", + "content": prompt, }) var apiKey string - response, err := h.fakeRequest(ctx, userVo, &apiKey, req) + response, err := h.doRequest(ctx, userVo, &apiKey, req) if err != nil { if strings.Contains(err.Error(), "context canceled") { logger.Info("用户取消了请求:", prompt) @@ -213,176 +221,191 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession defer response.Body.Close() } - //contentType := response.Header.Get("Content-Type") - //if strings.Contains(contentType, "text/event-stream") || true { - if true { - replyCreatedAt := time.Now() - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - var functionCall = false - var functionName string - var arguments = make([]string, 0) - reader := bufio.NewReader(response.Body) - for { - line, err := reader.ReadString('\n') - if err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error(err) - } - break - } - if !strings.Contains(line, "data:") || len(line) < 30 { - continue - } - - var responseBody = types.ApiResponse{} - err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 - logger.Error(err, line) - replyMessage(ws, ErrorMsg) - replyMessage(ws, "![](/images/wx.png)") - break - } - - fun := responseBody.Choices[0].Delta.FunctionCall - if functionCall && fun.Name == "" { - arguments = append(arguments, fun.Arguments) - continue - } - - if !utils.IsEmptyValue(fun) { - functionCall = true - functionName = fun.Name - replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)}) - continue - } - - if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 - break - } - - // 初始化 role - if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { - message.Role = responseBody.Choices[0].Delta.Role - replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - continue - } else if responseBody.Choices[0].FinishReason != "" { - break // 输出完成或者输出中断了 - } else { - content := responseBody.Choices[0].Delta.Content - contents = append(contents, content) - replyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: responseBody.Choices[0].Delta.Content, - }) - } - } // end for - - if functionCall { // 调用函数完成任务 - // TODO 调用函数完成任务 - data, err := h.funcZaoBao.Fetch() - if err != nil { - replyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: "调用函数出错", - }) - } else { - replyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: data, - }) - } - contents = append(contents, data) - } - - // 消息发送成功 - if len(contents) > 0 { - // 更新用户的对话次数 - res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) - if res.Error != nil { - return res.Error - } - - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if userVo.ChatConfig.EnableContext && functionCall == false { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - if userVo.ChatConfig.EnableHistory { - // for prompt - token, err := utils.CalcTokens(prompt, req.Model) + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + if true { + replyCreatedAt := time.Now() + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var functionCall = false + var functionName string + var arguments = make([]string, 0) + reader := bufio.NewReader(response.Body) + for { + line, err := reader.ReadString('\n') if err != nil { - logger.Error(err) + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else if err != io.EOF { + logger.Error(err) + } + break } - historyUserMsg := model.HistoryMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: user.Avatar, - Content: prompt, - Tokens: token, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) + if !strings.Contains(line, "data:") || len(line) < 30 { + continue } - // for reply - token, err = utils.CalcTokens(message.Content, req.Model) + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 + logger.Error(err, line) + replyMessage(ws, ErrorMsg) + replyMessage(ws, "![](/images/wx.png)") + break + } + + fun := responseBody.Choices[0].Delta.FunctionCall + if functionCall && fun.Name == "" { + arguments = append(arguments, fun.Arguments) + continue + } + + if !utils.IsEmptyValue(fun) { + functionCall = true + functionName = fun.Name + replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])}) + continue + } + + if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 + break + } + + // 初始化 role + if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { + message.Role = responseBody.Choices[0].Delta.Role + replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + continue + } else if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { + content := responseBody.Choices[0].Delta.Content + contents = append(contents, utils.InterfaceToString(content)) + replyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } + } // end for + + if functionCall { // 调用函数完成任务 + logger.Info(functionName) + logger.Info(arguments) + // TODO 调用函数完成任务 + data, err := h.funcZaoBao.Fetch() if err != nil { - logger.Error(err) + replyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: "调用函数出错", + }) + } else { + replyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: data, + }) } - historyReplyMsg := model.HistoryMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: token, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 统计用户 token 数量 - h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?", - historyUserMsg.Tokens+historyReplyMsg.Tokens)) + contents = append(contents, data) } - // 保存当前会话 - var chatItem model.ChatItem - res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.Model = session.Model - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt + // 消息发送成功 + if len(contents) > 0 { + // 更新用户的对话次数 + res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + if res.Error != nil { + return res.Error + } + + if message.Role == "" { + message.Role = "assistant" + } + message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} + + // 计算本次对话消耗的总 token 数量 + req.Messages = append(req.Messages, message) + totalTokens := getTotalTokens(req) + replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)}) + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if userVo.ChatConfig.EnableContext && functionCall == false { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + if userVo.ChatConfig.EnableHistory { + useContext := true + if functionCall { + useContext = false + } + + // for prompt + token, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: user.Avatar, + Content: prompt, + Tokens: token, + UseContext: useContext, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + token, err = utils.CalcTokens(message.Content, req.Model) + if err != nil { + logger.Error(err) + } + historyReplyMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: token, + UseContext: useContext, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 统计用户 token 数量 + h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?", + historyUserMsg.Tokens+historyReplyMsg.Tokens)) + } + + // 保存当前会话 + var chatItem model.ChatItem + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.Model = session.Model + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + h.db.Create(&chatItem) } - h.db.Create(&chatItem) } } } else { @@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin return client.Do(request) } -func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) { - link := "https://img.r9it.com/chatgpt/response" - client := &http.Client{} - request, _ := http.NewRequest(http.MethodGet, link, nil) - return client.Do(request) -} - // 回复客户片段端消息 func replyChunkMessage(client types.Client, message types.WsMessage) { msg, err := json.Marshal(message) @@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) { resp.SUCCESS(c, tokens) } +func getTotalTokens(req types.ApiRequest) int { + encode := utils.JsonEncode(req.Messages) + var items []map[string]interface{} + err := utils.JsonDecode(encode, &items) + if err != nil { + return 0 + } + tokens := 0 + for _, item := range items { + content, ok := item["content"] + if ok && !utils.IsEmptyValue(content) { + t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model) + if err == nil { + tokens += t + } + } + } + return tokens +} + // StopGenerate 停止生成 func (h *ChatHandler) StopGenerate(c *gin.Context) { sessionId := c.Query("session_id") diff --git a/api/main.go b/api/main.go index 4eedaab9..07a6dc92 100644 --- a/api/main.go +++ b/api/main.go @@ -11,6 +11,7 @@ import ( "chatplus/store" "context" "embed" + "errors" "io" "log" "os" @@ -101,9 +102,12 @@ func main() { }), // 创建函数 - fx.Provide(func() *function.FuncZaoBao { + fx.Provide(func() (*function.FuncZaoBao, error) { token := os.Getenv("AL_API_TOKEN") - return function.NewZaoBao(token) + if token == "" { + return nil, errors.New("invalid AL api token") + } + return function.NewZaoBao(token), nil }), // 创建控制器 diff --git a/api/store/model/chat_history.go b/api/store/model/chat_history.go index 208e3afb..b1eb85e7 100644 --- a/api/store/model/chat_history.go +++ b/api/store/model/chat_history.go @@ -2,13 +2,14 @@ package model type HistoryMessage struct { BaseModel - ChatId string // 会话 ID - UserId uint // 用户 ID - RoleId uint // 角色 ID - Type string - Icon string - Tokens int - Content string + ChatId string // 会话 ID + UserId uint // 用户 ID + RoleId uint // 角色 ID + Type string + Icon string + Tokens int + Content string + UseContext bool // 是否可以作为聊天上下文 } func (HistoryMessage) TableName() string { diff --git a/api/store/vo/chat_history.go b/api/store/vo/chat_history.go index 19bb0130..60ca8838 100644 --- a/api/store/vo/chat_history.go +++ b/api/store/vo/chat_history.go @@ -2,13 +2,14 @@ package vo type HistoryMessage struct { BaseVo - ChatId string `json:"chat_id"` - UserId uint `json:"user_id"` - RoleId uint `json:"role_id"` - Type string `json:"type"` - Icon string `json:"icon"` - Tokens int `json:"tokens"` - Content string `json:"content"` + ChatId string `json:"chat_id"` + UserId uint `json:"user_id"` + RoleId uint `json:"role_id"` + Type string `json:"type"` + Icon string `json:"icon"` + Tokens int `json:"tokens"` + Content string `json:"content"` + UseContext bool `json:"use_context"` } func (HistoryMessage) TableName() string { diff --git a/api/utils/strings.go b/api/utils/strings.go index 7bdd1da1..a8fa3be6 100644 --- a/api/utils/strings.go +++ b/api/utils/strings.go @@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string { func JsonDecode(src string, dest interface{}) error { return json.Unmarshal([]byte(src), dest) } + +func InterfaceToString(value interface{}) string { + if str, ok := value.(string); ok { + return str + } + return JsonEncode(value) +} diff --git a/database/plugins.sql b/database/plugins.sql new file mode 100644 index 00000000..c5096aa9 --- /dev/null +++ b/database/plugins.sql @@ -0,0 +1 @@ +ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`; \ No newline at end of file diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 8b4988fd..890c143d 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -94,6 +94,7 @@
+ 聊天角色: