refactor: 更改 OpenAI 请求 Body 数据结构,兼容函数调用请求

This commit is contained in:
RockYang 2023-07-15 18:00:40 +08:00
parent a5ad9648bf
commit cc1b56501d
11 changed files with 335 additions and 214 deletions

View File

@ -22,8 +22,8 @@ type AppServer struct {
Debug bool Debug bool
Config *types.AppConfig Config *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // 聊天配置 ChatConfig *types.ChatConfig // 聊天配置
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API // 防止第三方直接连接 socket 调用 OpenAI API
@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false, Debug: false,
Config: appConfig, Config: appConfig,
Engine: gin.Default(), Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []types.Message](), ChatContexts: types.NewLMap[string, []interface{}](),
ChatSession: types.NewLMap[string, types.ChatSession](), ChatSession: types.NewLMap[string, types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](), ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),

View File

@ -2,17 +2,17 @@ package types
// ApiRequest API 请求实体 // ApiRequest API 请求实体
type ApiRequest struct { type ApiRequest struct {
Model string `json:"model"` Model string `json:"model"`
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
Messages []Message `json:"messages"` Messages []interface{} `json:"messages"`
Functions []Function `json:"functions"`
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
FunctionCall FunctionCall `json:"function_call"`
} }
type ApiResponse struct { type ApiResponse struct {
@ -21,8 +21,15 @@ type ApiResponse struct {
// ChoiceItem API 响应实体 // ChoiceItem API 响应实体
type ChoiceItem struct { type ChoiceItem struct {
Delta Message `json:"delta"` Delta Delta `json:"delta"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
}
type Delta struct {
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
FunctionCall FunctionCall `json:"function_call,omitempty"`
} }
// ChatSession 聊天会话对象 // ChatSession 聊天会话对象

View File

@ -6,13 +6,71 @@ type FunctionCall struct {
} }
type Function struct { type Function struct {
Name string Name string `json:"name"`
Description string Description string `json:"description"`
Parameters []Parameter Parameters Parameters `json:"parameters"`
} }
type Parameter struct { type Parameters struct {
Type string Type string `json:"type"`
Required []string Required []string `json:"required"`
Properties map[string]interface{} Properties map[string]Property `json:"properties"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
var InnerFunctions = []Function{
{
Name: "zao_bao",
Description: "每日早报,获取当天全球的热门新闻事件列表",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: "weibo_hot",
Description: "新浪微博热搜榜,微博当日热搜榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: "zhihu_top",
Description: "知乎热榜,知乎当日话题讨论榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
}
var FunctionNameMap = map[string]string{
"zao_bao": "每日早报",
"weibo_hot": "微博热搜",
"zhihu_top": "知乎热榜",
} }

View File

@ -9,7 +9,7 @@ type MKey interface {
string | int string | int
} }
type MValue interface { type MValue interface {
*WsClient | ChatSession | []Message | context.CancelFunc *WsClient | ChatSession | context.CancelFunc | []interface{}
} }
type LMap[K MKey, T MValue] struct { type LMap[K MKey, T MValue] struct {
lock sync.RWMutex lock sync.RWMutex

View File

@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Temperature: userVo.ChatConfig.Temperature, Temperature: userVo.ChatConfig.Temperature,
MaxTokens: userVo.ChatConfig.MaxTokens, MaxTokens: userVo.ChatConfig.MaxTokens,
Stream: true, Stream: true,
Functions: types.InnerFunctions,
} }
// 加载聊天上下文 // 加载聊天上下文
var chatCtx []types.Message var chatCtx []interface{}
if userVo.ChatConfig.EnableContext { if userVo.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) { if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId) chatCtx = h.App.ChatContexts.Get(session.ChatId)
@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
var messages []types.Message var messages []types.Message
err := utils.JsonDecode(role.Context, &messages) err := utils.JsonDecode(role.Context, &messages)
if err == nil { if err == nil {
chatCtx = messages for _, v := range messages {
chatCtx = append(chatCtx, v)
}
} }
// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的 // TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
var historyMessages []model.HistoryMessage var historyMessages []model.HistoryMessage
res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages) res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
if res.Error == nil { if res.Error == nil {
for _, msg := range historyMessages { for _, msg := range historyMessages {
ms := types.Message{Role: "user", Content: msg.Content} ms := types.Message{Role: "user", Content: msg.Content}
@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
logger.Info("聊天上下文:", chatCtx) logger.Info("聊天上下文:", chatCtx)
} }
} }
req.Messages = append(chatCtx, types.Message{ reqMgs := make([]interface{}, 0)
Role: "user", for _, m := range chatCtx {
Content: prompt, reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
}) })
var apiKey string var apiKey string
response, err := h.fakeRequest(ctx, userVo, &apiKey, req) response, err := h.doRequest(ctx, userVo, &apiKey, req)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "context canceled") { if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt) logger.Info("用户取消了请求:", prompt)
@ -213,176 +221,191 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
defer response.Body.Close() defer response.Body.Close()
} }
//contentType := response.Header.Get("Content-Type") contentType := response.Header.Get("Content-Type")
//if strings.Contains(contentType, "text/event-stream") || true { if strings.Contains(contentType, "text/event-stream") {
if true { if true {
replyCreatedAt := time.Now() replyCreatedAt := time.Now()
// 循环读取 Chunk 消息 // 循环读取 Chunk 消息
var message = types.Message{} var message = types.Message{}
var contents = make([]string, 0) var contents = make([]string, 0)
var functionCall = false var functionCall = false
var functionName string var functionName string
var arguments = make([]string, 0) var arguments = make([]string, 0)
reader := bufio.NewReader(response.Body) reader := bufio.NewReader(response.Body)
for { for {
line, err := reader.ReadString('\n') line, err := reader.ReadString('\n')
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error(err)
}
break
}
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
logger.Error(err, line)
replyMessage(ws, ErrorMsg)
replyMessage(ws, "![](/images/wx.png)")
break
}
fun := responseBody.Choices[0].Delta.FunctionCall
if functionCall && fun.Name == "" {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(fun) {
functionCall = true
functionName = fun.Name
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)})
continue
}
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
})
}
} // end for
if functionCall { // 调用函数完成任务
// TODO 调用函数完成任务
data, err := h.funcZaoBao.Fetch()
if err != nil {
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: "调用函数出错",
})
} else {
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: data,
})
}
contents = append(contents, data)
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
if res.Error != nil {
return res.Error
}
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
// for prompt
token, err := utils.CalcTokens(prompt, req.Model)
if err != nil { if err != nil {
logger.Error(err) if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else if err != io.EOF {
logger.Error(err)
}
break
} }
historyUserMsg := model.HistoryMessage{ if !strings.Contains(line, "data:") || len(line) < 30 {
UserId: userVo.Id, continue
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: user.Avatar,
Content: prompt,
Tokens: token,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
} }
// for reply var responseBody = types.ApiResponse{}
token, err = utils.CalcTokens(message.Content, req.Model) err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
logger.Error(err, line)
replyMessage(ws, ErrorMsg)
replyMessage(ws, "![](/images/wx.png)")
break
}
fun := responseBody.Choices[0].Delta.FunctionCall
if functionCall && fun.Name == "" {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(fun) {
functionCall = true
functionName = fun.Name
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
continue
}
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if functionCall { // 调用函数完成任务
logger.Info(functionName)
logger.Info(arguments)
// TODO 调用函数完成任务
data, err := h.funcZaoBao.Fetch()
if err != nil { if err != nil {
logger.Error(err) replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: "调用函数出错",
})
} else {
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: data,
})
} }
historyReplyMsg := model.HistoryMessage{ contents = append(contents, data)
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: token,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 统计用户 token 数量
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
historyUserMsg.Tokens+historyReplyMsg.Tokens))
} }
// 保存当前会话 // 消息发送成功
var chatItem model.ChatItem if len(contents) > 0 {
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) // 更新用户的对话次数
if res.Error != nil { res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
chatItem.ChatId = session.ChatId if res.Error != nil {
chatItem.UserId = session.UserId return res.Error
chatItem.RoleId = role.Id }
chatItem.Model = session.Model
if utf8.RuneCountInString(prompt) > 30 { if message.Role == "" {
chatItem.Title = string([]rune(prompt)[:30]) + "..." message.Role = "assistant"
} else { }
chatItem.Title = prompt message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 计算本次对话消耗的总 token 数量
req.Messages = append(req.Messages, message)
totalTokens := getTotalTokens(req)
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt
token, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: user.Avatar,
Content: prompt,
Tokens: token,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
token, err = utils.CalcTokens(message.Content, req.Model)
if err != nil {
logger.Error(err)
}
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: token,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 统计用户 token 数量
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
historyUserMsg.Tokens+historyReplyMsg.Tokens))
}
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.Model = session.Model
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
} }
h.db.Create(&chatItem)
} }
} }
} else { } else {
@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return client.Do(request) return client.Do(request)
} }
func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
link := "https://img.r9it.com/chatgpt/response"
client := &http.Client{}
request, _ := http.NewRequest(http.MethodGet, link, nil)
return client.Do(request)
}
// 回复客户片段端消息 // 回复客户片段端消息
func replyChunkMessage(client types.Client, message types.WsMessage) { func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message) msg, err := json.Marshal(message)
@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
resp.SUCCESS(c, tokens) resp.SUCCESS(c, tokens)
} }
func getTotalTokens(req types.ApiRequest) int {
encode := utils.JsonEncode(req.Messages)
var items []map[string]interface{}
err := utils.JsonDecode(encode, &items)
if err != nil {
return 0
}
tokens := 0
for _, item := range items {
content, ok := item["content"]
if ok && !utils.IsEmptyValue(content) {
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
if err == nil {
tokens += t
}
}
}
return tokens
}
// StopGenerate 停止生成 // StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) { func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id") sessionId := c.Query("session_id")

