mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
refactor: 更改 OpenAI 请求 Body 数据结构,兼容函数调用请求
This commit is contained in:
parent
a5ad9648bf
commit
cc1b56501d
@ -22,8 +22,8 @@ type AppServer struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
Config *types.AppConfig
|
Config *types.AppConfig
|
||||||
Engine *gin.Engine
|
Engine *gin.Engine
|
||||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
||||||
ChatConfig *types.ChatConfig // 聊天配置
|
ChatConfig *types.ChatConfig // 聊天配置
|
||||||
|
|
||||||
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
||||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||||
@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
|||||||
Debug: false,
|
Debug: false,
|
||||||
Config: appConfig,
|
Config: appConfig,
|
||||||
Engine: gin.Default(),
|
Engine: gin.Default(),
|
||||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
ChatContexts: types.NewLMap[string, []interface{}](),
|
||||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
|
@ -2,17 +2,17 @@ package types
|
|||||||
|
|
||||||
// ApiRequest API 请求实体
|
// ApiRequest API 请求实体
|
||||||
type ApiRequest struct {
|
type ApiRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []interface{} `json:"messages"`
|
||||||
|
Functions []Function `json:"functions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
FunctionCall FunctionCall `json:"function_call"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiResponse struct {
|
type ApiResponse struct {
|
||||||
@ -21,8 +21,15 @@ type ApiResponse struct {
|
|||||||
|
|
||||||
// ChoiceItem API 响应实体
|
// ChoiceItem API 响应实体
|
||||||
type ChoiceItem struct {
|
type ChoiceItem struct {
|
||||||
Delta Message `json:"delta"`
|
Delta Delta `json:"delta"`
|
||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Delta struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Content interface{} `json:"content"`
|
||||||
|
FunctionCall FunctionCall `json:"function_call,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatSession 聊天会话对象
|
// ChatSession 聊天会话对象
|
||||||
|
@ -6,13 +6,71 @@ type FunctionCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
Name string
|
Name string `json:"name"`
|
||||||
Description string
|
Description string `json:"description"`
|
||||||
Parameters []Parameter
|
Parameters Parameters `json:"parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Parameter struct {
|
type Parameters struct {
|
||||||
Type string
|
Type string `json:"type"`
|
||||||
Required []string
|
Required []string `json:"required"`
|
||||||
Properties map[string]interface{}
|
Properties map[string]Property `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Property struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var InnerFunctions = []Function{
|
||||||
|
{
|
||||||
|
Name: "zao_bao",
|
||||||
|
Description: "每日早报,获取当天全球的热门新闻事件列表",
|
||||||
|
Parameters: Parameters{
|
||||||
|
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]Property{
|
||||||
|
"text": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "weibo_hot",
|
||||||
|
Description: "新浪微博热搜榜,微博当日热搜榜单",
|
||||||
|
Parameters: Parameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]Property{
|
||||||
|
"text": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
Name: "zhihu_top",
|
||||||
|
Description: "知乎热榜,知乎当日话题讨论榜单",
|
||||||
|
Parameters: Parameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]Property{
|
||||||
|
"text": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var FunctionNameMap = map[string]string{
|
||||||
|
"zao_bao": "每日早报",
|
||||||
|
"weibo_hot": "微博热搜",
|
||||||
|
"zhihu_top": "知乎热榜",
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ type MKey interface {
|
|||||||
string | int
|
string | int
|
||||||
}
|
}
|
||||||
type MValue interface {
|
type MValue interface {
|
||||||
*WsClient | ChatSession | []Message | context.CancelFunc
|
*WsClient | ChatSession | context.CancelFunc | []interface{}
|
||||||
}
|
}
|
||||||
type LMap[K MKey, T MValue] struct {
|
type LMap[K MKey, T MValue] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
Temperature: userVo.ChatConfig.Temperature,
|
Temperature: userVo.ChatConfig.Temperature,
|
||||||
MaxTokens: userVo.ChatConfig.MaxTokens,
|
MaxTokens: userVo.ChatConfig.MaxTokens,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
|
Functions: types.InnerFunctions,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载聊天上下文
|
// 加载聊天上下文
|
||||||
var chatCtx []types.Message
|
var chatCtx []interface{}
|
||||||
if userVo.ChatConfig.EnableContext {
|
if userVo.ChatConfig.EnableContext {
|
||||||
if h.App.ChatContexts.Has(session.ChatId) {
|
if h.App.ChatContexts.Has(session.ChatId) {
|
||||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
||||||
@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
var messages []types.Message
|
var messages []types.Message
|
||||||
err := utils.JsonDecode(role.Context, &messages)
|
err := utils.JsonDecode(role.Context, &messages)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
chatCtx = messages
|
for _, v := range messages {
|
||||||
|
chatCtx = append(chatCtx, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的
|
// TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
|
||||||
var historyMessages []model.HistoryMessage
|
var historyMessages []model.HistoryMessage
|
||||||
res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages)
|
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for _, msg := range historyMessages {
|
for _, msg := range historyMessages {
|
||||||
ms := types.Message{Role: "user", Content: msg.Content}
|
ms := types.Message{Role: "user", Content: msg.Content}
|
||||||
@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
logger.Info("聊天上下文:", chatCtx)
|
logger.Info("聊天上下文:", chatCtx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
req.Messages = append(chatCtx, types.Message{
|
reqMgs := make([]interface{}, 0)
|
||||||
Role: "user",
|
for _, m := range chatCtx {
|
||||||
Content: prompt,
|
reqMgs = append(reqMgs, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Messages = append(reqMgs, map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
})
|
})
|
||||||
var apiKey string
|
var apiKey string
|
||||||
response, err := h.fakeRequest(ctx, userVo, &apiKey, req)
|
response, err := h.doRequest(ctx, userVo, &apiKey, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
logger.Info("用户取消了请求:", prompt)
|
logger.Info("用户取消了请求:", prompt)
|
||||||
@ -213,176 +221,191 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
//contentType := response.Header.Get("Content-Type")
|
contentType := response.Header.Get("Content-Type")
|
||||||
//if strings.Contains(contentType, "text/event-stream") || true {
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
if true {
|
if true {
|
||||||
replyCreatedAt := time.Now()
|
replyCreatedAt := time.Now()
|
||||||
// 循环读取 Chunk 消息
|
// 循环读取 Chunk 消息
|
||||||
var message = types.Message{}
|
var message = types.Message{}
|
||||||
var contents = make([]string, 0)
|
var contents = make([]string, 0)
|
||||||
var functionCall = false
|
var functionCall = false
|
||||||
var functionName string
|
var functionName string
|
||||||
var arguments = make([]string, 0)
|
var arguments = make([]string, 0)
|
||||||
reader := bufio.NewReader(response.Body)
|
reader := bufio.NewReader(response.Body)
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadString('\n')
|
line, err := reader.ReadString('\n')
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
|
||||||
logger.Info("用户取消了请求:", prompt)
|
|
||||||
} else {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseBody = types.ApiResponse{}
|
|
||||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
|
||||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
|
||||||
logger.Error(err, line)
|
|
||||||
replyMessage(ws, ErrorMsg)
|
|
||||||
replyMessage(ws, "")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
|
||||||
if functionCall && fun.Name == "" {
|
|
||||||
arguments = append(arguments, fun.Arguments)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !utils.IsEmptyValue(fun) {
|
|
||||||
functionCall = true
|
|
||||||
functionName = fun.Name
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化 role
|
|
||||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
|
||||||
message.Role = responseBody.Choices[0].Delta.Role
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
|
||||||
continue
|
|
||||||
} else if responseBody.Choices[0].FinishReason != "" {
|
|
||||||
break // 输出完成或者输出中断了
|
|
||||||
} else {
|
|
||||||
content := responseBody.Choices[0].Delta.Content
|
|
||||||
contents = append(contents, content)
|
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: responseBody.Choices[0].Delta.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} // end for
|
|
||||||
|
|
||||||
if functionCall { // 调用函数完成任务
|
|
||||||
// TODO 调用函数完成任务
|
|
||||||
data, err := h.funcZaoBao.Fetch()
|
|
||||||
if err != nil {
|
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: "调用函数出错",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
contents = append(contents, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 消息发送成功
|
|
||||||
if len(contents) > 0 {
|
|
||||||
// 更新用户的对话次数
|
|
||||||
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
|
||||||
if res.Error != nil {
|
|
||||||
return res.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
if message.Role == "" {
|
|
||||||
message.Role = "assistant"
|
|
||||||
}
|
|
||||||
message.Content = strings.Join(contents, "")
|
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
|
||||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 追加聊天记录
|
|
||||||
if userVo.ChatConfig.EnableHistory {
|
|
||||||
// for prompt
|
|
||||||
token, err := utils.CalcTokens(prompt, req.Model)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
if strings.Contains(err.Error(), "context canceled") {
|
||||||
|
logger.Info("用户取消了请求:", prompt)
|
||||||
|
} else if err != io.EOF {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
historyUserMsg := model.HistoryMessage{
|
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||||
UserId: userVo.Id,
|
continue
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.PromptMsg,
|
|
||||||
Icon: user.Avatar,
|
|
||||||
Content: prompt,
|
|
||||||
Tokens: token,
|
|
||||||
}
|
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
|
||||||
res := h.db.Save(&historyUserMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// for reply
|
var responseBody = types.ApiResponse{}
|
||||||
token, err = utils.CalcTokens(message.Content, req.Model)
|
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||||
|
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
||||||
|
logger.Error(err, line)
|
||||||
|
replyMessage(ws, ErrorMsg)
|
||||||
|
replyMessage(ws, "")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||||
|
if functionCall && fun.Name == "" {
|
||||||
|
arguments = append(arguments, fun.Arguments)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utils.IsEmptyValue(fun) {
|
||||||
|
functionCall = true
|
||||||
|
functionName = fun.Name
|
||||||
|
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
|
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化 role
|
||||||
|
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||||
|
message.Role = responseBody.Choices[0].Delta.Role
|
||||||
|
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
|
continue
|
||||||
|
} else if responseBody.Choices[0].FinishReason != "" {
|
||||||
|
break // 输出完成或者输出中断了
|
||||||
|
} else {
|
||||||
|
content := responseBody.Choices[0].Delta.Content
|
||||||
|
contents = append(contents, utils.InterfaceToString(content))
|
||||||
|
replyChunkMessage(ws, types.WsMessage{
|
||||||
|
Type: types.WsMiddle,
|
||||||
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} // end for
|
||||||
|
|
||||||
|
if functionCall { // 调用函数完成任务
|
||||||
|
logger.Info(functionName)
|
||||||
|
logger.Info(arguments)
|
||||||
|
// TODO 调用函数完成任务
|
||||||
|
data, err := h.funcZaoBao.Fetch()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
replyChunkMessage(ws, types.WsMessage{
|
||||||
|
Type: types.WsMiddle,
|
||||||
|
Content: "调用函数出错",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
replyChunkMessage(ws, types.WsMessage{
|
||||||
|
Type: types.WsMiddle,
|
||||||
|
Content: data,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
historyReplyMsg := model.HistoryMessage{
|
contents = append(contents, data)
|
||||||
UserId: userVo.Id,
|
|
||||||
ChatId: session.ChatId,
|
|
||||||
RoleId: role.Id,
|
|
||||||
Type: types.ReplyMsg,
|
|
||||||
Icon: role.Icon,
|
|
||||||
Content: message.Content,
|
|
||||||
Tokens: token,
|
|
||||||
}
|
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
|
||||||
res = h.db.Create(&historyReplyMsg)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统计用户 token 数量
|
|
||||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
|
||||||
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存当前会话
|
// 消息发送成功
|
||||||
var chatItem model.ChatItem
|
if len(contents) > 0 {
|
||||||
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
// 更新用户的对话次数
|
||||||
if res.Error != nil {
|
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
||||||
chatItem.ChatId = session.ChatId
|
if res.Error != nil {
|
||||||
chatItem.UserId = session.UserId
|
return res.Error
|
||||||
chatItem.RoleId = role.Id
|
}
|
||||||
chatItem.Model = session.Model
|
|
||||||
if utf8.RuneCountInString(prompt) > 30 {
|
if message.Role == "" {
|
||||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
message.Role = "assistant"
|
||||||
} else {
|
}
|
||||||
chatItem.Title = prompt
|
message.Content = strings.Join(contents, "")
|
||||||
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
req.Messages = append(req.Messages, message)
|
||||||
|
totalTokens := getTotalTokens(req)
|
||||||
|
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
|
||||||
|
|
||||||
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
|
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||||
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 追加聊天记录
|
||||||
|
if userVo.ChatConfig.EnableHistory {
|
||||||
|
useContext := true
|
||||||
|
if functionCall {
|
||||||
|
useContext = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// for prompt
|
||||||
|
token, err := utils.CalcTokens(prompt, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
historyUserMsg := model.HistoryMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.PromptMsg,
|
||||||
|
Icon: user.Avatar,
|
||||||
|
Content: prompt,
|
||||||
|
Tokens: token,
|
||||||
|
UseContext: useContext,
|
||||||
|
}
|
||||||
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
|
res := h.db.Save(&historyUserMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for reply
|
||||||
|
token, err = utils.CalcTokens(message.Content, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
historyReplyMsg := model.HistoryMessage{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Type: types.ReplyMsg,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Content: message.Content,
|
||||||
|
Tokens: token,
|
||||||
|
UseContext: useContext,
|
||||||
|
}
|
||||||
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
res = h.db.Create(&historyReplyMsg)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计用户 token 数量
|
||||||
|
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||||
|
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存当前会话
|
||||||
|
var chatItem model.ChatItem
|
||||||
|
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||||
|
if res.Error != nil {
|
||||||
|
chatItem.ChatId = session.ChatId
|
||||||
|
chatItem.UserId = session.UserId
|
||||||
|
chatItem.RoleId = role.Id
|
||||||
|
chatItem.Model = session.Model
|
||||||
|
if utf8.RuneCountInString(prompt) > 30 {
|
||||||
|
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||||
|
} else {
|
||||||
|
chatItem.Title = prompt
|
||||||
|
}
|
||||||
|
h.db.Create(&chatItem)
|
||||||
}
|
}
|
||||||
h.db.Create(&chatItem)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
|||||||
return client.Do(request)
|
return client.Do(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
|
|
||||||
link := "https://img.r9it.com/chatgpt/response"
|
|
||||||
client := &http.Client{}
|
|
||||||
request, _ := http.NewRequest(http.MethodGet, link, nil)
|
|
||||||
return client.Do(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 回复客户片段端消息
|
// 回复客户片段端消息
|
||||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
||||||
msg, err := json.Marshal(message)
|
msg, err := json.Marshal(message)
|
||||||
@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, tokens)
|
resp.SUCCESS(c, tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTotalTokens(req types.ApiRequest) int {
|
||||||
|
encode := utils.JsonEncode(req.Messages)
|
||||||
|
var items []map[string]interface{}
|
||||||
|
err := utils.JsonDecode(encode, &items)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
tokens := 0
|
||||||
|
for _, item := range items {
|
||||||
|
content, ok := item["content"]
|
||||||
|
if ok && !utils.IsEmptyValue(content) {
|
||||||
|
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
|
||||||
|
if err == nil {
|
||||||
|
tokens += t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
// StopGenerate 停止生成
|
// StopGenerate 停止生成
|
||||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||||
sessionId := c.Query("session_id")
|
sessionId := c.Query("session_id")
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@ -101,9 +102,12 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建函数
|
// 创建函数
|
||||||
fx.Provide(func() *function.FuncZaoBao {
|
fx.Provide(func() (*function.FuncZaoBao, error) {
|
||||||
token := os.Getenv("AL_API_TOKEN")
|
token := os.Getenv("AL_API_TOKEN")
|
||||||
return function.NewZaoBao(token)
|
if token == "" {
|
||||||
|
return nil, errors.New("invalid AL api token")
|
||||||
|
}
|
||||||
|
return function.NewZaoBao(token), nil
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建控制器
|
// 创建控制器
|
||||||
|
@ -2,13 +2,14 @@ package model
|
|||||||
|
|
||||||
type HistoryMessage struct {
|
type HistoryMessage struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
ChatId string // 会话 ID
|
ChatId string // 会话 ID
|
||||||
UserId uint // 用户 ID
|
UserId uint // 用户 ID
|
||||||
RoleId uint // 角色 ID
|
RoleId uint // 角色 ID
|
||||||
Type string
|
Type string
|
||||||
Icon string
|
Icon string
|
||||||
Tokens int
|
Tokens int
|
||||||
Content string
|
Content string
|
||||||
|
UseContext bool // 是否可以作为聊天上下文
|
||||||
}
|
}
|
||||||
|
|
||||||
func (HistoryMessage) TableName() string {
|
func (HistoryMessage) TableName() string {
|
||||||
|
@ -2,13 +2,14 @@ package vo
|
|||||||
|
|
||||||
type HistoryMessage struct {
|
type HistoryMessage struct {
|
||||||
BaseVo
|
BaseVo
|
||||||
ChatId string `json:"chat_id"`
|
ChatId string `json:"chat_id"`
|
||||||
UserId uint `json:"user_id"`
|
UserId uint `json:"user_id"`
|
||||||
RoleId uint `json:"role_id"`
|
RoleId uint `json:"role_id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Icon string `json:"icon"`
|
Icon string `json:"icon"`
|
||||||
Tokens int `json:"tokens"`
|
Tokens int `json:"tokens"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
UseContext bool `json:"use_context"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (HistoryMessage) TableName() string {
|
func (HistoryMessage) TableName() string {
|
||||||
|
@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string {
|
|||||||
func JsonDecode(src string, dest interface{}) error {
|
func JsonDecode(src string, dest interface{}) error {
|
||||||
return json.Unmarshal([]byte(src), dest)
|
return json.Unmarshal([]byte(src), dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InterfaceToString(value interface{}) string {
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
return JsonEncode(value)
|
||||||
|
}
|
||||||
|
1
database/plugins.sql
Normal file
1
database/plugins.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`;
|
@ -94,6 +94,7 @@
|
|||||||
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
|
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
|
||||||
<div class="chat-head">
|
<div class="chat-head">
|
||||||
<div class="chat-config">
|
<div class="chat-config">
|
||||||
|
<span class="role-select-label">聊天角色:</span>
|
||||||
<el-select v-model="roleId" filterable placeholder="角色" class="role-select">
|
<el-select v-model="roleId" filterable placeholder="角色" class="role-select">
|
||||||
<el-option
|
<el-option
|
||||||
v-for="item in roles"
|
v-for="item in roles"
|
||||||
@ -210,7 +211,8 @@ import {
|
|||||||
Check,
|
Check,
|
||||||
Close,
|
Close,
|
||||||
Delete,
|
Delete,
|
||||||
Edit, Iphone,
|
Edit,
|
||||||
|
Iphone,
|
||||||
Plus,
|
Plus,
|
||||||
Promotion,
|
Promotion,
|
||||||
RefreshRight,
|
RefreshRight,
|
||||||
@ -920,6 +922,10 @@ $borderColor = #4676d0;
|
|||||||
justify-content center;
|
justify-content center;
|
||||||
padding-top 10px;
|
padding-top 10px;
|
||||||
|
|
||||||
|
.role-select-label {
|
||||||
|
color #ffffff
|
||||||
|
}
|
||||||
|
|
||||||
.el-select {
|
.el-select {
|
||||||
//max-width 150px;
|
//max-width 150px;
|
||||||
margin-right 10px;
|
margin-right 10px;
|
||||||
|
Loading…
Reference in New Issue
Block a user