mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-05 00:33:47 +08:00
refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model
This commit is contained in:
@@ -205,6 +205,7 @@ func needLogin(c *gin.Context) bool {
|
||||
c.Request.URL.Path == "/api/chat/detail" ||
|
||||
c.Request.URL.Path == "/api/chat/list" ||
|
||||
c.Request.URL.Path == "/api/app/list" ||
|
||||
c.Request.URL.Path == "/api/app/type/list" ||
|
||||
c.Request.URL.Path == "/api/app/list/user" ||
|
||||
c.Request.URL.Path == "/api/model/list" ||
|
||||
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||
|
||||
@@ -57,6 +57,7 @@ type ChatSession struct {
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model ChatModel `json:"model"` // GPT 模型
|
||||
Start int64 `json:"start"` // 开始请求时间戳
|
||||
Tools []int `json:"tools"` // 工具函数列表
|
||||
Stream bool `json:"stream"` // 是否采用流式输出
|
||||
}
|
||||
|
||||
@@ -26,10 +26,9 @@ type ReplyMessage struct {
|
||||
type WsMsgType string
|
||||
|
||||
const (
|
||||
WsStart = WsMsgType("start")
|
||||
WsMiddle = WsMsgType("middle")
|
||||
WsEnd = WsMsgType("end")
|
||||
WsErr = WsMsgType("error")
|
||||
WsContent = WsMsgType("content") // 输出内容
|
||||
WsEnd = WsMsgType("end")
|
||||
WsErr = WsMsgType("error")
|
||||
)
|
||||
|
||||
// InputMessage 对话输入消息结构
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -22,7 +23,7 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
||||
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||
var items []model.AppType
|
||||
var appTypes = make([]vo.AppType, 0)
|
||||
err := h.DB.Order("sort_num ASC").Find(&items).Error
|
||||
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -202,15 +202,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: session.Model.Value,
|
||||
Temperature: session.Model.Temperature,
|
||||
Model: session.Model.Value,
|
||||
}
|
||||
// 兼容 GPT-O1 模型
|
||||
if strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
req.MaxCompletionTokens = session.Model.MaxTokens
|
||||
utils.ReplyContent(ws, "AI 正在思考...\n")
|
||||
req.Stream = false
|
||||
session.Start = time.Now().Unix()
|
||||
} else {
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.Stream = session.Stream
|
||||
}
|
||||
|
||||
@@ -449,7 +450,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debugf(utils.JsonEncode(req))
|
||||
logger.Debugf("对话请求消息体:%+v", req)
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 创建 HttpClient 请求对象
|
||||
@@ -499,14 +500,6 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
|
||||
}
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
Prompt string
|
||||
Content string
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
func (h *ChatHandler) saveChatHistory(
|
||||
req types.ApiRequest,
|
||||
usage Usage,
|
||||
@@ -517,12 +510,8 @@ func (h *ChatHandler) saveChatHistory(
|
||||
userVo vo.User,
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = usage.Content
|
||||
useMsg := types.Message{Role: "user", Content: usage.Prompt}
|
||||
|
||||
useMsg := types.Message{Role: "user", Content: usage.Prompt}
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
@@ -573,7 +562,7 @@ func (h *ChatHandler) saveChatHistory(
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Content: usage.Content,
|
||||
Tokens: replyTokens,
|
||||
TotalTokens: totalTokens,
|
||||
UseContext: true,
|
||||
|
||||
@@ -23,7 +23,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type respVo struct {
|
||||
type Usage struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIResVo struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
@@ -38,11 +46,7 @@ type respVo struct {
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// OPenAI 消息发送实现
|
||||
@@ -73,19 +77,19 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
|
||||
if response.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var message = types.Message{Role: "assistant"}
|
||||
var contents = make([]string, 0)
|
||||
var function model.Function
|
||||
var toolCall = false
|
||||
var arguments = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
var isNew = true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
@@ -132,8 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if res.Error == nil {
|
||||
toolCall = true
|
||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg})
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg})
|
||||
contents = append(contents, callMsg)
|
||||
}
|
||||
continue
|
||||
@@ -150,12 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
if isNew {
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
|
||||
isNew = false
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
||||
Type: types.WsMiddle,
|
||||
Type: types.WsContent,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
}
|
||||
@@ -188,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if errMsg != "" || apiRes.Code != types.Success {
|
||||
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
||||
Type: types.WsMiddle,
|
||||
Type: types.WsContent,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
||||
Type: types.WsMiddle,
|
||||
Type: types.WsContent,
|
||||
Content: apiRes.Data,
|
||||
})
|
||||
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
||||
@@ -210,10 +209,27 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
message.Content = usage.Content
|
||||
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
}
|
||||
} else { // 非流式输出
|
||||
|
||||
var respVo OpenAIResVo
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取响应失败:%v", body)
|
||||
}
|
||||
err = json.Unmarshal(body, &respVo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析响应失败:%v", body)
|
||||
}
|
||||
content := respVo.Choices[0].Message.Content
|
||||
if strings.HasPrefix(req.Model, "o1-") {
|
||||
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||
}
|
||||
utils.ReplyMessage(ws, content)
|
||||
respVo.Usage.Prompt = prompt
|
||||
respVo.Usage.Content = content
|
||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -86,6 +86,8 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.ReplyErrorMessage(client, err.Error())
|
||||
} else {
|
||||
utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -148,7 +150,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
// 循环读取 Chunk 消息
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
var isNew = true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
@@ -169,12 +170,8 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
||||
break
|
||||
}
|
||||
|
||||
if isNew {
|
||||
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart})
|
||||
isNew = false
|
||||
}
|
||||
utils.ReplyChunkMessage(client, types.ReplyMessage{
|
||||
Type: types.WsMiddle,
|
||||
Type: types.WsContent,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
} // end for
|
||||
|
||||
@@ -512,6 +512,11 @@ func main() {
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
}),
|
||||
fx.Provide(handler.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/app/type")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Provide(handler.NewTestHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||
group := s.Engine.Group("/api/test")
|
||||
|
||||
@@ -33,11 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
|
||||
|
||||
// ReplyMessage 回复客户端一条完整的消息
|
||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
|
||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
|
||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message})
|
||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
|
||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
|
||||
}
|
||||
|
||||
func ReplyContent(ws *types.WsClient, message interface{}) {
|
||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
|
||||
}
|
||||
|
||||
// ReplyErrorMessage 向客户端发送错误消息
|
||||
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
|
||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})
|
||||
|
||||
Reference in New Issue
Block a user