refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model

This commit is contained in:
RockYang
2024-09-14 17:06:13 +08:00
parent 131efd6ba5
commit 0385e60ce1
18 changed files with 245 additions and 245 deletions

View File

@@ -205,6 +205,7 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/app/list" ||
c.Request.URL.Path == "/api/app/type/list" ||
c.Request.URL.Path == "/api/app/list/user" ||
c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" ||

View File

@@ -57,6 +57,7 @@ type ChatSession struct {
ClientIP string `json:"client_ip"` // 客户端 IP
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
Model ChatModel `json:"model"` // GPT 模型
Start int64 `json:"start"` // 开始请求时间戳
Tools []int `json:"tools"` // 工具函数列表
Stream bool `json:"stream"` // 是否采用流式输出
}

View File

@@ -26,10 +26,9 @@ type ReplyMessage struct {
type WsMsgType string
const (
WsStart = WsMsgType("start")
WsMiddle = WsMsgType("middle")
WsEnd = WsMsgType("end")
WsErr = WsMsgType("error")
WsContent = WsMsgType("content") // 输出内容
WsEnd = WsMsgType("end")
WsErr = WsMsgType("error")
)
// InputMessage 对话输入消息结构

View File

@@ -6,6 +6,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -22,7 +23,7 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType
var appTypes = make([]vo.AppType, 0)
err := h.DB.Order("sort_num ASC").Find(&items).Error
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return

View File

@@ -202,15 +202,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var req = types.ApiRequest{
Model: session.Model.Value,
Temperature: session.Model.Temperature,
Model: session.Model.Value,
}
// 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") {
req.MaxCompletionTokens = session.Model.MaxTokens
utils.ReplyContent(ws, "AI 正在思考...\n")
req.Stream = false
session.Start = time.Now().Unix()
} else {
req.MaxTokens = session.Model.MaxTokens
req.Temperature = session.Model.Temperature
req.Stream = session.Stream
}
@@ -449,7 +450,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
if err != nil {
return nil, err
}
logger.Debugf(utils.JsonEncode(req))
logger.Debugf("对话请求消息体:%+v", req)
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象
@@ -499,14 +500,6 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
}
}
type Usage struct {
Prompt string
Content string
PromptTokens int
CompletionTokens int
TotalTokens int
}
func (h *ChatHandler) saveChatHistory(
req types.ApiRequest,
usage Usage,
@@ -517,12 +510,8 @@ func (h *ChatHandler) saveChatHistory(
userVo vo.User,
promptCreatedAt time.Time,
replyCreatedAt time.Time) {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = usage.Content
useMsg := types.Message{Role: "user", Content: usage.Prompt}
useMsg := types.Message{Role: "user", Content: usage.Prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
@@ -573,7 +562,7 @@ func (h *ChatHandler) saveChatHistory(
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Content: usage.Content,
Tokens: replyTokens,
TotalTokens: totalTokens,
UseContext: true,

View File

@@ -23,7 +23,15 @@ import (
"time"
)
type respVo struct {
type Usage struct {
Prompt string `json:"prompt,omitempty"`
Content string `json:"content,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIResVo struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
@@ -38,11 +46,7 @@ type respVo struct {
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Usage Usage `json:"usage"`
}
// OPenAI 消息发送实现
@@ -73,19 +77,19 @@ func (h *ChatHandler) sendOpenAiMessage(
if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var message = types.Message{Role: "assistant"}
var contents = make([]string, 0)
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
@@ -132,8 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg})
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg})
contents = append(contents, callMsg)
}
continue
@@ -150,12 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
if isNew {
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsMiddle,
Type: types.WsContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
@@ -188,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage(
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsMiddle,
Type: types.WsContent,
Content: msg,
})
contents = append(contents, msg)
} else {
utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsMiddle,
Type: types.WsContent,
Content: apiRes.Data,
})
contents = append(contents, utils.InterfaceToString(apiRes.Data))
@@ -210,10 +209,27 @@ func (h *ChatHandler) sendOpenAiMessage(
CompletionTokens: 0,
TotalTokens: 0,
}
message.Content = usage.Content
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else { // 非流式输出
var respVo OpenAIResVo
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败:%v", body)
}
err = json.Unmarshal(body, &respVo)
if err != nil {
return fmt.Errorf("解析响应失败:%v", body)
}
content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
utils.ReplyMessage(ws, content)
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
}
return nil

View File

@@ -86,6 +86,8 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
if err != nil {
logger.Error(err)
utils.ReplyErrorMessage(client, err.Error())
} else {
utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd})
}
}
@@ -148,7 +150,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
@@ -169,12 +170,8 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
break
}
if isNew {
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.ReplyMessage{
Type: types.WsMiddle,
Type: types.WsContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for

View File

@@ -512,6 +512,11 @@ func main() {
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}),
fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type")
group.GET("list", h.List)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")

View File

@@ -33,11 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
// ReplyMessage 回复客户端一条完整的消息
func ReplyMessage(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message})
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
}
func ReplyContent(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
}
// ReplyErrorMessage 向客户端发送错误消息
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})