mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
refactor: 更改 OpenAI 请求 Body 数据结构,兼容函数调用请求
This commit is contained in:
parent
a5ad9648bf
commit
cc1b56501d
@ -22,8 +22,8 @@ type AppServer struct {
|
||||
Debug bool
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||
ChatConfig *types.ChatConfig // 聊天配置
|
||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
||||
ChatConfig *types.ChatConfig // 聊天配置
|
||||
|
||||
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
Debug: false,
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
|
@ -2,17 +2,17 @@ package types
|
||||
|
||||
// ApiRequest API 请求实体
|
||||
type ApiRequest struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
Messages []Message `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
Messages []interface{} `json:"messages"`
|
||||
Functions []Function `json:"functions"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
FunctionCall FunctionCall `json:"function_call"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ApiResponse struct {
|
||||
@ -21,8 +21,15 @@ type ApiResponse struct {
|
||||
|
||||
// ChoiceItem API 响应实体
|
||||
type ChoiceItem struct {
|
||||
Delta Message `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Delta Delta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Content interface{} `json:"content"`
|
||||
FunctionCall FunctionCall `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatSession 聊天会话对象
|
||||
|
@ -6,13 +6,71 @@ type FunctionCall struct {
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []Parameter
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters Parameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Type string
|
||||
Required []string
|
||||
Properties map[string]interface{}
|
||||
type Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]Property `json:"properties"`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
var InnerFunctions = []Function{
|
||||
{
|
||||
Name: "zao_bao",
|
||||
Description: "每日早报,获取当天全球的热门新闻事件列表",
|
||||
Parameters: Parameters{
|
||||
|
||||
Type: "object",
|
||||
Properties: map[string]Property{
|
||||
"text": {
|
||||
Type: "string",
|
||||
Description: "",
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "weibo_hot",
|
||||
Description: "新浪微博热搜榜,微博当日热搜榜单",
|
||||
Parameters: Parameters{
|
||||
Type: "object",
|
||||
Properties: map[string]Property{
|
||||
"text": {
|
||||
Type: "string",
|
||||
Description: "",
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "zhihu_top",
|
||||
Description: "知乎热榜,知乎当日话题讨论榜单",
|
||||
Parameters: Parameters{
|
||||
Type: "object",
|
||||
Properties: map[string]Property{
|
||||
"text": {
|
||||
Type: "string",
|
||||
Description: "",
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var FunctionNameMap = map[string]string{
|
||||
"zao_bao": "每日早报",
|
||||
"weibo_hot": "微博热搜",
|
||||
"zhihu_top": "知乎热榜",
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ type MKey interface {
|
||||
string | int
|
||||
}
|
||||
type MValue interface {
|
||||
*WsClient | ChatSession | []Message | context.CancelFunc
|
||||
*WsClient | ChatSession | context.CancelFunc | []interface{}
|
||||
}
|
||||
type LMap[K MKey, T MValue] struct {
|
||||
lock sync.RWMutex
|
||||
|
@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
Temperature: userVo.ChatConfig.Temperature,
|
||||
MaxTokens: userVo.ChatConfig.MaxTokens,
|
||||
Stream: true,
|
||||
Functions: types.InnerFunctions,
|
||||
}
|
||||
|
||||
// 加载聊天上下文
|
||||
var chatCtx []types.Message
|
||||
var chatCtx []interface{}
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if h.App.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
||||
@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
var messages []types.Message
|
||||
err := utils.JsonDecode(role.Context, &messages)
|
||||
if err == nil {
|
||||
chatCtx = messages
|
||||
for _, v := range messages {
|
||||
chatCtx = append(chatCtx, v)
|
||||
}
|
||||
}
|
||||
// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的
|
||||
// TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
|
||||
var historyMessages []model.HistoryMessage
|
||||
res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages)
|
||||
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
|
||||
if res.Error == nil {
|
||||
for _, msg := range historyMessages {
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
logger.Info("聊天上下文:", chatCtx)
|
||||
}
|
||||
}
|
||||
req.Messages = append(chatCtx, types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
reqMgs := make([]interface{}, 0)
|
||||
for _, m := range chatCtx {
|
||||
reqMgs = append(reqMgs, m)
|
||||
}
|
||||
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
})
|
||||
var apiKey string
|
||||
response, err := h.fakeRequest(ctx, userVo, &apiKey, req)
|
||||
response, err := h.doRequest(ctx, userVo, &apiKey, req)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
@ -213,176 +221,191 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
defer response.Body.Close()
|
||||
}
|
||||
|
||||
//contentType := response.Header.Get("Content-Type")
|
||||
//if strings.Contains(contentType, "text/event-stream") || true {
|
||||
if true {
|
||||
replyCreatedAt := time.Now()
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var functionCall = false
|
||||
var functionName string
|
||||
var arguments = make([]string, 0)
|
||||
reader := bufio.NewReader(response.Body)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
replyMessage(ws, ErrorMsg)
|
||||
replyMessage(ws, "")
|
||||
break
|
||||
}
|
||||
|
||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
if functionCall && fun.Name == "" {
|
||||
arguments = append(arguments, fun.Arguments)
|
||||
continue
|
||||
}
|
||||
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionCall = true
|
||||
functionName = fun.Name
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)})
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
break
|
||||
}
|
||||
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
continue
|
||||
} else if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, content)
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: responseBody.Choices[0].Delta.Content,
|
||||
})
|
||||
}
|
||||
} // end for
|
||||
|
||||
if functionCall { // 调用函数完成任务
|
||||
// TODO 调用函数完成任务
|
||||
data, err := h.funcZaoBao.Fetch()
|
||||
if err != nil {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: "调用函数出错",
|
||||
})
|
||||
} else {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: data,
|
||||
})
|
||||
}
|
||||
contents = append(contents, data)
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
if userVo.ChatConfig.EnableHistory {
|
||||
// for prompt
|
||||
token, err := utils.CalcTokens(prompt, req.Model)
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
if true {
|
||||
replyCreatedAt := time.Now()
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var functionCall = false
|
||||
var functionName string
|
||||
var arguments = make([]string, 0)
|
||||
reader := bufio.NewReader(response.Body)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
} else if err != io.EOF {
|
||||
logger.Error(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
historyUserMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: user.Avatar,
|
||||
Content: prompt,
|
||||
Tokens: token,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.db.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
// for reply
|
||||
token, err = utils.CalcTokens(message.Content, req.Model)
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
replyMessage(ws, ErrorMsg)
|
||||
replyMessage(ws, "")
|
||||
break
|
||||
}
|
||||
|
||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
if functionCall && fun.Name == "" {
|
||||
arguments = append(arguments, fun.Arguments)
|
||||
continue
|
||||
}
|
||||
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionCall = true
|
||||
functionName = fun.Name
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
break
|
||||
}
|
||||
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
continue
|
||||
} else if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
}
|
||||
} // end for
|
||||
|
||||
if functionCall { // 调用函数完成任务
|
||||
logger.Info(functionName)
|
||||
logger.Info(arguments)
|
||||
// TODO 调用函数完成任务
|
||||
data, err := h.funcZaoBao.Fetch()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: "调用函数出错",
|
||||
})
|
||||
} else {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: data,
|
||||
})
|
||||
}
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: token,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.db.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 统计用户 token 数量
|
||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
||||
contents = append(contents, data)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.Model = session.Model
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
req.Messages = append(req.Messages, message)
|
||||
totalTokens := getTotalTokens(req)
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
if userVo.ChatConfig.EnableHistory {
|
||||
useContext := true
|
||||
if functionCall {
|
||||
useContext = false
|
||||
}
|
||||
|
||||
// for prompt
|
||||
token, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyUserMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: user.Avatar,
|
||||
Content: prompt,
|
||||
Tokens: token,
|
||||
UseContext: useContext,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
res := h.db.Save(&historyUserMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
token, err = utils.CalcTokens(message.Content, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: token,
|
||||
UseContext: useContext,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
res = h.db.Create(&historyReplyMsg)
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 统计用户 token 数量
|
||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
|
||||
if res.Error != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.UserId = session.UserId
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.Model = session.Model
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
}
|
||||
h.db.Create(&chatItem)
|
||||
}
|
||||
h.db.Create(&chatItem)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
|
||||
link := "https://img.r9it.com/chatgpt/response"
|
||||
client := &http.Client{}
|
||||
request, _ := http.NewRequest(http.MethodGet, link, nil)
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
// 回复客户片段端消息
|
||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
||||
msg, err := json.Marshal(message)
|
||||
@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
resp.SUCCESS(c, tokens)
|
||||
}
|
||||
|
||||
func getTotalTokens(req types.ApiRequest) int {
|
||||
encode := utils.JsonEncode(req.Messages)
|
||||
var items []map[string]interface{}
|
||||
err := utils.JsonDecode(encode, &items)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
tokens := 0
|
||||
for _, item := range items {
|
||||
content, ok := item["content"]
|
||||
if ok && !utils.IsEmptyValue(content) {
|
||||
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
|
||||
if err == nil {
|
||||
tokens += t
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// StopGenerate 停止生成
|
||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
sessionId := c.Query("session_id")
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"chatplus/store"
|
||||
"context"
|
||||
"embed"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
@ -101,9 +102,12 @@ func main() {
|
||||
}),
|
||||
|
||||
// 创建函数
|
||||
fx.Provide(func() *function.FuncZaoBao {
|
||||
fx.Provide(func() (*function.FuncZaoBao, error) {
|
||||
token := os.Getenv("AL_API_TOKEN")
|
||||
return function.NewZaoBao(token)
|
||||
if token == "" {
|
||||
return nil, errors.New("invalid AL api token")
|
||||
}
|
||||
return function.NewZaoBao(token), nil
|
||||
}),
|
||||
|
||||
// 创建控制器
|
||||
|
@ -2,13 +2,14 @@ package model
|
||||
|
||||
type HistoryMessage struct {
|
||||
BaseModel
|
||||
ChatId string // 会话 ID
|
||||
UserId uint // 用户 ID
|
||||
RoleId uint // 角色 ID
|
||||
Type string
|
||||
Icon string
|
||||
Tokens int
|
||||
Content string
|
||||
ChatId string // 会话 ID
|
||||
UserId uint // 用户 ID
|
||||
RoleId uint // 角色 ID
|
||||
Type string
|
||||
Icon string
|
||||
Tokens int
|
||||
Content string
|
||||
UseContext bool // 是否可以作为聊天上下文
|
||||
}
|
||||
|
||||
func (HistoryMessage) TableName() string {
|
||||
|
@ -2,13 +2,14 @@ package vo
|
||||
|
||||
type HistoryMessage struct {
|
||||
BaseVo
|
||||
ChatId string `json:"chat_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
RoleId uint `json:"role_id"`
|
||||
Type string `json:"type"`
|
||||
Icon string `json:"icon"`
|
||||
Tokens int `json:"tokens"`
|
||||
Content string `json:"content"`
|
||||
ChatId string `json:"chat_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
RoleId uint `json:"role_id"`
|
||||
Type string `json:"type"`
|
||||
Icon string `json:"icon"`
|
||||
Tokens int `json:"tokens"`
|
||||
Content string `json:"content"`
|
||||
UseContext bool `json:"use_context"`
|
||||
}
|
||||
|
||||
func (HistoryMessage) TableName() string {
|
||||
|
@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string {
|
||||
func JsonDecode(src string, dest interface{}) error {
|
||||
return json.Unmarshal([]byte(src), dest)
|
||||
}
|
||||
|
||||
func InterfaceToString(value interface{}) string {
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
return JsonEncode(value)
|
||||
}
|
||||
|
1
database/plugins.sql
Normal file
1
database/plugins.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`;
|
@ -94,6 +94,7 @@
|
||||
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
|
||||
<div class="chat-head">
|
||||
<div class="chat-config">
|
||||
<span class="role-select-label">聊天角色:</span>
|
||||
<el-select v-model="roleId" filterable placeholder="角色" class="role-select">
|
||||
<el-option
|
||||
v-for="item in roles"
|
||||
@ -210,7 +211,8 @@ import {
|
||||
Check,
|
||||
Close,
|
||||
Delete,
|
||||
Edit, Iphone,
|
||||
Edit,
|
||||
Iphone,
|
||||
Plus,
|
||||
Promotion,
|
||||
RefreshRight,
|
||||
@ -920,6 +922,10 @@ $borderColor = #4676d0;
|
||||
justify-content center;
|
||||
padding-top 10px;
|
||||
|
||||
.role-select-label {
|
||||
color #ffffff
|
||||
}
|
||||
|
||||
.el-select {
|
||||
//max-width 150px;
|
||||
margin-right 10px;
|
||||
|
Loading…
Reference in New Issue
Block a user