refactor: 更改 OpenAI 请求 Body 数据结构,兼容函数调用请求

This commit is contained in:
RockYang 2023-07-15 18:00:40 +08:00
parent a5ad9648bf
commit cc1b56501d
11 changed files with 335 additions and 214 deletions

View File

@ -22,7 +22,7 @@ type AppServer struct {
Debug bool
Config *types.AppConfig
Engine *gin.Engine
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // 聊天配置
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false,
Config: appConfig,
Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatContexts: types.NewLMap[string, []interface{}](),
ChatSession: types.NewLMap[string, types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),

View File

@ -6,13 +6,13 @@ type ApiRequest struct {
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
Messages []Message `json:"messages"`
Messages []interface{} `json:"messages"`
Functions []Function `json:"functions"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
FunctionCall FunctionCall `json:"function_call"`
}
type ApiResponse struct {
@ -21,10 +21,17 @@ type ApiResponse struct {
// ChoiceItem API 响应实体
type ChoiceItem struct {
Delta Message `json:"delta"`
Delta Delta `json:"delta"`
FinishReason string `json:"finish_reason"`
}
type Delta struct {
Role string `json:"role"`
Name string `json:"name"`
Content interface{} `json:"content"`
FunctionCall FunctionCall `json:"function_call,omitempty"`
}
// ChatSession 聊天会话对象
type ChatSession struct {
SessionId string `json:"session_id"`

View File

@ -6,13 +6,71 @@ type FunctionCall struct {
}
type Function struct {
Name string
Description string
Parameters []Parameter
Name string `json:"name"`
Description string `json:"description"`
Parameters Parameters `json:"parameters"`
}
type Parameter struct {
Type string
Required []string
Properties map[string]interface{}
type Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]Property `json:"properties"`
}
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
var InnerFunctions = []Function{
{
Name: "zao_bao",
Description: "每日早报,获取当天全球的热门新闻事件列表",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: "weibo_hot",
Description: "新浪微博热搜榜,微博当日热搜榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
{
Name: "zhihu_top",
Description: "知乎热榜,知乎当日话题讨论榜单",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"text": {
Type: "string",
Description: "",
},
},
Required: []string{},
},
},
}
var FunctionNameMap = map[string]string{
"zao_bao": "每日早报",
"weibo_hot": "微博热搜",
"zhihu_top": "知乎热榜",
}

View File

@ -9,7 +9,7 @@ type MKey interface {
string | int
}
type MValue interface {
*WsClient | ChatSession | []Message | context.CancelFunc
*WsClient | ChatSession | context.CancelFunc | []interface{}
}
type LMap[K MKey, T MValue] struct {
lock sync.RWMutex

View File

@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Temperature: userVo.ChatConfig.Temperature,
MaxTokens: userVo.ChatConfig.MaxTokens,
Stream: true,
Functions: types.InnerFunctions,
}
// 加载聊天上下文
var chatCtx []types.Message
var chatCtx []interface{}
if userVo.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
chatCtx = messages
for _, v := range messages {
chatCtx = append(chatCtx, v)
}
// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的
}
// TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
var historyMessages []model.HistoryMessage
res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages)
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
if res.Error == nil {
for _, msg := range historyMessages {
ms := types.Message{Role: "user", Content: msg.Content}
@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
logger.Info("聊天上下文:", chatCtx)
}
}
req.Messages = append(chatCtx, types.Message{
Role: "user",
Content: prompt,
reqMgs := make([]interface{}, 0)
for _, m := range chatCtx {
reqMgs = append(reqMgs, m)
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
})
var apiKey string
response, err := h.fakeRequest(ctx, userVo, &apiKey, req)
response, err := h.doRequest(ctx, userVo, &apiKey, req)
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
@ -213,8 +221,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
defer response.Body.Close()
}
//contentType := response.Header.Get("Content-Type")
//if strings.Contains(contentType, "text/event-stream") || true {
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
if true {
replyCreatedAt := time.Now()
// 循环读取 Chunk 消息
@ -229,7 +237,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
} else if err != io.EOF {
logger.Error(err)
}
break
@ -257,7 +265,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
functionCall = true
functionName = fun.Name
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
continue
}
@ -274,15 +282,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
contents = append(contents, utils.InterfaceToString(content))
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if functionCall { // 调用函数完成任务
logger.Info(functionName)
logger.Info(arguments)
// TODO 调用函数完成任务
data, err := h.funcZaoBao.Fetch()
if err != nil {
@ -313,6 +323,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 计算本次对话消耗的总 token 数量
req.Messages = append(req.Messages, message)
totalTokens := getTotalTokens(req)
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
@ -322,6 +337,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 追加聊天记录
if userVo.ChatConfig.EnableHistory {
useContext := true
if functionCall {
useContext = false
}
// for prompt
token, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
@ -335,6 +355,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Icon: user.Avatar,
Content: prompt,
Tokens: token,
UseContext: useContext,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
@ -356,6 +377,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
Icon: role.Icon,
Content: message.Content,
Tokens: token,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
@ -385,6 +407,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
h.db.Create(&chatItem)
}
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return client.Do(request)
}
func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
link := "https://img.r9it.com/chatgpt/response"
client := &http.Client{}
request, _ := http.NewRequest(http.MethodGet, link, nil)
return client.Do(request)
}
// 回复客户片段端消息
func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message)
@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
resp.SUCCESS(c, tokens)
}
func getTotalTokens(req types.ApiRequest) int {
encode := utils.JsonEncode(req.Messages)
var items []map[string]interface{}
err := utils.JsonDecode(encode, &items)
if err != nil {
return 0
}
tokens := 0
for _, item := range items {
content, ok := item["content"]
if ok && !utils.IsEmptyValue(content) {
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
if err == nil {
tokens += t
}
}
}
return tokens
}
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")

View File

@ -11,6 +11,7 @@ import (
"chatplus/store"
"context"
"embed"
"errors"
"io"
"log"
"os"
@ -101,9 +102,12 @@ func main() {
}),
// 创建函数
fx.Provide(func() *function.FuncZaoBao {
fx.Provide(func() (*function.FuncZaoBao, error) {
token := os.Getenv("AL_API_TOKEN")
return function.NewZaoBao(token)
if token == "" {
return nil, errors.New("invalid AL api token")
}
return function.NewZaoBao(token), nil
}),
// 创建控制器

View File

@ -9,6 +9,7 @@ type HistoryMessage struct {
Icon string
Tokens int
Content string
UseContext bool // 是否可以作为聊天上下文
}
func (HistoryMessage) TableName() string {

View File

@ -9,6 +9,7 @@ type HistoryMessage struct {
Icon string `json:"icon"`
Tokens int `json:"tokens"`
Content string `json:"content"`
UseContext bool `json:"use_context"`
}
func (HistoryMessage) TableName() string {

View File

@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string {
func JsonDecode(src string, dest interface{}) error {
return json.Unmarshal([]byte(src), dest)
}
func InterfaceToString(value interface{}) string {
if str, ok := value.(string); ok {
return str
}
return JsonEncode(value)
}

1
database/plugins.sql Normal file
View File

@ -0,0 +1 @@
ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`;

View File

@ -94,6 +94,7 @@
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
<div class="chat-head">
<div class="chat-config">
<span class="role-select-label">聊天角色</span>
<el-select v-model="roleId" filterable placeholder="角色" class="role-select">
<el-option
v-for="item in roles"
@ -210,7 +211,8 @@ import {
Check,
Close,
Delete,
Edit, Iphone,
Edit,
Iphone,
Plus,
Promotion,
RefreshRight,
@ -920,6 +922,10 @@ $borderColor = #4676d0;
justify-content center;
padding-top 10px;
.role-select-label {
color #ffffff
}
.el-select {
//max-width 150px;
margin-right 10px;