diff --git a/api/core/app_server.go b/api/core/app_server.go index 73b5b0ac..e93e024f 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -30,8 +30,7 @@ type AppServer struct { Engine *gin.Engine ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message - ChatConfig *types.ChatConfig // chat config cache - SysConfig *types.SystemConfig // system config cache + SysConfig *types.SystemConfig // system config cache // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API @@ -69,23 +68,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) { } func (s *AppServer) Run(db *gorm.DB) error { - // load chat config from database - var chatConfig model.Config - res := db.Where("marker", "chat").First(&chatConfig) - if res.Error != nil { - return res.Error - } - err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig) - if err != nil { - return err - } // load system configs var sysConfig model.Config - res = db.Where("marker", "system").First(&sysConfig) + res := db.Where("marker", "system").First(&sysConfig) if res.Error != nil { return res.Error } - err = utils.JsonDecode(sysConfig.Config, &s.SysConfig) + err := utils.JsonDecode(sysConfig.Config, &s.SysConfig) if err != nil { return err } diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 83742852..70333339 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -54,10 +54,13 @@ type ChatSession struct { } type ChatModel struct { - Id uint `json:"id"` - Platform Platform `json:"platform"` - Value string `json:"value"` - Power int `json:"power"` + Id uint `json:"id"` + Platform Platform `json:"platform"` + Value string `json:"value"` + Power int `json:"power"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature float32 `json:"temperature"` // 模型温度 } type ApiError struct { @@ -72,27 +75,6 @@ type ApiError struct { const PromptMsg = "prompt" // prompt message const ReplyMsg = "reply" // reply message -var ModelToTokens = map[string]int{ - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192, - "gpt-4-32k": 32768, - "chatglm_pro": 32768, // 清华智普 - "chatglm_std": 16384, - "chatglm_lite": 4096, - "ernie_bot_turbo": 8192, // 文心一言 - "general": 8192, // 科大讯飞 - "general2": 8192, - "general3": 8192, -} - -func GetModelMaxToken(model string) int { - if token, ok := ModelToTokens[model]; ok { - return token - } - return 4096 -} - // PowerType 算力日志类型 type PowerType int diff --git a/api/core/types/config.go b/api/core/types/config.go index a8bf061e..8a1a2e85 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -121,20 +121,6 @@ func (c RedisConfig) Url() string { return fmt.Sprintf("%s:%d", c.Host, c.Port) } -// ChatConfig 系统默认的聊天配置 -type ChatConfig struct { - OpenAI ModelAPIConfig `json:"open_ai"` - Azure ModelAPIConfig `json:"azure"` - ChatGML ModelAPIConfig `json:"chat_gml"` - Baidu ModelAPIConfig `json:"baidu"` - XunFei ModelAPIConfig `json:"xun_fei"` - - EnableContext bool `json:"enable_context"` // 是否开启聊天上下文 - EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录 - ContextDeep int `json:"context_deep"` // 上下文深度 - DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量 -} - type Platform string const OpenAI = Platform("OpenAI") @@ -144,16 +130,6 @@ const Baidu = Platform("Baidu") const XunFei = Platform("XunFei") const QWen = Platform("QWen") -// UserChatConfig 用户的聊天配置 -type UserChatConfig struct { - ApiKeys map[Platform]string `json:"api_keys"` -} - -type ModelAPIConfig struct { - Temperature float32 `json:"temperature"` - MaxTokens int `json:"max_tokens"` -} - type SystemConfig struct { Title string `json:"title"` AdminTitle string `json:"admin_title"` @@ -178,4 +154,7 @@ type SystemConfig struct { DallPower int `json:"dall_power"` // DALLE3 绘图消耗算力 WechatCardURL string `json:"wechat_card_url"` // 微信客服地址 + + EnableContext bool `json:"enable_context"` + ContextDeep int `json:"context_deep"` } diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go index 52b95358..75ee50ba 100644 --- a/api/handler/admin/api_key_handler.go +++ b/api/handler/admin/api_key_handler.go @@ -32,7 +32,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { Value string `json:"value"` ApiURL string `json:"api_url"` Enabled bool `json:"enabled"` - UseProxy bool `json:"use_proxy"` + ProxyURL string `json:"proxy_url"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) @@ -48,7 +48,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { apiKey.Type = data.Type apiKey.ApiURL = data.ApiURL apiKey.Enabled = data.Enabled - apiKey.UseProxy = data.UseProxy + apiKey.ProxyURL = data.ProxyURL apiKey.Name = data.Name res := h.db.Save(&apiKey) if res.Error != nil { diff --git a/api/handler/admin/chat_handler.go b/api/handler/admin/chat_handler.go index 4e06f2a5..f94c834e 100644 --- a/api/handler/admin/chat_handler.go +++ b/api/handler/admin/chat_handler.go @@ -24,14 +24,15 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { } type chatItemVo struct { - Username string `json:"username"` - UserId uint `json:"user_id"` - ChatId string `json:"chat_id"` - Title string `json:"title"` - Model string `json:"model"` - Token int `json:"token"` - CreatedAt int64 `json:"created_at"` - MsgNum int `json:"msg_num"` // 消息数量 + Username string `json:"username"` + UserId uint `json:"user_id"` + ChatId string `json:"chat_id"` + Title string `json:"title"` + Role vo.ChatRole `json:"role"` + Model string `json:"model"` + Token int `json:"token"` + CreatedAt int64 `json:"created_at"` + MsgNum int `json:"msg_num"` // 消息数量 } func (h *ChatHandler) List(c *gin.Context) { @@ -78,18 +79,23 @@ func (h *ChatHandler) List(c *gin.Context) { if res.Error == nil { userIds := make([]uint, 0) chatIds := make([]string, 0) + roleIds := make([]uint, 0) for _, item := range items { userIds = append(userIds, item.UserId) chatIds = append(chatIds, item.ChatId) + roleIds = append(roleIds, item.RoleId) } var messages []model.ChatMessage var users []model.User + var roles []model.ChatRole h.db.Where("chat_id IN ?", chatIds).Find(&messages) h.db.Where("id IN ?", userIds).Find(&users) + h.db.Where("id IN ?", roleIds).Find(&roles) tokenMap := make(map[string]int) userMap := make(map[uint]string) msgMap := make(map[string]int) + roleMap := make(map[uint]vo.ChatRole) for _, msg := range messages { tokenMap[msg.ChatId] += msg.Tokens msgMap[msg.ChatId] += 1 @@ -97,6 +103,14 @@ func (h *ChatHandler) List(c *gin.Context) { for _, user := range users { userMap[user.Id] = user.Username } + for _, r := range roles { + var roleVo vo.ChatRole + err := utils.CopyObject(r, &roleVo) + if err != nil { + continue + } + roleMap[r.Id] = roleVo + } for _, item := range items { list = append(list, chatItemVo{ UserId: item.UserId, @@ -106,6 +120,7 @@ func (h *ChatHandler) List(c *gin.Context) { Model: item.Model, Token: tokenMap[item.ChatId], MsgNum: msgMap[item.ChatId], + Role: roleMap[item.RoleId], CreatedAt: item.CreatedAt.Unix(), }) } diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go index 56d1ef19..a1bd0c73 100644 --- a/api/handler/admin/chat_model_handler.go +++ b/api/handler/admin/chat_model_handler.go @@ -26,15 +26,18 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { func (h *ChatModelHandler) Save(c *gin.Context) { var data struct { - Id uint `json:"id"` - Name string `json:"name"` - Value string `json:"value"` - Enabled bool `json:"enabled"` - SortNum int `json:"sort_num"` - Open bool `json:"open"` - Platform string `json:"platform"` - Weight int `json:"weight"` - CreatedAt int64 `json:"created_at"` + Id uint `json:"id"` + Name string `json:"name"` + Value string `json:"value"` + Enabled bool `json:"enabled"` + SortNum int `json:"sort_num"` + Open bool `json:"open"` + Platform string `json:"platform"` + Power int `json:"power"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature string `json:"temperature"` // 模型温度 + CreatedAt int64 `json:"created_at"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) @@ -42,13 +45,16 @@ func (h *ChatModelHandler) Save(c *gin.Context) { } item := model.ChatModel{ - Platform: data.Platform, - Name: data.Name, - Value: data.Value, - Enabled: data.Enabled, - SortNum: data.SortNum, - Open: data.Open, - Power: data.Weight} + Platform: data.Platform, + Name: data.Name, + Value: data.Value, + Enabled: data.Enabled, + SortNum: data.SortNum, + Open: data.Open, + MaxTokens: data.MaxTokens, + MaxContext: data.MaxContext, + Temperature: float32(utils.Str2Float(data.Temperature)), + Power: data.Power} item.Id = data.Id if item.Id > 0 { item.CreatedAt = time.Unix(data.CreatedAt, 0) @@ -145,19 +151,16 @@ func (h *ChatModelHandler) Sort(c *gin.Context) { } func (h *ChatModelHandler) Remove(c *gin.Context) { - var data struct { - Id uint - } - if err := c.ShouldBindJSON(&data); err != nil { + id := h.GetInt(c, "id", 0) + if id <= 0 { resp.ERROR(c, types.InvalidArgs) return } - if data.Id > 0 { - res := h.db.Where("id = ?", data.Id).Delete(&model.ChatModel{}) - if res.Error != nil { - resp.ERROR(c, "更新数据库失败!") - return - } + + res := h.db.Where("id = ?", id).Delete(&model.ChatModel{}) + if res.Error != nil { + resp.ERROR(c, "更新数据库失败!") + return } resp.SUCCESS(c) } diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index a82d3426..08c939ba 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -56,8 +56,6 @@ func (h *ConfigHandler) Update(c *gin.Context) { var err error if data.Key == "system" { err = utils.JsonDecode(cfg.Config, &h.App.SysConfig) - } else if data.Key == "chat" { - err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig) } if err != nil { resp.ERROR(c, "Failed to update config cache: "+err.Error()) diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go index 580831e7..b9213da6 100644 --- a/api/handler/admin/user_handler.go +++ b/api/handler/admin/user_handler.go @@ -107,13 +107,6 @@ func (h *UserHandler) Save(c *gin.Context) { ChatRoles: utils.JsonEncode(data.ChatRoles), ChatModels: utils.JsonEncode(data.ChatModels), ExpiredTime: utils.Str2stamp(data.ExpiredTime), - ChatConfig: utils.JsonEncode(types.UserChatConfig{ - ApiKeys: map[types.Platform]string{ - types.OpenAI: "", - types.Azure: "", - types.ChatGLM: "", - }, - }), } res = h.db.Create(&u) _ = utils.CopyObject(u, &userVo) diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index 87bd4599..4c7151b4 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -111,66 +111,64 @@ func (h *ChatHandler) sendAzureMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext { + if h.App.SysConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - replyTokens += getTotalTokens(req) - - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: replyTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + replyTokens += getTotalTokens(req) + + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: replyTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) // 保存当前会话 var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) if res.Error != nil { chatItem.ChatId = session.ChatId chatItem.UserId = session.UserId diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go index 7c227110..cce6bd3b 100644 --- a/api/handler/chatimpl/baidu_handler.go +++ b/api/handler/chatimpl/baidu_handler.go @@ -135,65 +135,63 @@ func (h *ChatHandler) sendBaiduMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext { + if h.App.SysConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyTokens + getTotalTokens(req) + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) // 保存当前会话 var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) if res.Error != nil { chatItem.ChatId = session.ChatId chatItem.UserId = session.UserId diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 8327abe2..967aa479 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -57,8 +57,6 @@ func (h *ChatHandler) Init() { } } -var chatConfig types.ChatConfig - // ChatHandle 处理聊天 WebSocket 请求 func (h *ChatHandler) ChatHandle(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) @@ -109,10 +107,13 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { session.ChatId = chatId session.Model = types.ChatModel{ - Id: chatModel.Id, - Value: chatModel.Value, - Power: chatModel.Power, - Platform: types.Platform(chatModel.Platform)} + Id: chatModel.Id, + Value: chatModel.Value, + Power: chatModel.Power, + MaxTokens: chatModel.MaxTokens, + MaxContext: chatModel.MaxContext, + Temperature: chatModel.Temperature, + Platform: types.Platform(chatModel.Platform)} logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) var chatRole model.ChatRole res = h.db.First(&chatRole, roleId) @@ -122,15 +123,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { return } - // 初始化聊天配置 - var config model.Config - h.db.Where("marker", "chat").First(&config) - err = utils.JsonDecode(config.Config, &chatConfig) - if err != nil { - utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!") - c.Abort() - return - } h.Init() // 保存会话连接 @@ -213,7 +205,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return nil } - if userVo.Power <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { + if userVo.Power <= 0 { utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!") utils.ReplyMessage(ws, ErrImg) return nil @@ -227,7 +219,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio // 检查 prompt 长度是否超过了当前模型允许的最大上下文长度 promptTokens, err := utils.CalcTokens(prompt, session.Model.Value) - if promptTokens > types.GetModelMaxToken(session.Model.Value) { + if promptTokens > session.Model.MaxContext { utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!") return nil } @@ -237,21 +229,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Stream: true, } switch session.Model.Platform { - case types.Azure: - req.Temperature = h.App.ChatConfig.Azure.Temperature - req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens - break - case types.ChatGLM: - req.Temperature = h.App.ChatConfig.ChatGML.Temperature - req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens - break - case types.Baidu: - req.Temperature = h.App.ChatConfig.OpenAI.Temperature - // TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持 + case types.Azure, types.ChatGLM, types.Baidu, types.XunFei: + req.Temperature = session.Model.Temperature + req.MaxTokens = session.Model.MaxTokens break case types.OpenAI: - req.Temperature = h.App.ChatConfig.OpenAI.Temperature - req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens + req.Temperature = session.Model.Temperature + req.MaxTokens = session.Model.MaxTokens // OpenAI 支持函数功能 var items []model.Function res := h.db.Where("enabled", true).Find(&items) @@ -283,15 +267,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio req.Tools = tools req.ToolChoice = "auto" } - - case types.XunFei: - req.Temperature = h.App.ChatConfig.XunFei.Temperature - req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens - break case types.QWen: - req.Input = map[string]interface{}{"messages": []map[string]string{{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}}} - req.Parameters = map[string]interface{}{} + req.Parameters = map[string]interface{}{ + "max_tokens": session.Model.MaxTokens, + "temperature": session.Model.Temperature, + } break + default: utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") utils.ReplyMessage(ws, ErrImg) @@ -301,14 +283,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio // 加载聊天上下文 chatCtx := make([]types.Message, 0) messages := make([]types.Message, 0) - if h.App.ChatConfig.EnableContext { + if h.App.SysConfig.EnableContext { if h.App.ChatContexts.Has(session.ChatId) { messages = h.App.ChatContexts.Get(session.ChatId) } else { _ = utils.JsonDecode(role.Context, &messages) - if chatConfig.ContextDeep > 0 { + if h.App.SysConfig.ContextDeep > 0 { var historyMessages []model.ChatMessage - res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages) + res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages) if res.Error == nil { for i := len(historyMessages) - 1; i >= 0; i-- { msg := historyMessages[i] @@ -331,12 +313,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio for _, v := range messages { tks, _ := utils.CalcTokens(v.Content, req.Model) // 上下文 token 超出了模型的最大上下文长度 - if tokens+tks >= types.GetModelMaxToken(req.Model) { + if tokens+tks >= session.Model.MaxContext { break } // 上下文的深度超出了模型的最大上下文深度 - if len(chatCtx) >= h.App.ChatConfig.ContextDeep { + if len(chatCtx) >= h.App.SysConfig.ContextDeep { break } @@ -351,10 +333,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio reqMgs = append(reqMgs, m) } - req.Messages = append(reqMgs, map[string]interface{}{ - "role": "user", - "content": prompt, - }) + if session.Model.Platform == types.QWen { + req.Input = map[string]interface{}{"prompt": prompt} + if len(reqMgs) > 0 { + req.Input["messages"] = reqMgs + } + } else { + req.Messages = append(reqMgs, map[string]interface{}{ + "role": "user", + "content": prompt, + }) + } switch session.Model.Platform { case types.Azure: @@ -497,9 +486,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf request = request.WithContext(ctx) request.Header.Set("Content-Type", "application/json") var proxyURL string - if h.App.Config.ProxyURL != "" && apiKey.UseProxy { // 使用代理 - proxyURL = h.App.Config.ProxyURL - proxy, _ := url.Parse(proxyURL) + if apiKey.ProxyURL != "" { // 使用代理 + proxy, _ := url.Parse(apiKey.ProxyURL) client = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxy), @@ -542,7 +530,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p res := h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power)) if res.Error == nil { // 记录算力消费日志 - h.db.Debug().Create(&model.PowerLog{ + h.db.Create(&model.PowerLog{ UserId: userVo.Id, Username: userVo.Username, Type: types.PowerConsume, diff --git a/api/handler/chatimpl/chat_item_handler.go b/api/handler/chatimpl/chat_item_handler.go index 285894be..532c7939 100644 --- a/api/handler/chatimpl/chat_item_handler.go +++ b/api/handler/chatimpl/chat_item_handler.go @@ -126,7 +126,7 @@ func (h *ChatHandler) History(c *gin.Context) { chatId := c.Query("chat_id") // 会话 ID var items []model.ChatMessage var messages = make([]vo.HistoryMessage, 0) - res := h.db.Debug().Where("chat_id = ?", chatId).Find(&items) + res := h.db.Where("chat_id = ?", chatId).Find(&items) if res.Error != nil { resp.ERROR(c, "No history message") return diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go index 3db444b8..602fc779 100644 --- a/api/handler/chatimpl/chatglm_handler.go +++ b/api/handler/chatimpl/chatglm_handler.go @@ -114,66 +114,64 @@ func (h *ChatHandler) sendChatGLMMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext { + if h.App.SysConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyTokens + getTotalTokens(req) + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) // 保存当前会话 var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) if res.Error != nil { chatItem.ChatId = session.ChatId chatItem.UserId = session.UserId diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index 13df6707..01e60912 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -180,79 +180,77 @@ func (h *ChatHandler) sendOpenAiMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext && toolCall == false { + if h.App.SysConfig.EnableContext && toolCall == false { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - useContext := true - if toolCall { - useContext = false - } - - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: useContext, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // 计算本次对话消耗的总 token 数量 - var replyTokens = 0 - if toolCall { // prompt + 函数名 + 参数 token - tokens, _ := utils.CalcTokens(function.Name, req.Model) - replyTokens += tokens - tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) - replyTokens += tokens - } else { - replyTokens, _ = utils.CalcTokens(message.Content, req.Model) - } - replyTokens += getTotalTokens(req) - - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: h.extractImgUrl(message.Content), - Tokens: replyTokens, - UseContext: useContext, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) + useContext := true + if toolCall { + useContext = false } + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: useContext, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // 计算本次对话消耗的总 token 数量 + var replyTokens = 0 + if toolCall { // prompt + 函数名 + 参数 token + tokens, _ := utils.CalcTokens(function.Name, req.Model) + replyTokens += tokens + tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) + replyTokens += tokens + } else { + replyTokens, _ = utils.CalcTokens(message.Content, req.Model) + } + replyTokens += getTotalTokens(req) + + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: h.extractImgUrl(message.Content), + Tokens: replyTokens, + UseContext: useContext, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) + // 保存当前会话 var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) if res.Error != nil { chatItem.ChatId = session.ChatId chatItem.UserId = session.UserId diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go index 0a6b7e9e..c8e87c07 100644 --- a/api/handler/chatimpl/qwen_handler.go +++ b/api/handler/chatimpl/qwen_handler.go @@ -20,13 +20,16 @@ type qWenResp struct { Output struct { FinishReason string `json:"finish_reason"` Text string `json:"text"` - } `json:"output"` + } `json:"output,omitempty"` Usage struct { TotalTokens int `json:"total_tokens"` InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` - } `json:"usage"` + } `json:"usage,omitempty"` RequestID string `json:"request_id"` + + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` } // 通义千问消息发送实现 @@ -70,6 +73,7 @@ func (h *ChatHandler) sendQWenMessage( scanner := bufio.NewScanner(response.Body) var content, lastText, newText string + var outPutStart = false for scanner.Scan() { line := scanner.Text() @@ -77,24 +81,32 @@ func (h *ChatHandler) sendQWenMessage( strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { continue } + if strings.HasPrefix(line, "data:") { content = line[5:] } - // 处理代码换行 - if len(content) == 0 { - content = "\n" - } var resp qWenResp - err := utils.JsonDecode(content, &resp) - if err != nil { - logger.Error("error with parse data line: ", err) - utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) - break - } - if len(contents) == 0 { // 发送消息头 - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + if !outPutStart { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + outPutStart = true + continue + } else { + // 处理代码换行 + content = "\n" + } + } else { + err := utils.JsonDecode(content, &resp) + if err != nil { + logger.Error("error with parse data line: ", content) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + break + } + if resp.Message != "" { + utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message)) + break + } } //通过比较 lastText(上一次的文本)和 currentText(当前的文本), @@ -135,66 +147,64 @@ func (h *ChatHandler) sendQWenMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext { + if h.App.SysConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyTokens + getTotalTokens(req) + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) // 保存当前会话 var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) if res.Error != nil { chatItem.ChatId = session.ChatId chatItem.UserId = session.UserId diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index c29d1fd1..595d4a5a 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -50,9 +50,10 @@ type xunFeiResp struct { } var Model2URL = map[string]string{ - "general": "v1.1", - "generalv2": "v2.1", - "generalv3": "v3.1", + "general": "v1.1", + "generalv2": "v2.1", + "generalv3": "v3.1", + "generalv3.5": "v3.5", } // 科大讯飞消息发送实现 @@ -86,6 +87,7 @@ func (h *ChatHandler) sendXunFeiMessage( } apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) //握手并建立websocket 连接 conn, resp, err := d.Dial(wsURL, nil) @@ -173,66 +175,64 @@ func (h *ChatHandler) sendXunFeiMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext { + if h.App.SysConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyTokens + getTotalTokens(req) + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) // 保存当前会话 var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) if res.Error != nil { chatItem.ChatId = session.ChatId chatItem.UserId = session.UserId @@ -260,7 +260,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} { "parameter": map[string]interface{}{ "chat": map[string]interface{}{ "domain": req.Model, - "temperature": float64(req.Temperature), + "temperature": req.Temperature, "top_k": int64(6), "max_tokens": int64(req.MaxTokens), "auditing": "default", diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index 488aa8b3..931ad891 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -22,7 +22,6 @@ type FunctionHandler struct { db *gorm.DB config types.ChatPlusApiConfig uploadManager *oss.UploaderManager - proxyURL string } func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { @@ -33,7 +32,6 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo db: db, config: config.ApiConfig, uploadManager: manager, - proxyURL: config.ProxyURL, } } @@ -213,47 +211,28 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { return } - // get image generation api URL - var conf model.Config - var chatConfig types.ChatConfig - tx = h.db.Where("marker", "chat").First(&conf) - if tx.Error != nil { - resp.ERROR(c, "error with get chat configs:"+tx.Error.Error()) - return - } - - err := utils.JsonDecode(conf.Config, &chatConfig) - if err != nil { - resp.ERROR(c, "error with decode chat config: "+err.Error()) - return - } - // translate prompt const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" - pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"]), h.App.Config.ProxyURL) + pt, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, params["prompt"])) if err == nil { logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt) prompt = pt } - imgNum := chatConfig.DallImgNum - if imgNum <= 0 { - imgNum = 1 - } var res imgRes var errRes ErrRes var request *req.Request - if apiKey.UseProxy && h.proxyURL != "" { - request = req.C().SetProxyURL(h.proxyURL).R() + if apiKey.ProxyURL != "" { + request = req.C().SetProxyURL(apiKey.ProxyURL).R() } else { request = req.C().R() } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, h.proxyURL) + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) r, err := request.SetHeader("Content-Type", "application/json"). SetHeader("Authorization", "Bearer "+apiKey.Value). SetBody(imgReq{ Model: "dall-e-3", Prompt: prompt, - N: imgNum, + N: 1, Size: "1024x1024", }). SetErrorResult(&errRes). diff --git a/api/handler/prompt_handler.go b/api/handler/prompt_handler.go index 13aea57a..f33b93f8 100644 --- a/api/handler/prompt_handler.go +++ b/api/handler/prompt_handler.go @@ -35,7 +35,7 @@ func (h *PromptHandler) Rewrite(c *gin.Context) { return } - content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt), h.App.Config.ProxyURL) + content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(rewritePromptTemplate, data.Prompt)) if err != nil { resp.ERROR(c, err.Error()) return @@ -53,7 +53,7 @@ func (h *PromptHandler) Translate(c *gin.Context) { return } - content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt), h.App.Config.ProxyURL) + content, err := utils.OpenAIRequest(h.db, fmt.Sprintf(translatePromptTemplate, data.Prompt)) if err != nil { resp.ERROR(c, err.Error()) return diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 92248dee..f28a00d1 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -95,14 +95,7 @@ func (h *UserHandler) Register(c *gin.Context) { Status: true, ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型 - ChatConfig: utils.JsonEncode(types.UserChatConfig{ - ApiKeys: map[types.Platform]string{ - types.OpenAI: "", - types.Azure: "", - types.ChatGLM: "", - }, - }), - Power: h.App.SysConfig.InitPower, + Power: h.App.SysConfig.InitPower, } res = h.db.Create(&user) @@ -245,14 +238,13 @@ func (h *UserHandler) Session(c *gin.Context) { } type userProfile struct { - Id uint `json:"id"` - Nickname string `json:"nickname"` - Username string `json:"username"` - Avatar string `json:"avatar"` - ChatConfig types.UserChatConfig `json:"chat_config"` - Power int `json:"power"` - ExpiredTime int64 `json:"expired_time"` - Vip bool `json:"vip"` + Id uint `json:"id"` + Nickname string `json:"nickname"` + Username string `json:"username"` + Avatar string `json:"avatar"` + Power int `json:"power"` + ExpiredTime int64 `json:"expired_time"` + Vip bool `json:"vip"` } func (h *UserHandler) Profile(c *gin.Context) { diff --git a/api/main.go b/api/main.go index 3173c5e6..c1de4d6a 100644 --- a/api/main.go +++ b/api/main.go @@ -315,7 +315,7 @@ func main() { group.GET("list", h.List) group.POST("set", h.Set) group.POST("sort", h.Sort) - group.POST("remove", h.Remove) + group.GET("remove", h.Remove) }), fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) { group := s.Engine.Group("/api/payment/") diff --git a/api/store/model/api_key.go b/api/store/model/api_key.go index 109cec77..fb7ae1d4 100644 --- a/api/store/model/api_key.go +++ b/api/store/model/api_key.go @@ -9,6 +9,6 @@ type ApiKey struct { Value string // API Key 的值 ApiURL string // 当前 KEY 的 API 地址 Enabled bool // 是否启用 - UseProxy bool // 是否使用代理访问 API URL + ProxyURL string // 代理地址 LastUsedAt int64 // 最后使用时间 } diff --git a/api/store/model/chat_model.go b/api/store/model/chat_model.go index 71a2ab69..8ddff961 100644 --- a/api/store/model/chat_model.go +++ b/api/store/model/chat_model.go @@ -2,11 +2,14 @@ package model type ChatModel struct { BaseModel - Platform string - Name string - Value string // API Key 的值 - SortNum int - Enabled bool - Power int // 每次对话消耗算力 - Open bool // 是否开放模型给所有人使用 + Platform string + Name string + Value string // API Key 的值 + SortNum int + Enabled bool + Power int // 每次对话消耗算力 + Open bool // 是否开放模型给所有人使用 + MaxTokens int // 最大响应长度 + MaxContext int // 最大上下文长度 + Temperature float32 // 模型温度 } diff --git a/api/store/vo/api_key.go b/api/store/vo/api_key.go index d32233f3..7321b13f 100644 --- a/api/store/vo/api_key.go +++ b/api/store/vo/api_key.go @@ -9,6 +9,6 @@ type ApiKey struct { Value string `json:"value"` // API Key 的值 ApiURL string `json:"api_url"` Enabled bool `json:"enabled"` - UseProxy bool `json:"use_proxy"` + ProxyURL string `json:"proxy_url"` LastUsedAt int64 `json:"last_used_at"` // 最后使用时间 } diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go index 72e164fa..81fc18ca 100644 --- a/api/store/vo/chat_model.go +++ b/api/store/vo/chat_model.go @@ -2,11 +2,14 @@ package vo type ChatModel struct { BaseVo - Platform string `json:"platform"` - Name string `json:"name"` - Value string `json:"value"` - Enabled bool `json:"enabled"` - SortNum int `json:"sort_num"` - Weight int `json:"weight"` - Open bool `json:"open"` + Platform string `json:"platform"` + Name string `json:"name"` + Value string `json:"value"` + Enabled bool `json:"enabled"` + SortNum int `json:"sort_num"` + Power int `json:"power"` + Open bool `json:"open"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature float32 `json:"temperature"` // 模型温度 } diff --git a/api/store/vo/config.go b/api/store/vo/config.go index db8368fb..6b07094d 100644 --- a/api/store/vo/config.go +++ b/api/store/vo/config.go @@ -5,6 +5,5 @@ import "chatplus/core/types" type Config struct { Id uint `json:"id"` Key string `json:"key"` - ChatConfig types.ChatConfig `json:"chat_config"` SystemConfig types.SystemConfig `json:"system_config"` } diff --git a/api/store/vo/user.go b/api/store/vo/user.go index 30660673..0ba66169 100644 --- a/api/store/vo/user.go +++ b/api/store/vo/user.go @@ -1,20 +1,17 @@ package vo -import "chatplus/core/types" - type User struct { BaseVo - Username string `json:"username"` - Nickname string `json:"nickname"` - Avatar string `json:"avatar"` - Salt string `json:"salt"` // 密码盐 - Power int `json:"power"` // 剩余算力 - ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置 - ChatRoles []string `json:"chat_roles"` // 聊天角色集合 - ChatModels []string `json:"chat_models"` // AI模型集合 - ExpiredTime int64 `json:"expired_time"` // 账户到期时间 - Status bool `json:"status"` // 当前状态 - LastLoginAt int64 `json:"last_login_at"` // 最后登录时间 - LastLoginIp string `json:"last_login_ip"` // 最后登录 IP - Vip bool `json:"vip"` + Username string `json:"username"` + Nickname string `json:"nickname"` + Avatar string `json:"avatar"` + Salt string `json:"salt"` // 密码盐 + Power int `json:"power"` // 剩余算力 + ChatRoles []string `json:"chat_roles"` // 聊天角色集合 + ChatModels []string `json:"chat_models"` // AI模型集合 + ExpiredTime int64 `json:"expired_time"` // 账户到期时间 + Status bool `json:"status"` // 当前状态 + LastLoginAt int64 `json:"last_login_at"` // 最后登录时间 + LastLoginIp string `json:"last_login_ip"` // 最后登录 IP + Vip bool `json:"vip"` } diff --git a/api/utils/net.go b/api/utils/net.go index 58319dc6..39c6ddec 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -88,7 +88,7 @@ type apiErrRes struct { } `json:"error"` } -func OpenAIRequest(db *gorm.DB, prompt string, proxy string) (string, error) { +func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { var apiKey model.ApiKey res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey) if res.Error != nil { @@ -104,8 +104,8 @@ func OpenAIRequest(db *gorm.DB, prompt string, proxy string) (string, error) { var response apiRes var errRes apiErrRes client := req.C() - if apiKey.UseProxy && proxy != "" { - client.SetProxyURL(proxy) + if apiKey.ProxyURL != "" { + client.SetProxyURL(apiKey.ApiURL) } r, err := client.R().SetHeader("Content-Type", "application/json"). SetHeader("Authorization", "Bearer "+apiKey.Value). diff --git a/api/utils/strings.go b/api/utils/strings.go index ac8d24ae..ccccbc97 100644 --- a/api/utils/strings.go +++ b/api/utils/strings.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "math/rand" + "strconv" "time" "golang.org/x/crypto/sha3" @@ -59,7 +60,7 @@ func Str2stamp(str string) int64 { if len(str) == 0 { return 0 } - + layout := "2006-01-02 15:04:05" t, err := time.Parse(layout, str) if err != nil { @@ -92,3 +93,11 @@ func InterfaceToString(value interface{}) string { } return JsonEncode(value) } + +func Str2Float(str string) float64 { + num, err := strconv.ParseFloat(str, 64) + if err != nil { + return 0 + } + return num +} diff --git a/database/update-v4.0.0.sql b/database/update-v4.0.0.sql index 1622e11b..9779ecc1 100644 --- a/database/update-v4.0.0.sql +++ b/database/update-v4.0.0.sql @@ -130,3 +130,8 @@ INSERT INTO `chatgpt_admin_permissions` VALUES (31, '权限配置', '', 2, 28, ' INSERT INTO `chatgpt_admin_permissions` VALUES (32, '角色配置', '', 3, 28, '2024-03-14 15:29:15', '2024-03-14 15:29:15'); INSERT INTO `chatgpt_admin_permissions` VALUES (33, '列表', 'api_admin_sysPermission_list', 1, 31, '2024-03-14 15:29:52', '2024-03-14 15:29:52'); INSERT INTO `chatgpt_admin_permissions` VALUES (34, '列表', 'api_admin_sysRole_list', 1, 32, '2024-03-14 15:30:21', '2024-03-14 15:30:21'); + + +ALTER TABLE `chatgpt_api_keys` CHANGE `use_proxy` `proxy_url` VARCHAR(100) NULL DEFAULT NULL COMMENT '代理地址'; +-- 重置 proxy_url +UPDATE chatgpt_api_keys set proxy_url='' \ No newline at end of file diff --git a/new-ui/pnpm-lock.yaml b/new-ui/pnpm-lock.yaml index 1a97e1b2..b91a1de8 100644 --- a/new-ui/pnpm-lock.yaml +++ b/new-ui/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.0' +lockfileVersion: '6.1' settings: autoInstallPeers: true @@ -127,103 +127,6 @@ importers: specifier: ^1.8.27 version: 1.8.27(typescript@5.3.3) - projects/mobile: - dependencies: - '@element-plus/icons-vue': - specifier: ^2.1.0 - version: 2.3.1(vue@3.4.21) - axios: - specifier: ^0.27.2 - version: 0.27.2 - clipboard: - specifier: ^2.0.11 - version: 2.0.11 - compressorjs: - specifier: ^1.2.1 - version: 1.2.1 - core-js: - specifier: ^3.8.3 - version: 3.36.0 - element-plus: - specifier: ^2.3.0 - version: 2.6.1(vue@3.4.21) - good-storage: - specifier: ^1.1.1 - version: 1.1.1 - highlight.js: - specifier: ^11.7.0 - version: 11.9.0 - json-bigint: - specifier: ^1.0.0 - version: 1.0.0 - lodash: - specifier: ^4.17.21 - version: 4.17.21 - markdown-it: - specifier: ^13.0.1 - version: 13.0.2 - markdown-it-latex2img: - specifier: ^0.0.6 - version: 0.0.6 - markdown-it-mathjax: - specifier: ^2.0.0 - version: 2.0.0 - md-editor-v3: - specifier: ^2.2.1 - version: 2.11.3(vue@3.4.21) - pinia: - specifier: ^2.1.4 - version: 2.1.7(typescript@5.3.3)(vue@3.4.21) - qrcode: - specifier: ^1.5.3 - version: 1.5.3 - qs: - specifier: ^6.11.1 - version: 6.12.0 - sortablejs: - specifier: ^1.15.0 - version: 1.15.2 - v3-waterfall: - specifier: ^1.2.1 - version: 1.3.3 - vant: - specifier: ^4.5.0 - version: 4.8.5(vue@3.4.21) - vue: - specifier: ^3.2.13 - version: 3.4.21(typescript@5.3.3) - vue-router: - specifier: ^4.0.15 - version: 4.3.0(vue@3.4.21) - devDependencies: - '@babel/core': - specifier: 7.18.6 - version: 7.18.6 - '@babel/eslint-parser': - specifier: ^7.12.16 - version: 7.23.10(@babel/core@7.18.6)(eslint@7.32.0) - '@vue/cli-plugin-babel': - specifier: ~5.0.0 - version: 5.0.8(@vue/cli-service@5.0.8)(core-js@3.36.0)(vue@3.4.21) - '@vue/cli-plugin-eslint': - specifier: ~5.0.0 - version: 5.0.8(@vue/cli-service@5.0.8)(eslint@7.32.0) - '@vue/cli-service': - specifier: ~5.0.0 - version: 5.0.8(lodash@4.17.21)(prettier@3.2.5)(stylus-loader@7.1.3)(vue@3.4.21) - eslint: - specifier: ^7.32.0 - version: 7.32.0 - eslint-plugin-vue: - specifier: ^8.0.3 - version: 8.7.1(eslint@7.32.0) - stylus: - specifier: ^0.58.1 - version: 0.58.1 - stylus-loader: - specifier: ^7.0.0 - version: 7.1.3(stylus@0.58.1)(webpack@5.90.3) - projects/web: dependencies: '@element-plus/icons-vue': diff --git a/new-ui/projects/admin/.vscode/extensions.json b/new-ui/projects/admin/.vscode/extensions.json deleted file mode 100644 index 4771846f..00000000 --- a/new-ui/projects/admin/.vscode/extensions.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "recommendations": [ - "Vue.volar", - "Vue.vscode-typescript-vue-plugin", - "dbaeumer.vscode-eslint" - ] -} diff --git a/new-ui/projects/admin/src/components/CustomUploader.vue b/new-ui/projects/admin/src/components/CustomUploader.vue new file mode 100644 index 00000000..4b98e924 --- /dev/null +++ b/new-ui/projects/admin/src/components/CustomUploader.vue @@ -0,0 +1,49 @@ + + + + + + + + + + + + + diff --git a/new-ui/projects/admin/src/components/SimpleTable/SimpleTable.vue b/new-ui/projects/admin/src/components/SimpleTable/SimpleTable.vue index 78e1a22c..28a58989 100644 --- a/new-ui/projects/admin/src/components/SimpleTable/SimpleTable.vue +++ b/new-ui/projects/admin/src/components/SimpleTable/SimpleTable.vue @@ -1,18 +1,17 @@ + + + 新增 + + + + + + 购买API-KEY + + + 编辑 删除 - - 新增 - - + {{ record[column.dataIndex] }} diff --git a/new-ui/projects/admin/src/views/ApiKey/ApiKeyForm.vue b/new-ui/projects/admin/src/views/ApiKey/ApiKeyForm.vue index eedd4dbf..0d064b73 100644 --- a/new-ui/projects/admin/src/views/ApiKey/ApiKeyForm.vue +++ b/new-ui/projects/admin/src/views/ApiKey/ApiKeyForm.vue @@ -1,9 +1,8 @@ - {{ - `注意:如果是百度文心一言平台,API-KEY 为 APIKey|SecretKey,中间用竖线(|)连接\n注意:如果是讯飞星火大模型,API-KEY 为 AppId|APIKey|APISecret,中间用竖线(|)连接` - }} + 注意:如果是百度文心一言平台,API-KEY 为 APIKey|SecretKey,中间用竖线(|)连接 + 注意:如果是讯飞星火大模型,API-KEY 为 AppId|APIKey|APISecret,中间用竖线(|)连接 - + - + - - - - + + + + + + @@ -86,7 +97,7 @@ defineExpose({ form, }); -const typeOPtions = [ +const typeOptions = [ { label: "聊天", value: "chart", @@ -96,6 +107,48 @@ const typeOPtions = [ value: "img", }, ]; + +const platformOptions = [ + { + label: "【OpenAI】ChatGPT", + value: "OpenAI", + api_url: "https://gpt.bemore.lol/v1/chat/completions", + img_url: "https://gpt.bemore.lol/v1/images/generations", + }, + { + label: "【讯飞】星火大模型", + value: "XunFei", + api_url: "wss://spark-api.xf-yun.com/{version}/chat", + }, + { + label: "【清华智普】ChatGLM", + value: "ChatGLM", + api_url: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke", + }, + { + label: "【百度】文心一言", + value: "Baidu", + api_url: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}", + }, + { + label: "【微软】Azure", + value: "Azure", + api_url: + "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15", + }, + { + label: "【阿里】千义通问", + value: "QWen", + api_url: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", + }, +]; + +const handlePlatformChange = () => { + const obj = platformOptions.find((item) => item.value === form.value.platform); + if (obj) { + form.value.api_url = form.value.type === "img" ? obj.img_url : obj.api_url; + } +};