View File

@ -11,6 +11,7 @@ import (
"chatplus/store" "chatplus/store"
"context" "context"
"embed" "embed"
"errors"
"io" "io"
"log" "log"
"os" "os"
@ -101,9 +102,12 @@ func main() {
}), }),
// 创建函数 // 创建函数
fx.Provide(func() *function.FuncZaoBao { fx.Provide(func() (*function.FuncZaoBao, error) {
token := os.Getenv("AL_API_TOKEN") token := os.Getenv("AL_API_TOKEN")
return function.NewZaoBao(token) if token == "" {
return nil, errors.New("invalid AL api token")
}
return function.NewZaoBao(token), nil
}), }),
// 创建控制器 // 创建控制器

View File

@ -2,13 +2,14 @@ package model
type HistoryMessage struct { type HistoryMessage struct {
BaseModel BaseModel
ChatId string // 会话 ID ChatId string // 会话 ID
UserId uint // 用户 ID UserId uint // 用户 ID
RoleId uint // 角色 ID RoleId uint // 角色 ID
Type string Type string
Icon string Icon string
Tokens int Tokens int
Content string Content string
UseContext bool // 是否可以作为聊天上下文
} }
func (HistoryMessage) TableName() string { func (HistoryMessage) TableName() string {

View File

@ -2,13 +2,14 @@ package vo
type HistoryMessage struct { type HistoryMessage struct {
BaseVo BaseVo
ChatId string `json:"chat_id"` ChatId string `json:"chat_id"`
UserId uint `json:"user_id"` UserId uint `json:"user_id"`
RoleId uint `json:"role_id"` RoleId uint `json:"role_id"`
Type string `json:"type"` Type string `json:"type"`
Icon string `json:"icon"` Icon string `json:"icon"`
Tokens int `json:"tokens"` Tokens int `json:"tokens"`
Content string `json:"content"` Content string `json:"content"`
UseContext bool `json:"use_context"`
} }
func (HistoryMessage) TableName() string { func (HistoryMessage) TableName() string {

View File

@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string {
func JsonDecode(src string, dest interface{}) error { func JsonDecode(src string, dest interface{}) error {
return json.Unmarshal([]byte(src), dest) return json.Unmarshal([]byte(src), dest)
} }
func InterfaceToString(value interface{}) string {
if str, ok := value.(string); ok {
return str
}
return JsonEncode(value)
}

1
database/plugins.sql Normal file
View File

@ -0,0 +1 @@
ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`;

View File

@ -94,6 +94,7 @@
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)"> <el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
<div class="chat-head"> <div class="chat-head">
<div class="chat-config"> <div class="chat-config">
<span class="role-select-label">聊天角色</span>
<el-select v-model="roleId" filterable placeholder="角色" class="role-select"> <el-select v-model="roleId" filterable placeholder="角色" class="role-select">
<el-option <el-option
v-for="item in roles" v-for="item in roles"
@ -210,7 +211,8 @@ import {
Check, Check,
Close, Close,
Delete, Delete,
Edit, Iphone, Edit,
Iphone,
Plus, Plus,
Promotion, Promotion,
RefreshRight, RefreshRight,
@ -920,6 +922,10 @@ $borderColor = #4676d0;
justify-content center; justify-content center;
padding-top 10px; padding-top 10px;
.role-select-label {
color #ffffff
}
.el-select { .el-select {
//max-width 150px; //max-width 150px;
margin-right 10px; margin-right 10px;