refactor: 更改 OpenAI 请求 Body 数据结构,兼容函数调用请求

This commit is contained in:
RockYang 2023-07-15 18:00:40 +08:00
parent a5ad9648bf
commit cc1b56501d
11 changed files with 335 additions and 214 deletions

View File

@ -22,7 +22,7 @@ type AppServer struct {
Debug bool Debug bool
Config *types.AppConfig Config *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // 聊天配置 ChatConfig *types.ChatConfig // 聊天配置
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false, Debug: false,
Config: appConfig, Config: appConfig,
Engine: gin.Default(), Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []types.Message](), ChatContexts: types.NewLMap[string, []interface{}](),
ChatSession: types.NewLMap[string, types.ChatSession](), ChatSession: types.NewLMap[string, types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](), ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),

View File

@ -6,13 +6,13 @@ type ApiRequest struct {
Temperature float32 `json:"temperature"` Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
Messages []Message `json:"messages"` Messages []interface{} `json:"messages"`
Functions []Function `json:"functions"`
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
FunctionCall FunctionCall `json:"function_call"`
} }
type ApiResponse struct { type ApiResponse struct {
@ -21,10 +21,17 @@ type ApiResponse struct {
// ChoiceItem API 响应实体 // ChoiceItem API 响应实体
type ChoiceItem struct { type ChoiceItem struct {
Delta Message `json:"delta"` Delta Delta `json:"delta"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} }
type Delta struct {
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
FunctionCall FunctionCall `json:"function_call,omitempty"`
}
// ChatSession 聊天会话对象 // ChatSession 聊天会话对象
type ChatSession struct { type ChatSession struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`

View File

@ -6,13 +6,71 @@ type FunctionCall struct {
} }
type Function struct { type Function struct {
Name string Name string `json:"name"`
Description string Description string `json:"description"`
Parameters []Parameter Parameters Parameters `json:"parameters"`
} }
type Parameter struct { type Parameters struct {
Type string Type string `json:"type"`
Required []string Required []string `json:"required"`
Properties map[string]interface{} Properties map[string]Property `json:"properties"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
var InnerFunctions = []Function{
{
Name: "zao_bao",
Description: "每日早报,获取当天全球的热门新闻事件列表",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: "weibo_hot",
Description: "新浪微博热搜榜,微博当日热搜榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: "zhihu_top",
Description: "知乎热榜,知乎当日话题讨论榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
}
var FunctionNameMap = map[string]string{
"zao_bao": "每日早报",
"weibo_hot": "微博热搜",
"zhihu_top": "知乎热榜",
} }

View File

@ -9,7 +9,7 @@ type MKey interface {
string | int string | int
} }
type MValue interface { type MValue interface {
*WsClient | ChatSession | []Message | context.CancelFunc *WsClient | ChatSession | context.CancelFunc | []interface{}
} }
type LMap[K MKey, T MValue] struct { type LMap[K MKey, T MValue] struct {
lock sync.RWMutex lock sync.RWMutex

View File

@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Temperature: userVo.ChatConfig.Temperature, Temperature: userVo.ChatConfig.Temperature,
MaxTokens: userVo.ChatConfig.MaxTokens, MaxTokens: userVo.ChatConfig.MaxTokens,
Stream: true, Stream: true,
Functions: types.InnerFunctions,
} }
// 加载聊天上下文 // 加载聊天上下文
var chatCtx []types.Message var chatCtx []interface{}
if userVo.ChatConfig.EnableContext { if userVo.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) { if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId) chatCtx = h.App.ChatContexts.Get(session.ChatId)
@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
var messages []types.Message var messages []types.Message
err := utils.JsonDecode(role.Context, &messages) err := utils.JsonDecode(role.Context, &messages)
if err == nil { if err == nil {
chatCtx = messages for _, v := range messages {
chatCtx = append(chatCtx, v)
} }
// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的 }
// TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
var historyMessages []model.HistoryMessage var historyMessages []model.HistoryMessage
res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages) res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
if res.Error == nil { if res.Error == nil {
for _, msg := range historyMessages { for _, msg := range historyMessages {
ms := types.Message{Role: "user", Content: msg.Content} ms := types.Message{Role: "user", Content: msg.Content}
@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
logger.Info("聊天上下文:", chatCtx) logger.Info("聊天上下文:", chatCtx)
} }
} }
req.Messages = append(chatCtx, types.Message{ reqMgs := make([]interface{}, 0)
Role: "user", for _, m := range chatCtx {
Content: prompt, reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
}) })
var apiKey string var apiKey string
response, err := h.fakeRequest(ctx, userVo, &apiKey, req) response, err := h.doRequest(ctx, userVo, &apiKey, req)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "context canceled") { if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt) logger.Info("用户取消了请求:", prompt)
@ -213,8 +221,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
defer response.Body.Close() defer response.Body.Close()
} }
//contentType := response.Header.Get("Content-Type") contentType := response.Header.Get("Content-Type")
//if strings.Contains(contentType, "text/event-stream") || true { if strings.Contains(contentType, "text/event-stream") {
if true { if true {
replyCreatedAt := time.Now() replyCreatedAt := time.Now()
// 循环读取 Chunk 消息 // 循环读取 Chunk 消息
@ -229,7 +237,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if err != nil { if err != nil {
if strings.Contains(err.Error(), "context canceled") { if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt) logger.Info("用户取消了请求:", prompt)
} else { } else if err != io.EOF {
logger.Error(err) logger.Error(err)
} }
break break
@ -257,7 +265,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
functionCall = true functionCall = true
functionName = fun.Name functionName = fun.Name
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)}) replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
continue continue
} }
@ -274,15 +282,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
break // 输出完成或者输出中断了 break // 输出完成或者输出中断了
} else { } else {
content := responseBody.Choices[0].Delta.Content content := responseBody.Choices[0].Delta.Content
contents = append(contents, content) contents = append(contents, utils.InterfaceToString(content))
replyChunkMessage(ws, types.WsMessage{ replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle, Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
}) })
} }
} // end for } // end for
if functionCall { // 调用函数完成任务 if functionCall { // 调用函数完成任务
logger.Info(functionName)
logger.Info(arguments)
// TODO 调用函数完成任务 // TODO 调用函数完成任务
data, err := h.funcZaoBao.Fetch() data, err := h.funcZaoBao.Fetch()
if err != nil { if err != nil {
@ -313,6 +323,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
message.Content = strings.Join(contents, "") message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt} useMsg := types.Message{Role: "user", Content: prompt}
// 计算本次对话消耗的总 token 数量
req.Messages = append(req.Messages, message)
totalTokens := getTotalTokens(req)
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false { if userVo.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
@ -322,6 +337,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 追加聊天记录 // 追加聊天记录
if userVo.ChatConfig.EnableHistory { if userVo.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt // for prompt
token, err := utils.CalcTokens(prompt, req.Model) token, err := utils.CalcTokens(prompt, req.Model)
if err != nil { if err != nil {
@ -335,6 +355,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Icon: user.Avatar, Icon: user.Avatar,
Content: prompt, Content: prompt,
Tokens: token, Tokens: token,
UseContext: useContext,
} }
historyUserMsg.CreatedAt = promptCreatedAt historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt historyUserMsg.UpdatedAt = promptCreatedAt
@ -356,6 +377,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Icon: role.Icon, Icon: role.Icon,
Content: message.Content, Content: message.Content,
Tokens: token, Tokens: token,
UseContext: useContext,
} }
historyReplyMsg.CreatedAt = replyCreatedAt historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt historyReplyMsg.UpdatedAt = replyCreatedAt
@ -385,6 +407,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
h.db.Create(&chatItem) h.db.Create(&chatItem)
} }
} }
}
} else { } else {
body, err := io.ReadAll(response.Body) body, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return client.Do(request) return client.Do(request)
} }
func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
link := "https://img.r9it.com/chatgpt/response"
client := &http.Client{}
request, _ := http.NewRequest(http.MethodGet, link, nil)
return client.Do(request)
}
// 回复客户片段端消息 // 回复客户片段端消息
func replyChunkMessage(client types.Client, message types.WsMessage) { func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message) msg, err := json.Marshal(message)
@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
resp.SUCCESS(c, tokens) resp.SUCCESS(c, tokens)
} }
func getTotalTokens(req types.ApiRequest) int {
encode := utils.JsonEncode(req.Messages)
var items []map[string]interface{}
err := utils.JsonDecode(encode, &items)
if err != nil {
return 0
}
tokens := 0
for _, item := range items {
content, ok := item["content"]
if ok && !utils.IsEmptyValue(content) {
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
if err == nil {
tokens += t
}
}
}
return tokens
}
// StopGenerate 停止生成 // StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) { func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id") sessionId := c.Query("session_id")

View File

@ -11,6 +11,7 @@ import (
"chatplus/store" "chatplus/store"
"context" "context"
"embed" "embed"
"errors"
"io" "io"
"log" "log"
"os" "os"
@ -101,9 +102,12 @@ func main() {
}), }),
// 创建函数 // 创建函数
fx.Provide(func() *function.FuncZaoBao { fx.Provide(func() (*function.FuncZaoBao, error) {
token := os.Getenv("AL_API_TOKEN") token := os.Getenv("AL_API_TOKEN")
return function.NewZaoBao(token) if token == "" {
return nil, errors.New("invalid AL api token")
}
return function.NewZaoBao(token), nil
}), }),
// 创建控制器 // 创建控制器

View File

@ -9,6 +9,7 @@ type HistoryMessage struct {
Icon string Icon string
Tokens int Tokens int
Content string Content string
UseContext bool // 是否可以作为聊天上下文
} }
func (HistoryMessage) TableName() string { func (HistoryMessage) TableName() string {

View File

@ -9,6 +9,7 @@ type HistoryMessage struct {
Icon string `json:"icon"` Icon string `json:"icon"`
Tokens int `json:"tokens"` Tokens int `json:"tokens"`
Content string `json:"content"` Content string `json:"content"`
UseContext bool `json:"use_context"`
} }
func (HistoryMessage) TableName() string { func (HistoryMessage) TableName() string {

View File

@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string {
func JsonDecode(src string, dest interface{}) error { func JsonDecode(src string, dest interface{}) error {
return json.Unmarshal([]byte(src), dest) return json.Unmarshal([]byte(src), dest)
} }
func InterfaceToString(value interface{}) string {
if str, ok := value.(string); ok {
return str
}
return JsonEncode(value)
}

1
database/plugins.sql Normal file
View File

@ -0,0 +1 @@
ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`;

View File

@ -94,6 +94,7 @@
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)"> <el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
<div class="chat-head"> <div class="chat-head">
<div class="chat-config"> <div class="chat-config">
<span class="role-select-label">聊天角色</span>
<el-select v-model="roleId" filterable placeholder="角色" class="role-select"> <el-select v-model="roleId" filterable placeholder="角色" class="role-select">
<el-option <el-option
v-for="item in roles" v-for="item in roles"
@ -210,7 +211,8 @@ import {
Check, Check,
Close, Close,
Delete, Delete,
Edit, Iphone, Edit,
Iphone,
Plus, Plus,
Promotion, Promotion,
RefreshRight, RefreshRight,
@ -920,6 +922,10 @@ $borderColor = #4676d0;
justify-content center; justify-content center;
padding-top 10px; padding-top 10px;
.role-select-label {
color #ffffff
}
.el-select { .el-select {
//max-width 150px; //max-width 150px;
margin-right 10px; margin-right 10px;