mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
refactor: 重构项目,为所有的 AI 工具都引入算力,采用算力统一结算各个工具的调用次数和权限
This commit is contained in:
@@ -103,8 +103,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
h.subUserCalls(userVo, session)
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
@@ -145,8 +143,8 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens += getTotalTokens(req)
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
replyTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
@@ -155,7 +153,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
Tokens: replyTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
@@ -166,8 +164,8 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
h.subUserCalls(userVo, session)
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
@@ -171,8 +168,8 @@ func (h *ChatHandler) sendBaiduMessage(
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyToken + getTotalTokens(req)
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
@@ -190,8 +187,8 @@ func (h *ChatHandler) sendBaiduMessage(
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
// 更新用户信息
|
||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
@@ -111,7 +111,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
session.Model = types.ChatModel{
|
||||
Id: chatModel.Id,
|
||||
Value: chatModel.Value,
|
||||
Weight: chatModel.Weight,
|
||||
Power: chatModel.Power,
|
||||
Platform: types.Platform(chatModel.Platform)}
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||
var chatRole model.ChatRole
|
||||
@@ -207,13 +207,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
return nil
|
||||
}
|
||||
|
||||
if userVo.Calls < session.Model.Weight {
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d)已不足以支付当前模型的单次对话需要消耗的对话额度(%d)!", userVo.Calls, session.Model.Weight))
|
||||
if userVo.Power < session.Model.Power {
|
||||
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d)已不足以支付当前模型的单次对话需要消耗的对话额度(%d)!", userVo.Power, session.Model.Power))
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
}
|
||||
|
||||
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
if userVo.Power <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
@@ -533,23 +533,28 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
// 扣减用户的对话次数
|
||||
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
|
||||
// 仅当用户没有导入自己的 API KEY 时才进行扣减
|
||||
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
||||
num := 1
|
||||
if session.Model.Weight > 0 {
|
||||
num = session.Model.Weight
|
||||
}
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
|
||||
// 扣减用户算力
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
|
||||
power := 1
|
||||
if session.Model.Power > 0 {
|
||||
power = session.Model.Power
|
||||
}
|
||||
res := h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
|
||||
if res.Error == nil {
|
||||
// 记录算力消费日志
|
||||
h.db.Debug().Create(&model.PowerLog{
|
||||
UserId: userVo.Id,
|
||||
Username: userVo.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: power,
|
||||
Mark: types.PowerSub,
|
||||
Balance: userVo.Power - power,
|
||||
Model: session.Model.Value,
|
||||
Remark: fmt.Sprintf("提问长度:%d,回复长度:%d", promptTokens, replyTokens),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
|
||||
h.db.Model(&model.User{}).Where("id = ?", userId).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
|
||||
h.db.Model(&model.User{}).Where("id = ?", userId).
|
||||
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
|
||||
}
|
||||
|
||||
// 将AI回复消息中生成的图片链接下载到本地
|
||||
|
||||
@@ -107,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
h.subUserCalls(userVo, session)
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
@@ -150,8 +147,8 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyToken + getTotalTokens(req)
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
@@ -169,8 +166,9 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
// 更新用户信息
|
||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
@@ -173,9 +173,6 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
h.subUserCalls(userVo, session)
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
@@ -220,16 +217,16 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
var replyTokens = 0
|
||||
if toolCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
||||
totalTokens += tokens
|
||||
replyTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
replyTokens += tokens
|
||||
} else {
|
||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
totalTokens += getTotalTokens(req)
|
||||
replyTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
@@ -238,7 +235,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: h.extractImgUrl(message.Content),
|
||||
Tokens: totalTokens,
|
||||
Tokens: replyTokens,
|
||||
UseContext: useContext,
|
||||
Model: req.Model,
|
||||
}
|
||||
@@ -249,8 +246,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendQWenMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
h.subUserCalls(userVo, session)
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
@@ -171,8 +168,8 @@ func (h *ChatHandler) sendQWenMessage(
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyToken + getTotalTokens(req)
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
@@ -190,8 +187,9 @@ func (h *ChatHandler) sendQWenMessage(
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
// 更新用户信息
|
||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
@@ -166,9 +166,6 @@ func (h *ChatHandler) sendXunFeiMessage(
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
h.subUserCalls(userVo, session)
|
||||
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
@@ -209,8 +206,8 @@ func (h *ChatHandler) sendXunFeiMessage(
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyToken + getTotalTokens(req)
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
@@ -228,8 +225,9 @@ func (h *ChatHandler) sendXunFeiMessage(
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
// 更新用户信息
|
||||
h.incUserTokenFee(userVo.Id, totalTokens)
|
||||
|
||||
// 更新用户算力
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
Reference in New Issue
Block a user