diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go
index 7261f14e..ea8ccc33 100644
--- a/api/handler/chat_handler.go
+++ b/api/handler/chat_handler.go
@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
+ "bufio"
"bytes"
"context"
"encoding/json"
@@ -32,10 +33,30 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
+ req2 "github.com/imroc/req/v3"
"github.com/sashabaranov/go-openai"
"gorm.io/gorm"
)
+const (
+ ChatEventStart = "start"
+ ChatEventEnd = "end"
+ ChatEventError = "error"
+ ChatEventMessageDelta = "message_delta"
+ ChatEventTitle = "title"
+)
+
+type ChatInput struct {
+ UserId uint `json:"user_id"`
+ RoleId int `json:"role_id"`
+ ModelId int `json:"model_id"`
+ ChatId string `json:"chat_id"`
+ Content string `json:"content"`
+ Tools []int `json:"tools"`
+ Stream bool `json:"stream"`
+ Files []vo.File `json:"files"`
+}
+
type ChatHandler struct {
BaseHandler
redis *redis.Client
@@ -58,7 +79,89 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
}
}
-func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
+// Chat 处理聊天请求
+func (h *ChatHandler) Chat(c *gin.Context) {
+ var data ChatInput
+ if err := c.ShouldBindJSON(&data); err != nil {
+ resp.ERROR(c, types.InvalidArgs)
+ return
+ }
+
+ // 设置SSE响应头
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+ c.Header("X-Accel-Buffering", "no")
+
+ ctx, cancel := context.WithCancel(c.Request.Context())
+ defer cancel()
+
+ // 验证聊天角色
+ var chatRole model.ChatRole
+ err := h.DB.First(&chatRole, data.RoleId).Error
+ if err != nil || !chatRole.Enable {
+ pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
+ return
+ }
+
+ // 如果角色绑定了模型ID,使用角色的模型ID
+ if chatRole.ModelId > 0 {
+ data.ModelId = int(chatRole.ModelId)
+ }
+
+ // 获取模型信息
+ var chatModel model.ChatModel
+ err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
+ if err != nil || !chatModel.Enabled {
+ pushMessage(c, ChatEventError, "当前AI模型暂未启用,请更换模型后再发起对话!")
+ return
+ }
+
+ session := &types.ChatSession{
+ ClientIP: c.ClientIP(),
+ UserId: data.UserId,
+ ChatId: data.ChatId,
+ Tools: data.Tools,
+ Stream: data.Stream,
+ Model: types.ChatModel{
+ KeyId: data.ModelId,
+ },
+ }
+
+ // 使用旧的聊天数据覆盖模型和角色ID
+ var chat model.ChatItem
+ h.DB.Where("chat_id", data.ChatId).First(&chat)
+ if chat.Id > 0 {
+ chatModel.Id = chat.ModelId
+ data.RoleId = int(chat.RoleId)
+ }
+
+ // 复制模型数据
+ err = utils.CopyObject(chatModel, &session.Model)
+ if err != nil {
+ logger.Error(err, chatModel)
+ }
+ session.Model.Id = chatModel.Id
+
+ // 发送消息
+ err = h.sendMessage(ctx, session, chatRole, data.Content, c)
+ if err != nil {
+ pushMessage(c, ChatEventError, err.Error())
+ return
+ }
+
+ pushMessage(c, ChatEventEnd, "对话完成")
+}
+
+func pushMessage(c *gin.Context, msgType string, content interface{}) {
+ c.SSEvent("message", map[string]interface{}{
+ "type": msgType,
+ "body": content,
+ })
+ c.Writer.Flush()
+}
+
+func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, c *gin.Context) error {
var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
@@ -254,7 +357,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
logger.Debugf("%+v", req.Messages)
- return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws)
+ return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, c)
}
// Tokens 统计 token 数量
@@ -584,3 +687,221 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// 直接写入完整的音频数据到响应
c.Writer.Write(audioBytes)
}
+
+// OPenAI 消息发送实现
+func (h *ChatHandler) sendOpenAiMessage(
+ req types.ApiRequest,
+ userVo vo.User,
+ ctx context.Context,
+ session *types.ChatSession,
+ role model.ChatRole,
+ prompt string,
+ c *gin.Context) error {
+ promptCreatedAt := time.Now() // 记录提问时间
+ start := time.Now()
+ var apiKey = model.ApiKey{}
+ response, err := h.doRequest(ctx, req, session, &apiKey)
+ logger.Info("HTTP请求完成,耗时:", time.Since(start))
+ if err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ return fmt.Errorf("用户取消了请求:%s", prompt)
+ } else if strings.Contains(err.Error(), "no available key") {
+ return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
+ }
+ return err
+ } else {
+ defer response.Body.Close()
+ }
+
+ if response.StatusCode != 200 {
+ body, _ := io.ReadAll(response.Body)
+ return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
+ }
+
+ contentType := response.Header.Get("Content-Type")
+ if strings.Contains(contentType, "text/event-stream") {
+ replyCreatedAt := time.Now() // 记录回复时间
+ // 循环读取 Chunk 消息
+ var message = types.Message{Role: "assistant"}
+ var contents = make([]string, 0)
+ var function model.Function
+ var toolCall = false
+ var arguments = make([]string, 0)
+ var reasoning = false
+
+ pushMessage(c, ChatEventStart, "开始响应")
+ scanner := bufio.NewScanner(response.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.Contains(line, "data:") || len(line) < 30 {
+ continue
+ }
+ var responseBody = types.ApiResponse{}
+ err = json.Unmarshal([]byte(line[6:]), &responseBody)
+ if err != nil { // 数据解析出错
+ return errors.New(line)
+ }
+ if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
+ continue
+ }
+ if responseBody.Choices[0].Delta.Content == nil &&
+ responseBody.Choices[0].Delta.ToolCalls == nil &&
+ responseBody.Choices[0].Delta.ReasoningContent == "" {
+ continue
+ }
+
+ if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
+ pushMessage(c, ChatEventError, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
+ break
+ }
+
+ var tool types.ToolCall
+ if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
+ tool = responseBody.Choices[0].Delta.ToolCalls[0]
+ if toolCall && tool.Function.Name == "" {
+ arguments = append(arguments, tool.Function.Arguments)
+ continue
+ }
+ }
+
+ // 兼容 Function Call
+ fun := responseBody.Choices[0].Delta.FunctionCall
+ if fun.Name != "" {
+ tool = *new(types.ToolCall)
+ tool.Function.Name = fun.Name
+ } else if toolCall {
+ arguments = append(arguments, fun.Arguments)
+ continue
+ }
+
+ if !utils.IsEmptyValue(tool) {
+ res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
+ if res.Error == nil {
+ toolCall = true
+ callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
+ pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
+ "type": "text",
+ "content": callMsg,
+ })
+ contents = append(contents, callMsg)
+ }
+ continue
+ }
+
+ if responseBody.Choices[0].FinishReason == "tool_calls" ||
+ responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
+ break
+ }
+
+ // output stopped
+ if responseBody.Choices[0].FinishReason != "" {
+ break // 输出完成或者输出中断了
+ } else { // 正常输出结果
+ // 兼容思考过程
+ if responseBody.Choices[0].Delta.ReasoningContent != "" {
+ reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
+ if !reasoning {
+ reasoningContent = fmt.Sprintf("%s", reasoningContent)
+ reasoning = true
+ }
+
+ pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
+ "type": "text",
+ "content": reasoningContent,
+ })
+ contents = append(contents, reasoningContent)
+ } else if responseBody.Choices[0].Delta.Content != "" {
+ finalContent := responseBody.Choices[0].Delta.Content
+ if reasoning {
+ finalContent = fmt.Sprintf("%s", responseBody.Choices[0].Delta.Content)
+ reasoning = false
+ }
+ contents = append(contents, utils.InterfaceToString(finalContent))
+ pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
+ "type": "text",
+ "content": finalContent,
+ })
+ }
+ }
+ } // end for
+
+ if err := scanner.Err(); err != nil {
+ if strings.Contains(err.Error(), "context canceled") {
+ logger.Info("用户取消了请求:", prompt)
+ } else {
+ logger.Error("信息读取出错:", err)
+ }
+ }
+
+ if toolCall { // 调用函数完成任务
+ params := make(map[string]any)
+ _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
+ logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
+ params["user_id"] = userVo.Id
+ var apiRes types.BizVo
+ r, err := req2.C().R().SetHeader("Body-Type", "application/json").
+ SetHeader("Authorization", function.Token).
+ SetBody(params).Post(function.Action)
+ errMsg := ""
+ if err != nil {
+ errMsg = err.Error()
+ } else {
+ all, _ := io.ReadAll(r.Body)
+ err = json.Unmarshal(all, &apiRes)
+ if err != nil {
+ errMsg = err.Error()
+ } else if apiRes.Code != types.Success {
+ errMsg = apiRes.Message
+ }
+ }
+
+ if errMsg != "" {
+ errMsg = "调用函数工具出错:" + errMsg
+ contents = append(contents, errMsg)
+ } else {
+ errMsg = utils.InterfaceToString(apiRes.Data)
+ contents = append(contents, errMsg)
+ }
+ pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
+ "type": "text",
+ "content": errMsg,
+ })
+ }
+
+ // 消息发送成功
+ if len(contents) > 0 {
+ usage := Usage{
+ Prompt: prompt,
+ Content: strings.Join(contents, ""),
+ PromptTokens: 0,
+ CompletionTokens: 0,
+ TotalTokens: 0,
+ }
+ message.Content = usage.Content
+ h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
+ }
+ } else {
+ var respVo OpenAIResVo
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ return fmt.Errorf("读取响应失败:%v", body)
+ }
+ err = json.Unmarshal(body, &respVo)
+ if err != nil {
+ return fmt.Errorf("解析响应失败:%v", body)
+ }
+ content := respVo.Choices[0].Message.Content
+ if strings.HasPrefix(req.Model, "o1-") {
+ content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
+ }
+ pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
+ "type": "text",
+ "content": content,
+ })
+ respVo.Usage.Prompt = prompt
+ respVo.Usage.Content = content
+ h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
+ }
+
+ return nil
+}
diff --git a/api/handler/chat_openai_handler.go b/api/handler/chat_openai_handler.go
index fea1a1e3..00c13093 100644
--- a/api/handler/chat_openai_handler.go
+++ b/api/handler/chat_openai_handler.go
@@ -1,253 +1,271 @@
package handler
-// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-// * Copyright 2023 The Geek-AI Authors. All rights reserved.
-// * Use of this source code is governed by a Apache-2.0 license
-// * that can be found in the LICENSE file.
-// * @Author yangjian102621@163.com
-// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+// // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+// // * Copyright 2023 The Geek-AI Authors. All rights reserved.
+// // * Use of this source code is governed by a Apache-2.0 license
+// // * that can be found in the LICENSE file.
+// // * @Author yangjian102621@163.com
+// // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-import (
- "bufio"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "geekai/core/types"
- "geekai/store/model"
- "geekai/store/vo"
- "geekai/utils"
- "io"
- "strings"
- "time"
+// import (
+// "bufio"
+// "context"
+// "encoding/json"
+// "errors"
+// "fmt"
+// "geekai/core/types"
+// "geekai/store/model"
+// "geekai/store/vo"
+// "geekai/utils"
+// "io"
+// "strings"
+// "time"
- req2 "github.com/imroc/req/v3"
-)
+// req2 "github.com/imroc/req/v3"
+// )
-type Usage struct {
- Prompt string `json:"prompt,omitempty"`
- Content string `json:"content,omitempty"`
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
+// type Usage struct {
+// Prompt string `json:"prompt,omitempty"`
+// Content string `json:"content,omitempty"`
+// PromptTokens int `json:"prompt_tokens"`
+// CompletionTokens int `json:"completion_tokens"`
+// TotalTokens int `json:"total_tokens"`
+// }
-type OpenAIResVo struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int `json:"created"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- Choices []struct {
- Index int `json:"index"`
- Message struct {
- Role string `json:"role"`
- Content string `json:"content"`
- } `json:"message"`
- Logprobs interface{} `json:"logprobs"`
- FinishReason string `json:"finish_reason"`
- } `json:"choices"`
- Usage Usage `json:"usage"`
-}
+// type OpenAIResVo struct {
+// Id string `json:"id"`
+// Object string `json:"object"`
+// Created int `json:"created"`
+// Model string `json:"model"`
+// SystemFingerprint string `json:"system_fingerprint"`
+// Choices []struct {
+// Index int `json:"index"`
+// Message struct {
+// Role string `json:"role"`
+// Content string `json:"content"`
+// } `json:"message"`
+// Logprobs interface{} `json:"logprobs"`
+// FinishReason string `json:"finish_reason"`
+// } `json:"choices"`
+// Usage Usage `json:"usage"`
+// }
-// OPenAI 消息发送实现
-func (h *ChatHandler) sendOpenAiMessage(
- req types.ApiRequest,
- userVo vo.User,
- ctx context.Context,
- session *types.ChatSession,
- role model.ChatRole,
- prompt string,
- ws *types.WsClient) error {
- promptCreatedAt := time.Now() // 记录提问时间
- start := time.Now()
- var apiKey = model.ApiKey{}
- response, err := h.doRequest(ctx, req, session, &apiKey)
- logger.Info("HTTP请求完成,耗时:", time.Since(start))
- if err != nil {
- if strings.Contains(err.Error(), "context canceled") {
- return fmt.Errorf("用户取消了请求:%s", prompt)
- } else if strings.Contains(err.Error(), "no available key") {
- return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
- }
- return err
- } else {
- defer response.Body.Close()
- }
+// // OPenAI 消息发送实现
+// func (h *ChatHandler) sendOpenAiMessage(
+// req types.ApiRequest,
+// userVo vo.User,
+// ctx context.Context,
+// session *types.ChatSession,
+// role model.ChatRole,
+// prompt string,
+// messageChan chan interface{}) error {
+// promptCreatedAt := time.Now() // 记录提问时间
+// start := time.Now()
+// var apiKey = model.ApiKey{}
+// response, err := h.doRequest(ctx, req, session, &apiKey)
+// logger.Info("HTTP请求完成,耗时:", time.Since(start))
+// if err != nil {
+// if strings.Contains(err.Error(), "context canceled") {
+// return fmt.Errorf("用户取消了请求:%s", prompt)
+// } else if strings.Contains(err.Error(), "no available key") {
+// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
+// }
+// return err
+// } else {
+// defer response.Body.Close()
+// }
- if response.StatusCode != 200 {
- body, _ := io.ReadAll(response.Body)
- return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
- }
+// if response.StatusCode != 200 {
+// body, _ := io.ReadAll(response.Body)
+// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
+// }
- contentType := response.Header.Get("Content-Type")
- if strings.Contains(contentType, "text/event-stream") {
- replyCreatedAt := time.Now() // 记录回复时间
- // 循环读取 Chunk 消息
- var message = types.Message{Role: "assistant"}
- var contents = make([]string, 0)
- var function model.Function
- var toolCall = false
- var arguments = make([]string, 0)
- var reasoning = false
+// contentType := response.Header.Get("Content-Type")
+// if strings.Contains(contentType, "text/event-stream") {
+// replyCreatedAt := time.Now() // 记录回复时间
+// // 循环读取 Chunk 消息
+// var message = types.Message{Role: "assistant"}
+// var contents = make([]string, 0)
+// var function model.Function
+// var toolCall = false
+// var arguments = make([]string, 0)
+// var reasoning = false
- scanner := bufio.NewScanner(response.Body)
- for scanner.Scan() {
- line := scanner.Text()
- if !strings.Contains(line, "data:") || len(line) < 30 {
- continue
- }
- var responseBody = types.ApiResponse{}
- err = json.Unmarshal([]byte(line[6:]), &responseBody)
- if err != nil { // 数据解析出错
- return errors.New(line)
- }
- if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
- continue
- }
- if responseBody.Choices[0].Delta.Content == nil &&
- responseBody.Choices[0].Delta.ToolCalls == nil &&
- responseBody.Choices[0].Delta.ReasoningContent == "" {
- continue
- }
+// scanner := bufio.NewScanner(response.Body)
+// for scanner.Scan() {
+// line := scanner.Text()
+// if !strings.Contains(line, "data:") || len(line) < 30 {
+// continue
+// }
+// var responseBody = types.ApiResponse{}
+// err = json.Unmarshal([]byte(line[6:]), &responseBody)
+// if err != nil { // 数据解析出错
+// return errors.New(line)
+// }
+// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
+// continue
+// }
+// if responseBody.Choices[0].Delta.Content == nil &&
+// responseBody.Choices[0].Delta.ToolCalls == nil &&
+// responseBody.Choices[0].Delta.ReasoningContent == "" {
+// continue
+// }
- if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
- utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
- break
- }
+// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
+// messageChan <- map[string]interface{}{
+// "type": "text",
+// "body": "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。",
+// }
+// break
+// }
- var tool types.ToolCall
- if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
- tool = responseBody.Choices[0].Delta.ToolCalls[0]
- if toolCall && tool.Function.Name == "" {
- arguments = append(arguments, tool.Function.Arguments)
- continue
- }
- }
+// var tool types.ToolCall
+// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
+// tool = responseBody.Choices[0].Delta.ToolCalls[0]
+// if toolCall && tool.Function.Name == "" {
+// arguments = append(arguments, tool.Function.Arguments)
+// continue
+// }
+// }
- // 兼容 Function Call
- fun := responseBody.Choices[0].Delta.FunctionCall
- if fun.Name != "" {
- tool = *new(types.ToolCall)
- tool.Function.Name = fun.Name
- } else if toolCall {
- arguments = append(arguments, fun.Arguments)
- continue
- }
+// // 兼容 Function Call
+// fun := responseBody.Choices[0].Delta.FunctionCall
+// if fun.Name != "" {
+// tool = *new(types.ToolCall)
+// tool.Function.Name = fun.Name
+// } else if toolCall {
+// arguments = append(arguments, fun.Arguments)
+// continue
+// }
- if !utils.IsEmptyValue(tool) {
- res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
- if res.Error == nil {
- toolCall = true
- callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
- utils.SendChunkMsg(ws, callMsg)
- contents = append(contents, callMsg)
- }
- continue
- }
+// if !utils.IsEmptyValue(tool) {
+// res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
+// if res.Error == nil {
+// toolCall = true
+// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
+// messageChan <- map[string]interface{}{
+// "type": "text",
+// "body": callMsg,
+// }
+// contents = append(contents, callMsg)
+// }
+// continue
+// }
- if responseBody.Choices[0].FinishReason == "tool_calls" ||
- responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
- break
- }
+// if responseBody.Choices[0].FinishReason == "tool_calls" ||
+// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
+// break
+// }
- // output stopped
- if responseBody.Choices[0].FinishReason != "" {
- break // 输出完成或者输出中断了
- } else { // 正常输出结果
- // 兼容思考过程
- if responseBody.Choices[0].Delta.ReasoningContent != "" {
- reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
- if !reasoning {
- reasoningContent = fmt.Sprintf("%s", reasoningContent)
- reasoning = true
- }
+// // output stopped
+// if responseBody.Choices[0].FinishReason != "" {
+// break // 输出完成或者输出中断了
+// } else { // 正常输出结果
+// // 兼容思考过程
+// if responseBody.Choices[0].Delta.ReasoningContent != "" {
+// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
+// if !reasoning {
+// reasoningContent = fmt.Sprintf("%s", reasoningContent)
+// reasoning = true
+// }
- utils.SendChunkMsg(ws, reasoningContent)
- contents = append(contents, reasoningContent)
- } else if responseBody.Choices[0].Delta.Content != "" {
- finalContent := responseBody.Choices[0].Delta.Content
- if reasoning {
- finalContent = fmt.Sprintf("%s", responseBody.Choices[0].Delta.Content)
- reasoning = false
- }
- contents = append(contents, utils.InterfaceToString(finalContent))
- utils.SendChunkMsg(ws, finalContent)
- }
- }
- } // end for
+// messageChan <- map[string]interface{}{
+// "type": "text",
+// "body": reasoningContent,
+// }
+// contents = append(contents, reasoningContent)
+// } else if responseBody.Choices[0].Delta.Content != "" {
+// finalContent := responseBody.Choices[0].Delta.Content
+// if reasoning {
+// finalContent = fmt.Sprintf("%s", responseBody.Choices[0].Delta.Content)
+// reasoning = false
+// }
+// contents = append(contents, utils.InterfaceToString(finalContent))
+// messageChan <- map[string]interface{}{
+// "type": "text",
+// "body": finalContent,
+// }
+// }
+// }
+// } // end for
- if err := scanner.Err(); err != nil {
- if strings.Contains(err.Error(), "context canceled") {
- logger.Info("用户取消了请求:", prompt)
- } else {
- logger.Error("信息读取出错:", err)
- }
- }
+// if err := scanner.Err(); err != nil {
+// if strings.Contains(err.Error(), "context canceled") {
+// logger.Info("用户取消了请求:", prompt)
+// } else {
+// logger.Error("信息读取出错:", err)
+// }
+// }
- if toolCall { // 调用函数完成任务
- params := make(map[string]any)
- _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
- logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
- params["user_id"] = userVo.Id
- var apiRes types.BizVo
- r, err := req2.C().R().SetHeader("Body-Type", "application/json").
- SetHeader("Authorization", function.Token).
- SetBody(params).Post(function.Action)
- errMsg := ""
- if err != nil {
- errMsg = err.Error()
- } else {
- all, _ := io.ReadAll(r.Body)
- err = json.Unmarshal(all, &apiRes)
- if err != nil {
- errMsg = err.Error()
- } else if apiRes.Code != types.Success {
- errMsg = apiRes.Message
- }
- }
+// if toolCall { // 调用函数完成任务
+// params := make(map[string]any)
+// _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
+// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
+// params["user_id"] = userVo.Id
+// var apiRes types.BizVo
+// r, err := req2.C().R().SetHeader("Body-Type", "application/json").
+// SetHeader("Authorization", function.Token).
+// SetBody(params).Post(function.Action)
+// errMsg := ""
+// if err != nil {
+// errMsg = err.Error()
+// } else {
+// all, _ := io.ReadAll(r.Body)
+// err = json.Unmarshal(all, &apiRes)
+// if err != nil {
+// errMsg = err.Error()
+// } else if apiRes.Code != types.Success {
+// errMsg = apiRes.Message
+// }
+// }
- if errMsg != "" {
- errMsg = "调用函数工具出错:" + errMsg
- contents = append(contents, errMsg)
- } else {
- errMsg = utils.InterfaceToString(apiRes.Data)
- contents = append(contents, errMsg)
- }
- utils.SendChunkMsg(ws, errMsg)
- }
+// if errMsg != "" {
+// errMsg = "调用函数工具出错:" + errMsg
+// contents = append(contents, errMsg)
+// } else {
+// errMsg = utils.InterfaceToString(apiRes.Data)
+// contents = append(contents, errMsg)
+// }
+// messageChan <- map[string]interface{}{
+// "type": "text",
+// "body": errMsg,
+// }
+// }
- // 消息发送成功
- if len(contents) > 0 {
- usage := Usage{
- Prompt: prompt,
- Content: strings.Join(contents, ""),
- PromptTokens: 0,
- CompletionTokens: 0,
- TotalTokens: 0,
- }
- message.Content = usage.Content
- h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
- }
- } else { // 非流式输出
- var respVo OpenAIResVo
- body, err := io.ReadAll(response.Body)
- if err != nil {
- return fmt.Errorf("读取响应失败:%v", body)
- }
- err = json.Unmarshal(body, &respVo)
- if err != nil {
- return fmt.Errorf("解析响应失败:%v", body)
- }
- content := respVo.Choices[0].Message.Content
- if strings.HasPrefix(req.Model, "o1-") {
- content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
- }
- utils.SendChunkMsg(ws, content)
- respVo.Usage.Prompt = prompt
- respVo.Usage.Content = content
- h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
- }
+// // 消息发送成功
+// if len(contents) > 0 {
+// usage := Usage{
+// Prompt: prompt,
+// Content: strings.Join(contents, ""),
+// PromptTokens: 0,
+// CompletionTokens: 0,
+// TotalTokens: 0,
+// }
+// message.Content = usage.Content
+// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
+// }
+// } else { // 非流式输出
+// var respVo OpenAIResVo
+// body, err := io.ReadAll(response.Body)
+// if err != nil {
+// return fmt.Errorf("读取响应失败:%v", body)
+// }
+// err = json.Unmarshal(body, &respVo)
+// if err != nil {
+// return fmt.Errorf("解析响应失败:%v", body)
+// }
+// content := respVo.Choices[0].Message.Content
+// if strings.HasPrefix(req.Model, "o1-") {
+// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
+// }
+// messageChan <- map[string]interface{}{
+// "type": "text",
+// "body": content,
+// }
+// respVo.Usage.Prompt = prompt
+// respVo.Usage.Content = content
+// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
+// }
- return nil
-}
+// return nil
+// }
diff --git a/api/handler/ws_handler.go b/api/handler/ws_handler.go
deleted file mode 100644
index 958a1266..00000000
--- a/api/handler/ws_handler.go
+++ /dev/null
@@ -1,153 +0,0 @@
-package handler
-
-// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-// * Copyright 2023 The Geek-AI Authors. All rights reserved.
-// * Use of this source code is governed by a Apache-2.0 license
-// * that can be found in the LICENSE file.
-// * @Author yangjian102621@163.com
-// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-
-import (
- "context"
- "geekai/core"
- "geekai/core/types"
- "geekai/service"
- "geekai/store/model"
- "geekai/utils"
- "net/http"
- "strings"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "gorm.io/gorm"
-)
-
-// Websocket 连接处理 handler
-
-type WebsocketHandler struct {
- BaseHandler
- wsService *service.WebsocketService
- chatHandler *ChatHandler
-}
-
-func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
- return &WebsocketHandler{
- BaseHandler: BaseHandler{App: app, DB: db},
- chatHandler: chatHandler,
- wsService: s,
- }
-}
-
-func (h *WebsocketHandler) Client(c *gin.Context) {
- clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
- ws, err := (&websocket.Upgrader{
- CheckOrigin: func(r *http.Request) bool { return true },
- Subprotocols: strings.Split(clientProtocols, ","),
- }).Upgrade(c.Writer, c.Request, nil)
- if err != nil {
- logger.Error(err)
- c.Abort()
- return
- }
-
- clientId := c.Query("client_id")
- client := types.NewWsClient(ws, clientId)
- userId := h.GetLoginUserId(c)
- if userId == 0 {
- _ = client.Send([]byte("Invalid user_id"))
- c.Abort()
- return
- }
- var user model.User
- if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
- _ = client.Send([]byte("Invalid user_id"))
- c.Abort()
- return
- }
-
- h.wsService.Clients.Put(clientId, client)
- logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
- go func() {
- for {
- _, msg, err := client.Receive()
- if err != nil {
- logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
- client.Close()
- h.wsService.Clients.Delete(clientId)
- break
- }
-
- var message types.InputMessage
- err = utils.JsonDecode(string(msg), &message)
- if err != nil {
- continue
- }
-
- logger.Debugf("Receive a message:%+v", message)
- if message.Type == types.MsgTypePing {
- utils.SendChannelMsg(client, types.ChPing, "pong")
- continue
- }
-
- // 当前只处理聊天消息,其他消息全部丢弃
- var chatMessage types.ChatMessage
- err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
- if err != nil || message.Channel != types.ChChat {
- logger.Warnf("invalid message body:%+v", message.Body)
- continue
- }
- var chatRole model.ChatRole
- err = h.DB.First(&chatRole, chatMessage.RoleId).Error
- if err != nil || !chatRole.Enable {
- utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
- continue
- }
- // if the role bind a model_id, use role's bind model_id
- if chatRole.ModelId > 0 {
- chatMessage.ModelId = int(chatRole.ModelId)
- }
- // get model info
- var chatModel model.ChatModel
- err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
- if err != nil || !chatModel.Enabled {
- utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
- continue
- }
-
- session := &types.ChatSession{
- ClientIP: c.ClientIP(),
- UserId: userId,
- }
-
- // use old chat data override the chat model and role ID
- var chat model.ChatItem
- h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
- if chat.Id > 0 {
- chatModel.Id = chat.ModelId
- chatMessage.RoleId = int(chat.RoleId)
- }
-
- session.ChatId = chatMessage.ChatId
- session.Tools = chatMessage.Tools
- session.Stream = chatMessage.Stream
- session.Model.KeyId = chatMessage.ModelId
- // 复制模型数据
- err = utils.CopyObject(chatModel, &session.Model)
- if err != nil {
- logger.Error(err, chatModel)
- }
- session.Model.Id = chatModel.Id
- ctx, cancel := context.WithCancel(context.Background())
- h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
- err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
- if err != nil {
- logger.Error(err)
- utils.SendAndFlush(client, err.Error())
- } else {
- utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
- logger.Infof("回答完毕: %v", message.Body)
- }
-
- }
- }()
-}
diff --git a/api/main.go b/api/main.go
index 47910b97..3c6db9d2 100644
--- a/api/main.go
+++ b/api/main.go
@@ -243,6 +243,7 @@ func main() {
}),
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
group := s.Engine.Group("/api/chat/")
+ group.Any("message", h.Chat)
group.GET("list", h.List)
group.GET("detail", h.Detail)
group.POST("update", h.Update)
@@ -515,10 +516,6 @@ func main() {
group.Any("sse", h.PostTest, h.SseTest)
}),
fx.Provide(service.NewWebsocketService),
- fx.Provide(handler.NewWebsocketHandler),
- fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
- s.Engine.Any("/api/ws", h.Client)
- }),
fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt")
diff --git a/web/package.json b/web/package.json
index 0ab84fd6..8c40a18c 100644
--- a/web/package.json
+++ b/web/package.json
@@ -45,8 +45,8 @@
"vue": "^3.2.13",
"vue-router": "^4.0.15",
"unplugin-auto-import": "^0.18.5",
+ "@microsoft/fetch-event-source": "^2.0.1",
"vue-waterfall-plugin-next": "^2.6.5"
-
},
"devDependencies": {
"@vitejs/plugin-vue": "^5.2.4",
diff --git a/web/src/App.vue b/web/src/App.vue
index cb690973..2875a867 100644
--- a/web/src/App.vue
+++ b/web/src/App.vue
@@ -5,13 +5,12 @@