diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 4e90afaa..3827b860 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -61,15 +61,15 @@ type ChatSession struct { } type ChatModel struct { - Id uint `json:"id"` - Platform Platform `json:"platform"` - Name string `json:"name"` - Value string `json:"value"` - Power int `json:"power"` - MaxTokens int `json:"max_tokens"` // 最大响应长度 - MaxContext int `json:"max_context"` // 最大上下文长度 - Temperature float32 `json:"temperature"` // 模型温度 - KeyId int `json:"key_id"` // 绑定 API KEY + Id uint `json:"id"` + Platform string `json:"platform"` + Name string `json:"name"` + Value string `json:"value"` + Power int `json:"power"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` // 绑定 API KEY } type ApiError struct { diff --git a/api/core/types/config.go b/api/core/types/config.go index b0ec7ebe..8536cd51 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -137,14 +137,44 @@ func (c RedisConfig) Url() string { return fmt.Sprintf("%s:%d", c.Host, c.Port) } -type Platform string +type Platform struct { + Name string `json:"name"` + Value string `json:"value"` + ChatURL string `json:"chat_url"` + ImgURL string `json:"img_url"` +} -const OpenAI = Platform("OpenAI") -const Azure = Platform("Azure") -const ChatGLM = Platform("ChatGLM") -const Baidu = Platform("Baidu") -const XunFei = Platform("XunFei") -const QWen = Platform("QWen") +var OpenAI = Platform{ + Name: "OpenAI - GPT", + Value: "OpenAI", + ChatURL: "https://api.chat-plus.net/v1/chat/completions", + ImgURL: "https://api.chat-plus.net/v1/images/generations", +} +var Azure = Platform{ + Name: "微软 - Azure", + Value: "Azure", + ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15", +} +var ChatGLM = Platform{ + Name: "智谱 - ChatGLM", + Value: "ChatGLM", + ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke", +} +var Baidu = Platform{ + Name: "百度 - 文心大模型", + Value: "Baidu", + ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}", +} +var XunFei = Platform{ + Name: "讯飞 - 星火大模型", + Value: "XunFei", + ChatURL: "wss://spark-api.xf-yun.com/{version}/chat", +} +var QWen = Platform{ + Name: "阿里 - 通义千问", + Value: "QWen", + ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", +} type SystemConfig struct { Title string `json:"title,omitempty"` diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index 6cf571e2..584b026b 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -28,8 +28,8 @@ type ConfigHandler struct { handler.BaseHandler levelDB *store.LevelDB licenseService *service.LicenseService - mjServicePool *mj.ServicePool - sdServicePool *sd.ServicePool + mjServicePool *mj.ServicePool + sdServicePool *sd.ServicePool } func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { @@ -140,12 +140,13 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) { resp.SUCCESS(c, license) } -// GetDrawingConfig 获取AI绘画配置 -func (h *ConfigHandler) GetDrawingConfig(c *gin.Context) { +// GetAppConfig 获取内置配置 +func (h *ConfigHandler) GetAppConfig(c *gin.Context) { resp.SUCCESS(c, gin.H{ - "mj_plus": h.App.Config.MjPlusConfigs, - "mj_proxy": h.App.Config.MjProxyConfigs, - "sd": h.App.Config.SdConfigs, + "mj_plus": h.App.Config.MjPlusConfigs, + "mj_proxy": h.App.Config.MjProxyConfigs, + "sd": h.App.Config.SdConfigs, + "platforms": Platforms, }) } diff --git a/api/handler/admin/types.go b/api/handler/admin/types.go new file mode 100644 index 00000000..c06139ba --- /dev/null +++ b/api/handler/admin/types.go @@ -0,0 +1,12 @@ +package admin + +import "geekai/core/types" + +var Platforms = []types.Platform{ + types.OpenAI, + types.QWen, + types.XunFei, + types.ChatGLM, + types.Baidu, + types.Azure, +} diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index ac5ad7ff..bd28d720 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -17,11 +17,9 @@ import ( "geekai/store/model" "geekai/store/vo" "geekai/utils" - "html/template" "io" "strings" "time" - "unicode/utf8" ) // 微软 Azure 模型消息发送实现 @@ -101,104 +99,12 @@ func (h *ChatHandler) sendAzureMessage( // 消息发送成功 if len(contents) > 0 { - - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.SysConfig.EnableContext { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.DB.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - replyTokens += getTotalTokens(req) - - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: replyTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.DB.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) - - // 保存当前会话 - var chatItem model.ChatItem - res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - chatItem.Model = req.Model - h.DB.Create(&chatItem) - } + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } + } else { - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("error with reading response: %v", err) - } - var res types.ApiError - err = json.Unmarshal(body, &res) - if err != nil { - return fmt.Errorf("error with decode response: %v", err) - } - - if strings.Contains(res.Error.Message, "maximum context length") { - logger.Error(res.Error.Message) - h.App.ChatContexts.Delete(session.ChatId) - return h.sendMessage(ctx, session, role, prompt, ws) - } else { - return fmt.Errorf("请求 Azure API 失败:%v", res.Error) - } + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) } return nil diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go index 6e591a33..783ac3e9 100644 --- a/api/handler/chatimpl/baidu_handler.go +++ b/api/handler/chatimpl/baidu_handler.go @@ -17,12 +17,10 @@ import ( "geekai/store/model" "geekai/store/vo" "geekai/utils" - "html/template" "io" "net/http" "strings" "time" - "unicode/utf8" ) type baiduResp struct { @@ -130,99 +128,11 @@ func (h *ChatHandler) sendBaiduMessage( // 消息发送成功 if len(contents) > 0 { - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.SysConfig.EnableContext { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.DB.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.DB.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) - - // 保存当前会话 - var chatItem model.ChatItem - res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - chatItem.Model = req.Model - h.DB.Create(&chatItem) - } + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } } else { - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("error with reading response: %v", err) - } - - var res struct { - Code int `json:"error_code"` - Msg string `json:"error_msg"` - } - err = json.Unmarshal(body, &res) - if err != nil { - return fmt.Errorf("error with decode response: %v", err) - } - utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg) + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) } return nil diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 47a23acf..894a347a 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -23,11 +23,13 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" + "html/template" "net/http" "net/url" "regexp" "strings" "time" + "unicode/utf8" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" @@ -122,7 +124,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { MaxContext: chatModel.MaxContext, Temperature: chatModel.Temperature, KeyId: chatModel.KeyId, - Platform: types.Platform(chatModel.Platform)} + Platform: chatModel.Platform} logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) // 保存会话连接 @@ -218,11 +220,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Stream: true, } switch session.Model.Platform { - case types.Azure, types.ChatGLM, types.Baidu, types.XunFei: + case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value: req.Temperature = session.Model.Temperature req.MaxTokens = session.Model.MaxTokens break - case types.OpenAI: + case types.OpenAI.Value: req.Temperature = session.Model.Temperature req.MaxTokens = session.Model.MaxTokens // OpenAI 支持函数功能 @@ -261,7 +263,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio req.Tools = tools req.ToolChoice = "auto" } - case types.QWen: + case types.QWen.Value: req.Parameters = map[string]interface{}{ "max_tokens": session.Model.MaxTokens, "temperature": session.Model.Temperature, @@ -325,14 +327,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio reqMgs = append(reqMgs, m) } - if session.Model.Platform == types.QWen { + if session.Model.Platform == types.QWen.Value { req.Input = make(map[string]interface{}) reqMgs = append(reqMgs, types.Message{ Role: "user", Content: prompt, }) req.Input["messages"] = reqMgs - } else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model + } else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model imgURLs := utils.ExtractImgURL(prompt) logger.Debugf("detected IMG: %+v", imgURLs) var content interface{} @@ -370,17 +372,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio logger.Debugf("%+v", req.Messages) switch session.Model.Platform { - case types.Azure: + case types.Azure.Value: return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.OpenAI: + case types.OpenAI.Value: return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.ChatGLM: + case types.ChatGLM.Value: return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.Baidu: + case types.Baidu.Value: return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.XunFei: + case types.XunFei.Value: return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.QWen: + case types.QWen.Value: return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) } @@ -467,7 +469,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi } // ONLY allow apiURL in blank list - if session.Model.Platform == types.OpenAI { + if session.Model.Platform == types.OpenAI.Value { err := h.licenseService.IsValidApiURL(apiKey.ApiURL) if err != nil { return nil, err @@ -476,19 +478,19 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi var apiURL string switch session.Model.Platform { - case types.Azure: + case types.Azure.Value: md := strings.Replace(req.Model, ".", "", 1) apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) break - case types.ChatGLM: + case types.ChatGLM.Value: apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 req.Messages = nil break - case types.Baidu: + case types.Baidu.Value: apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) break - case types.QWen: + case types.QWen.Value: apiURL = apiKey.ApiURL req.Messages = nil break @@ -498,7 +500,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi // 更新 API KEY 的最后使用时间 h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // 百度文心,需要串接 access_token - if session.Model.Platform == types.Baidu { + if session.Model.Platform == types.Baidu.Value { token, err := h.getBaiduToken(apiKey.Value) if err != nil { return nil, err @@ -534,22 +536,22 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi } logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) switch session.Model.Platform { - case types.Azure: + case types.Azure.Value: request.Header.Set("api-key", apiKey.Value) break - case types.ChatGLM: + case types.ChatGLM.Value: token, err := h.getChatGLMToken(apiKey.Value) if err != nil { return nil, err } request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) break - case types.Baidu: + case types.Baidu.Value: request.RequestURI = "" - case types.OpenAI: + case types.OpenAI.Value: request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) break - case types.QWen: + case types.QWen.Value: request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) request.Header.Set("X-DashScope-SSE", "enable") break @@ -583,6 +585,97 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p } +func (h *ChatHandler) saveChatHistory( + req types.ApiRequest, + prompt string, + contents []string, + message types.Message, + chatCtx []types.Message, + session *types.ChatSession, + role model.ChatRole, + userVo vo.User, + promptCreatedAt time.Time, + replyCreatedAt time.Time) { + if message.Role == "" { + message.Role = "assistant" + } + message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if h.App.SysConfig.EnableContext { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.DB.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyTokens + getTotalTokens(req) + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.DB.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) + + // 保存当前会话 + var chatItem model.ChatItem + res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.ModelId = session.Model.Id + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + chatItem.Model = req.Model + h.DB.Create(&chatItem) + } +} + // 将AI回复消息中生成的图片链接下载到本地 func (h *ChatHandler) extractImgUrl(text string) string { pattern := `!\[([^\]]*)]\(([^)]+)\)` diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go index 53e83bcb..0192abc8 100644 --- a/api/handler/chatimpl/chatglm_handler.go +++ b/api/handler/chatimpl/chatglm_handler.go @@ -10,7 +10,6 @@ package chatimpl import ( "bufio" "context" - "encoding/json" "errors" "fmt" "geekai/core/types" @@ -18,11 +17,9 @@ import ( "geekai/store/vo" "geekai/utils" "github.com/golang-jwt/jwt/v5" - "html/template" "io" "strings" "time" - "unicode/utf8" ) // 清华大学 ChatGML 消息发送实现 @@ -108,103 +105,11 @@ func (h *ChatHandler) sendChatGLMMessage( // 消息发送成功 if len(contents) > 0 { - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.SysConfig.EnableContext { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.DB.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.DB.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) - - // 保存当前会话 - var chatItem model.ChatItem - res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - chatItem.Model = req.Model - h.DB.Create(&chatItem) - } + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } } else { - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("error with reading response: %v", err) - } - - var res struct { - Code int `json:"code"` - Success bool `json:"success"` - Msg string `json:"msg"` - } - err = json.Unmarshal(body, &res) - if err != nil { - return fmt.Errorf("error with decode response: %v", err) - } - if !res.Success { - utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg) - } + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) } return nil diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index 3878c46c..fb953b79 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -17,13 +17,10 @@ import ( "geekai/store/model" "geekai/store/vo" "geekai/utils" - "html/template" + req2 "github.com/imroc/req/v3" "io" "strings" "time" - "unicode/utf8" - - req2 "github.com/imroc/req/v3" ) // OPenAI 消息发送实现 @@ -178,126 +175,11 @@ func (h *ChatHandler) sendOpenAiMessage( // 消息发送成功 if len(contents) > 0 { - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.SysConfig.EnableContext && toolCall == false { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - useContext := true - if toolCall { - useContext = false - } - - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: useContext, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.DB.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // 计算本次对话消耗的总 token 数量 - var replyTokens = 0 - if toolCall { // prompt + 函数名 + 参数 token - tokens, _ := utils.CalcTokens(function.Name, req.Model) - replyTokens += tokens - tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) - replyTokens += tokens - } else { - replyTokens, _ = utils.CalcTokens(message.Content, req.Model) - } - replyTokens += getTotalTokens(req) - - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: h.extractImgUrl(message.Content), - Tokens: replyTokens, - UseContext: useContext, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.DB.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) - - // 保存当前会话 - var chatItem model.ChatItem - res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - chatItem.Model = req.Model - h.DB.Create(&chatItem) - } + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } } else { - body, err := io.ReadAll(response.Body) - if err != nil { - utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error()) - return fmt.Errorf("error with reading response: %v", err) - } - var res types.ApiError - err = json.Unmarshal(body, &res) - if err != nil { - utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```") - return fmt.Errorf("error with decode response: %v", err) - } - - // OpenAI API 调用异常处理 - if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { - utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") - // 移除当前 API key - h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) - } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { - utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") - } else if strings.Contains(res.Error.Message, "This model's maximum context length") { - logger.Error(res.Error.Message) - utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") - h.App.ChatContexts.Delete(session.ChatId) - return h.sendMessage(ctx, session, role, prompt, ws) - } else { - utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) - } + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求 OpenAI API 失败:%s", body) } return nil diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go index 7b7eb1ac..28bf66bb 100644 --- a/api/handler/chatimpl/qwen_handler.go +++ b/api/handler/chatimpl/qwen_handler.go @@ -10,18 +10,15 @@ package chatimpl import ( "bufio" "context" - "encoding/json" "fmt" "geekai/core/types" "geekai/store/model" "geekai/store/vo" "geekai/utils" "github.com/syndtr/goleveldb/leveldb/errors" - "html/template" "io" "strings" "time" - "unicode/utf8" ) type qWenResp struct { @@ -142,100 +139,11 @@ func (h *ChatHandler) sendQWenMessage( // 消息发送成功 if len(contents) > 0 { - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.SysConfig.EnableContext { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.DB.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.DB.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) - - // 保存当前会话 - var chatItem model.ChatItem - res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - chatItem.Model = req.Model - h.DB.Create(&chatItem) - } + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } } else { - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("error with reading response: %v", err) - } - - var res struct { - Code int `json:"error_code"` - Msg string `json:"error_msg"` - } - err = json.Unmarshal(body, &res) - if err != nil { - return fmt.Errorf("error with decode response: %v", err) - } - utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg) + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) } return nil diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index afc1c165..e4a081fc 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -21,13 +21,11 @@ import ( "geekai/utils" "github.com/gorilla/websocket" "gorm.io/gorm" - "html/template" "io" "net/http" "net/url" "strings" "time" - "unicode/utf8" ) type xunFeiResp struct { @@ -181,89 +179,10 @@ func (h *ChatHandler) sendXunFeiMessage( } } - // 消息发送成功 if len(contents) > 0 { - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.SysConfig.EnableContext { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.DB.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) - historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.DB.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - - // 更新用户算力 - h.subUserPower(userVo, session, promptToken, replyTokens) - - // 保存当前会话 - var chatItem model.ChatItem - res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - chatItem.Model = req.Model - h.DB.Create(&chatItem) - } + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } - return nil } diff --git a/api/main.go b/api/main.go index 55a7075a..9b4458cf 100644 --- a/api/main.go +++ b/api/main.go @@ -304,7 +304,7 @@ func main() { group.GET("config/get", h.Get) group.POST("active", h.Active) group.GET("config/get/license", h.GetLicense) - group.GET("config/get/draw", h.GetDrawingConfig) + group.GET("config/get/app", h.GetAppConfig) group.POST("config/update/draw", h.SaveDrawingConfig) }), fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { diff --git a/web/src/components/admin/AdminSidebar.vue b/web/src/components/admin/AdminSidebar.vue index df24ddc8..4c0aebda 100644 --- a/web/src/components/admin/AdminSidebar.vue +++ b/web/src/components/admin/AdminSidebar.vue @@ -92,9 +92,9 @@ const items = [ }, { - icon: 'role', - index: '/admin/role', - title: '角色管理', + icon: 'menu', + index: '/admin/app', + title: '应用管理', }, { icon: 'api-key', diff --git a/web/src/router.js b/web/src/router.js index a03c5ec6..64fdecb6 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -46,7 +46,7 @@ const routes = [ component: () => import('@/views/Member.vue'), }, { - name: 'chat-role', + name: 'chat-app', path: '/apps', meta: {title: '应用中心'}, component: () => import('@/views/ChatApps.vue'), @@ -139,10 +139,10 @@ const routes = [ component: () => import('@/views/admin/Users.vue'), }, { - path: '/admin/role', - name: 'admin-role', - meta: {title: '角色管理'}, - component: () => import('@/views/admin/Roles.vue'), + path: '/admin/app', + name: 'admin-app', + meta: {title: '应用管理'}, + component: () => import('@/views/admin/Apps.vue'), }, { path: '/admin/apikey', diff --git a/web/src/views/admin/AIDrawing.vue b/web/src/views/admin/AIDrawing.vue index b193924c..66836b6b 100644 --- a/web/src/views/admin/AIDrawing.vue +++ b/web/src/views/admin/AIDrawing.vue @@ -119,12 +119,12 @@ const mjModels = ref([ {name: "急速(Turbo)", value: "turbo"}, ]) -httpGet("/api/admin/config/get/draw").then(res => { +httpGet("/api/admin/config/get/app").then(res => { sdConfigs.value = res.data.sd mjPlusConfigs.value = res.data.mj_plus mjProxyConfigs.value = res.data.mj_proxy }).catch(e =>{ - ElMessage.error("获取AI绘画配置失败:"+e.message) + ElMessage.error("获取配置失败:"+e.message) }) const addConfig = (configs) => { diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index fec5d14c..aacc05a2 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -87,7 +87,7 @@ - + {{ item.name }} @@ -99,7 +99,8 @@ + placeholder="必须填土完整的 Chat API URL,如:https://api.openai.com/v1/chat/completions"/> +
如果你使用了第三方中转,这里就填写中转地址
@@ -126,7 +127,7 @@ import {onMounted, onUnmounted, reactive, ref} from "vue"; import {httpGet, httpPost} from "@/utils/http"; import {ElMessage} from "element-plus"; import {dateFormat, removeArrayItem, substr} from "@/utils/libs"; -import {DocumentCopy, Plus, ShoppingCart} from "@element-plus/icons-vue"; +import {DocumentCopy, Plus, ShoppingCart, InfoFilled} from "@element-plus/icons-vue"; import ClipboardJS from "clipboard"; // 变量定义 @@ -142,39 +143,7 @@ const rules = reactive({ const loading = ref(true) const formRef = ref(null) const title = ref("") -const platforms = ref([ - { - name: "【OpenAI/中转】ChatGPT", - value: "OpenAI", - api_url: "https://api.chat-plus.net/v1/chat/completions", - img_url: "https://api.chat-plus.net/v1/images/generations" - }, - { - name: "【讯飞】星火大模型", - value: "XunFei", - api_url: "wss://spark-api.xf-yun.com/{version}/chat" - }, - { - name: "【清华智普】ChatGLM", - value: "ChatGLM", - api_url: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke" - }, - { - name: "【百度】文心一言", - value: "Baidu", - api_url: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}" - }, - { - name: "【微软】Azure", - value: "Azure", - api_url: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15" - }, - { - name: "【阿里】千义通问", - value: "QWen", - api_url: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - }, -]) +const platforms = ref([]) const types = ref([ {name: "聊天", value: "chat"}, {name: "绘画", value: "img"}, @@ -191,6 +160,12 @@ onMounted(() => { clipboard.value.on('error', () => { ElMessage.error('复制失败!'); }) + + httpGet("/api/admin/config/get/app").then(res => { + platforms.value = res.data.platforms + }).catch(e =>{ + ElMessage.error("获取配置失败:"+e.message) + }) }) onUnmounted(() => { @@ -263,21 +238,24 @@ const set = (filed, row) => { }) } -const changePlatform = () => { - let platform = null +const selectedPlatform = ref(null) +const changePlatform = (value) => { + console.log(value) for (let v of platforms.value) { - if (v.value === item.value.platform) { - platform = v - break + if (v.value === value) { + selectedPlatform.value = v + item.value.api_url = v.chat_url } } - if (platform !== null) { - if (item.value.type === "img" && platform.img_url) { - item.value.api_url = platform.img_url - } else { - item.value.api_url = platform.api_url - } +} +const changeType = (value) => { + if (selectedPlatform.value) { + if(value === 'img') { + item.value.api_url = selectedPlatform.value.img_url + } else { + item.value.api_url = selectedPlatform.value.chat_url + } } } @@ -306,7 +284,9 @@ const changePlatform = () => { .el-form { .el-form-item__content { - + .info { + color #999999 + } .el-icon { padding-left: 10px; } diff --git a/web/src/views/admin/Roles.vue b/web/src/views/admin/Apps.vue similarity index 92% rename from web/src/views/admin/Roles.vue rename to web/src/views/admin/Apps.vue index b32de7c7..c85baf77 100644 --- a/web/src/views/admin/Roles.vue +++ b/web/src/views/admin/Apps.vue @@ -9,25 +9,25 @@ - + - + - + @@ -36,7 +36,7 @@