mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
refactor chat message body struct
This commit is contained in:
parent
866564370d
commit
131efd6ba5
@ -3,6 +3,8 @@
|
|||||||
* 功能优化:用户文件列表组件增加分页功能支持
|
* 功能优化:用户文件列表组件增加分页功能支持
|
||||||
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
|
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
|
||||||
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
|
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
|
||||||
|
* 功能新增:给 AI 应用(角色)增加分类
|
||||||
|
* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
|
||||||
|
|
||||||
## v4.1.3
|
## v4.1.3
|
||||||
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
|
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
|
||||||
|
@ -9,12 +9,12 @@ package types
|
|||||||
|
|
||||||
// ApiRequest API 请求实体
|
// ApiRequest API 请求实体
|
||||||
type ApiRequest struct {
|
type ApiRequest struct {
|
||||||
Model string `json:"model,omitempty"` // 兼容百度文心一言
|
Model string `json:"model,omitempty"`
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Stream bool `json:"stream"`
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
Messages []interface{} `json:"messages,omitempty"`
|
Messages []interface{} `json:"messages,omitempty"`
|
||||||
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||||
|
|
||||||
@ -57,7 +57,8 @@ type ChatSession struct {
|
|||||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||||
Model ChatModel `json:"model"` // GPT 模型
|
Model ChatModel `json:"model"` // GPT 模型
|
||||||
Tools string `json:"tools"` // 函数
|
Tools []int `json:"tools"` // 工具函数列表
|
||||||
|
Stream bool `json:"stream"` // 是否采用流式输出
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatModel struct {
|
type ChatModel struct {
|
||||||
|
@ -17,8 +17,8 @@ type BizVo struct {
|
|||||||
Data interface{} `json:"data,omitempty"`
|
Data interface{} `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// WsMessage Websocket message
|
// ReplyMessage 对话回复消息结构
|
||||||
type WsMessage struct {
|
type ReplyMessage struct {
|
||||||
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
||||||
Content interface{} `json:"content"`
|
Content interface{} `json:"content"`
|
||||||
}
|
}
|
||||||
@ -32,6 +32,13 @@ const (
|
|||||||
WsErr = WsMsgType("error")
|
WsErr = WsMsgType("error")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// InputMessage 对话输入消息结构
|
||||||
|
type InputMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Tools []int `json:"tools"` // 允许调用工具列表
|
||||||
|
Stream bool `json:"stream"` // 是否采用流式输出
|
||||||
|
}
|
||||||
|
|
||||||
type BizCode int
|
type BizCode int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -73,13 +73,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
roleId := h.GetInt(c, "role_id", 0)
|
roleId := h.GetInt(c, "role_id", 0)
|
||||||
chatId := c.Query("chat_id")
|
chatId := c.Query("chat_id")
|
||||||
modelId := h.GetInt(c, "model_id", 0)
|
modelId := h.GetInt(c, "model_id", 0)
|
||||||
tools := c.Query("tools")
|
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
client := types.NewWsClient(ws)
|
||||||
var chatRole model.ChatRole
|
var chatRole model.ChatRole
|
||||||
res := h.DB.First(&chatRole, roleId)
|
res := h.DB.First(&chatRole, roleId)
|
||||||
if res.Error != nil || !chatRole.Enable {
|
if res.Error != nil || !chatRole.Enable {
|
||||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -91,7 +90,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
res = h.DB.First(&chatModel, modelId)
|
res = h.DB.First(&chatModel, modelId)
|
||||||
if res.Error != nil || chatModel.Enabled == false {
|
if res.Error != nil || chatModel.Enabled == false {
|
||||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
utils.ReplyErrorMessage(client, "当前AI模型暂未启用,对话已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -100,7 +99,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
SessionId: sessionId,
|
SessionId: sessionId,
|
||||||
ClientIP: c.ClientIP(),
|
ClientIP: c.ClientIP(),
|
||||||
UserId: h.GetLoginUserId(c),
|
UserId: h.GetLoginUserId(c),
|
||||||
Tools: tools,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// use old chat data override the chat model and role ID
|
// use old chat data override the chat model and role ID
|
||||||
@ -137,20 +135,16 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var message types.WsMessage
|
var message types.InputMessage
|
||||||
err = utils.JsonDecode(string(msg), &message)
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 心跳消息
|
logger.Infof("Receive a message:%+v", message)
|
||||||
if message.Type == "heartbeat" {
|
|
||||||
logger.Debug("收到 Chat 心跳消息:", message.Content)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Receive a message: ", message.Content)
|
|
||||||
|
|
||||||
|
session.Tools = message.Tools
|
||||||
|
session.Stream = message.Stream
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
h.ReqCancelFunc.Put(sessionId, cancel)
|
h.ReqCancelFunc.Put(sessionId, cancel)
|
||||||
// 回复消息
|
// 回复消息
|
||||||
@ -159,7 +153,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
utils.ReplyMessage(client, err.Error())
|
utils.ReplyMessage(client, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
|
||||||
logger.Infof("回答完毕: %v", message.Content)
|
logger.Infof("回答完毕: %v", message.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,15 +203,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
|
|
||||||
var req = types.ApiRequest{
|
var req = types.ApiRequest{
|
||||||
Model: session.Model.Value,
|
Model: session.Model.Value,
|
||||||
Stream: true,
|
Temperature: session.Model.Temperature,
|
||||||
}
|
}
|
||||||
req.Temperature = session.Model.Temperature
|
// 兼容 GPT-O1 模型
|
||||||
|
if strings.HasPrefix(session.Model.Value, "o1-") {
|
||||||
|
req.MaxCompletionTokens = session.Model.MaxTokens
|
||||||
|
req.Stream = false
|
||||||
|
} else {
|
||||||
req.MaxTokens = session.Model.MaxTokens
|
req.MaxTokens = session.Model.MaxTokens
|
||||||
|
req.Stream = session.Stream
|
||||||
|
}
|
||||||
|
|
||||||
if session.Tools != "" {
|
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
|
||||||
toolIds := strings.Split(session.Tools, ",")
|
|
||||||
var items []model.Function
|
var items []model.Function
|
||||||
res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
|
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
var tools = make([]types.Tool, 0)
|
var tools = make([]types.Tool, 0)
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
@ -279,7 +278,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
v := messages[i]
|
v := messages[i]
|
||||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
tks, _ = utils.CalcTokens(v.Content, req.Model)
|
||||||
// 上下文 token 超出了模型的最大上下文长度
|
// 上下文 token 超出了模型的最大上下文长度
|
||||||
if tokens+tks >= session.Model.MaxContext {
|
if tokens+tks >= session.Model.MaxContext {
|
||||||
break
|
break
|
||||||
@ -500,10 +499,17 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
Prompt string
|
||||||
|
Content string
|
||||||
|
PromptTokens int
|
||||||
|
CompletionTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) saveChatHistory(
|
func (h *ChatHandler) saveChatHistory(
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
prompt string,
|
usage Usage,
|
||||||
contents []string,
|
|
||||||
message types.Message,
|
message types.Message,
|
||||||
chatCtx []types.Message,
|
chatCtx []types.Message,
|
||||||
session *types.ChatSession,
|
session *types.ChatSession,
|
||||||
@ -514,8 +520,8 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
if message.Role == "" {
|
if message.Role == "" {
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
}
|
}
|
||||||
message.Content = strings.Join(contents, "")
|
message.Content = usage.Content
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: usage.Prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.SysConfig.EnableContext {
|
if h.App.SysConfig.EnableContext {
|
||||||
@ -526,32 +532,41 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
// for prompt
|
// for prompt
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
var promptTokens, replyTokens, totalTokens int
|
||||||
if err != nil {
|
if usage.PromptTokens > 0 {
|
||||||
logger.Error(err)
|
promptTokens = usage.PromptTokens
|
||||||
|
} else {
|
||||||
|
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
historyUserMsg := model.ChatMessage{
|
historyUserMsg := model.ChatMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
ChatId: session.ChatId,
|
ChatId: session.ChatId,
|
||||||
RoleId: role.Id,
|
RoleId: role.Id,
|
||||||
Type: types.PromptMsg,
|
Type: types.PromptMsg,
|
||||||
Icon: userVo.Avatar,
|
Icon: userVo.Avatar,
|
||||||
Content: template.HTMLEscapeString(prompt),
|
Content: template.HTMLEscapeString(usage.Prompt),
|
||||||
Tokens: promptToken,
|
Tokens: promptTokens,
|
||||||
|
TotalTokens: promptTokens,
|
||||||
UseContext: true,
|
UseContext: true,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
}
|
}
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
err = h.DB.Save(&historyUserMsg).Error
|
err := h.DB.Save(&historyUserMsg).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to save prompt history message: ", err)
|
logger.Error("failed to save prompt history message: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// for reply
|
// for reply
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
if usage.CompletionTokens > 0 {
|
||||||
totalTokens := replyTokens + getTotalTokens(req)
|
replyTokens = usage.CompletionTokens
|
||||||
|
totalTokens = usage.TotalTokens
|
||||||
|
} else {
|
||||||
|
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||||
|
totalTokens = replyTokens + getTotalTokens(req)
|
||||||
|
}
|
||||||
historyReplyMsg := model.ChatMessage{
|
historyReplyMsg := model.ChatMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
ChatId: session.ChatId,
|
ChatId: session.ChatId,
|
||||||
@ -559,7 +574,8 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
Type: types.ReplyMsg,
|
Type: types.ReplyMsg,
|
||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
Tokens: totalTokens,
|
Tokens: replyTokens,
|
||||||
|
TotalTokens: totalTokens,
|
||||||
UseContext: true,
|
UseContext: true,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
}
|
}
|
||||||
@ -572,7 +588,7 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
|
|
||||||
// 更新用户算力
|
// 更新用户算力
|
||||||
if session.Model.Power > 0 {
|
if session.Model.Power > 0 {
|
||||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
h.subUserPower(userVo, session, promptTokens, replyTokens)
|
||||||
}
|
}
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
var chatItem model.ChatItem
|
var chatItem model.ChatItem
|
||||||
@ -582,10 +598,10 @@ func (h *ChatHandler) saveChatHistory(
|
|||||||
chatItem.UserId = userVo.Id
|
chatItem.UserId = userVo.Id
|
||||||
chatItem.RoleId = role.Id
|
chatItem.RoleId = role.Id
|
||||||
chatItem.ModelId = session.Model.Id
|
chatItem.ModelId = session.Model.Id
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
if utf8.RuneCountInString(usage.Prompt) > 30 {
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
|
||||||
} else {
|
} else {
|
||||||
chatItem.Title = prompt
|
chatItem.Title = usage.Prompt
|
||||||
}
|
}
|
||||||
chatItem.Model = req.Model
|
chatItem.Model = req.Model
|
||||||
err = h.DB.Create(&chatItem).Error
|
err = h.DB.Create(&chatItem).Error
|
||||||
|
@ -23,6 +23,28 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type respVo struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Choices []struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
Logprobs interface{} `json:"logprobs"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
// OPenAI 消息发送实现
|
// OPenAI 消息发送实现
|
||||||
func (h *ChatHandler) sendOpenAiMessage(
|
func (h *ChatHandler) sendOpenAiMessage(
|
||||||
chatCtx []types.Message,
|
chatCtx []types.Message,
|
||||||
@ -49,6 +71,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if response.StatusCode != 200 {
|
||||||
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body)
|
||||||
|
}
|
||||||
contentType := response.Header.Get("Content-Type")
|
contentType := response.Header.Get("Content-Type")
|
||||||
if strings.Contains(contentType, "text/event-stream") {
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
replyCreatedAt := time.Now() // 记录回复时间
|
replyCreatedAt := time.Now() // 记录回复时间
|
||||||
@ -106,8 +132,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
toolCall = true
|
toolCall = true
|
||||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
|
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg})
|
||||||
contents = append(contents, callMsg)
|
contents = append(contents, callMsg)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@ -125,10 +151,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
content := responseBody.Choices[0].Delta.Content
|
content := responseBody.Choices[0].Delta.Content
|
||||||
contents = append(contents, utils.InterfaceToString(content))
|
contents = append(contents, utils.InterfaceToString(content))
|
||||||
if isNew {
|
if isNew {
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
|
||||||
isNew = false
|
isNew = false
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
})
|
})
|
||||||
@ -161,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
}
|
}
|
||||||
if errMsg != "" || apiRes.Code != types.Success {
|
if errMsg != "" || apiRes.Code != types.Success {
|
||||||
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: msg,
|
Content: msg,
|
||||||
})
|
})
|
||||||
contents = append(contents, msg)
|
contents = append(contents, msg)
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: apiRes.Data,
|
Content: apiRes.Data,
|
||||||
})
|
})
|
||||||
@ -177,11 +203,17 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
usage := Usage{
|
||||||
|
Prompt: prompt,
|
||||||
|
Content: strings.Join(contents, ""),
|
||||||
|
PromptTokens: 0,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: 0,
|
||||||
}
|
}
|
||||||
} else {
|
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||||
body, _ := io.ReadAll(response.Body)
|
}
|
||||||
return fmt.Errorf("请求 OpenAI API 失败:%s", body)
|
} else { // 非流式输出
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -73,7 +73,7 @@ func (h *DallJobHandler) Client(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var message types.WsMessage
|
var message types.ReplyMessage
|
||||||
err = utils.JsonDecode(string(msg), &message)
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
@ -64,7 +64,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var message types.WsMessage
|
var message types.ReplyMessage
|
||||||
err = utils.JsonDecode(string(msg), &message)
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
@ -85,7 +85,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
|
|||||||
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
|
utils.ReplyErrorMessage(client, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -170,16 +170,16 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isNew {
|
if isNew {
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart})
|
||||||
isNew = false
|
isNew = false
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{
|
utils.ReplyChunkMessage(client, types.ReplyMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
})
|
})
|
||||||
} // end for
|
} // end for
|
||||||
|
|
||||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
body, _ := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
@ -11,6 +11,7 @@ type ChatMessage struct {
|
|||||||
Type string
|
Type string
|
||||||
Icon string
|
Icon string
|
||||||
Tokens int
|
Tokens int
|
||||||
|
TotalTokens int // 总 token 消耗
|
||||||
Content string
|
Content string
|
||||||
UseContext bool // 是否可以作为聊天上下文
|
UseContext bool // 是否可以作为聊天上下文
|
||||||
DeletedAt gorm.DeletedAt
|
DeletedAt gorm.DeletedAt
|
||||||
|
@ -5,7 +5,7 @@ import "geekai/core/types"
|
|||||||
type ChatRole struct {
|
type ChatRole struct {
|
||||||
BaseVo
|
BaseVo
|
||||||
Key string `json:"key"` // 角色唯一标识
|
Key string `json:"key"` // 角色唯一标识
|
||||||
Tid uint `json:"tid"`
|
Tid int `json:"tid"`
|
||||||
Name string `json:"name"` // 角色名称
|
Name string `json:"name"` // 角色名称
|
||||||
Context []types.Message `json:"context"` // 角色语料信息
|
Context []types.Message `json:"context"` // 角色语料信息
|
||||||
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
||||||
|
@ -33,9 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
|
|||||||
|
|
||||||
// ReplyMessage 回复客户端一条完整的消息
|
// ReplyMessage 回复客户端一条完整的消息
|
||||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
|
func ReplyMessage(ws *types.WsClient, message interface{}) {
|
||||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
|
||||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message})
|
||||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplyErrorMessage 向客户端发送错误消息
|
||||||
|
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
|
||||||
|
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||||
|
@ -10,3 +10,5 @@ CREATE TABLE `chatgpt_app_types` (
|
|||||||
ALTER TABLE `chatgpt_app_types`ADD PRIMARY KEY (`id`);
|
ALTER TABLE `chatgpt_app_types`ADD PRIMARY KEY (`id`);
|
||||||
ALTER TABLE `chatgpt_app_types` MODIFY `id` int NOT NULL AUTO_INCREMENT;
|
ALTER TABLE `chatgpt_app_types` MODIFY `id` int NOT NULL AUTO_INCREMENT;
|
||||||
ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`;
|
ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`;
|
||||||
|
|
||||||
|
ALTER TABLE `chatgpt_chat_history` ADD `total_tokens` INT NOT NULL COMMENT '消耗总Token长度' AFTER `tokens`;
|
Loading…
Reference in New Issue
Block a user