diff --git a/api/core/app_server.go b/api/core/app_server.go index 6a09ce99..4f0b0c55 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -34,13 +34,10 @@ type AppServer struct { ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function Functions map[string]function.Function + MjTasks *types.LMap[string, types.MjTask] } -func NewServer( - appConfig *types.AppConfig, - funZaoBao function.FuncZaoBao, - funZhiHu function.FuncHeadlines, - funWeibo function.FuncWeiboHot) *AppServer { +func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer { gin.SetMode(gin.ReleaseMode) gin.DefaultWriter = io.Discard return &AppServer{ @@ -51,11 +48,8 @@ func NewServer( ChatSession: types.NewLMap[string, types.ChatSession](), ChatClients: types.NewLMap[string, *types.WsClient](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), - Functions: map[string]function.Function{ - types.FuncZaoBao: funZaoBao, - types.FuncWeibo: funWeibo, - types.FuncHeadLine: funZhiHu, - }, + MjTasks: types.NewLMap[string, types.MjTask](), + Functions: functions, } } @@ -186,7 +180,8 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc { if c.Request.URL.Path == "/api/user/login" || c.Request.URL.Path == "/api/admin/login" || c.Request.URL.Path == "/api/user/register" || - c.Request.URL.Path == "/api/reward/push" || + c.Request.URL.Path == "/api/reward/notify" || + c.Request.URL.Path == "/api/mj/notify" || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || strings.HasPrefix(c.Request.URL.Path, "/static/") || diff --git a/api/core/config.go b/api/core/config.go index 7308c185..00c2f7a2 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -33,8 +33,8 @@ func NewDefaultConfig() *types.AppConfig { HttpOnly: false, SameSite: http.SameSiteLaxMode, }, - ApiConfig: types.ChatPlusApiConfig{}, - ChatPlusExtApiToken: utils.RandString(32), + ApiConfig: types.ChatPlusApiConfig{}, + ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)}, } } diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 04d57355..163a0aea 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -42,6 +42,16 @@ type ChatSession struct { Model string `json:"model"` // GPT 模型 } +type MjTask struct { + Client Client + ChatId string + MessageId string + MessageHash string + UserId uint + RoleId uint + Icon string +} + type ApiError struct { Error struct { Message string diff --git a/api/core/types/config.go b/api/core/types/config.go index 9a919528..f4b87be6 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -6,19 +6,19 @@ import ( ) type AppConfig struct { - Path string `toml:"-"` - Listen string - Session Session - ProxyURL string - MysqlDns string // mysql 连接地址 - Manager Manager // 后台管理员账户信息 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs - AesEncryptKey string - SmsConfig AliYunSmsConfig // AliYun send message service config - ChatPlusExtApiToken string // chatgpt-plus-exts callback api token + Path string `toml:"-"` + Listen string + Session Session + ProxyURL string + MysqlDns string // mysql 连接地址 + Manager Manager // 后台管理员账户信息 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs + AesEncryptKey string + SmsConfig AliYunSmsConfig // AliYun send message service config + ExtConfig ChatPlusExtConfig // ChatPlus extensions callback api config } type ChatPlusApiConfig struct { @@ -27,6 +27,11 @@ type ChatPlusApiConfig struct { Token string } +type ChatPlusExtConfig struct { + ApiURL string + Token string +} + type AliYunSmsConfig struct { AccessKey string AccessSecret string diff --git a/api/core/types/function.go b/api/core/types/function.go index 71e1e44d..a4361263 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -23,9 +23,10 @@ type Property struct { } const ( - FuncZaoBao = "zao_bao" // 每日早报 - FuncHeadLine = "headline" // 今日头条 - FuncWeibo = "weibo_hot" // 微博热搜 + FuncZaoBao = "zao_bao" // 每日早报 + FuncHeadLine = "headline" // 今日头条 + FuncWeibo = "weibo_hot" // 微博热搜 + FuncMidJourney = "mid_journey" // MJ 绘画 ) var InnerFunctions = []Function{ @@ -73,4 +74,23 @@ var InnerFunctions = []Function{ Required: []string{}, }, }, + + { + Name: FuncMidJourney, + Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画", + Parameters: Parameters{ + Type: "object", + Properties: map[string]Property{ + "prompt": { + Type: "string", + Description: "绘画内容描述,提示词,此参数需要翻译成英文", + }, + "ar": { + Type: "string", + Description: "图片长宽比,如 16:9", + }, + }, + Required: []string{}, + }, + }, } diff --git a/api/core/types/locked_map.go b/api/core/types/locked_map.go index 36ca48ff..11d19d84 100644 --- a/api/core/types/locked_map.go +++ b/api/core/types/locked_map.go @@ -9,7 +9,7 @@ type MKey interface { string | int } type MValue interface { - *WsClient | ChatSession | context.CancelFunc | []interface{} + *WsClient | ChatSession | context.CancelFunc | []interface{} | MjTask } type LMap[K MKey, T MValue] struct { lock sync.RWMutex diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index b908872e..b1e195f4 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -88,7 +88,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { var chatRole model.ChatRole res = h.db.First(&chatRole, roleId) if res.Error != nil || !chatRole.Enable { - replyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") + utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") c.Abort() return } @@ -98,7 +98,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { h.db.Where("marker", "chat").First(&config) err = utils.JsonDecode(config.Config, &chatConfig) if err != nil { - replyMessage(client, "加载系统配置失败,连接已关闭!!!") + utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!") c.Abort() return } @@ -116,7 +116,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { return } logger.Info("Receive a message: ", string(message)) - //replyMessage(client, "这是一条测试消息!") + //utils.ReplyMessage(client, "这是一条测试消息!") ctx, cancel := context.WithCancel(context.Background()) h.App.ReqCancelFunc.Put(sessionId, cancel) // 回复消息 @@ -124,7 +124,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { if err != nil { logger.Error(err) } else { - replyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) logger.Info("回答完毕: " + string(message)) } @@ -139,7 +139,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession var user model.User res := h.db.Model(&model.User{}).First(&user, session.UserId) if res.Error != nil { - replyMessage(ws, "非法用户,请联系管理员!") + utils.ReplyMessage(ws, "非法用户,请联系管理员!") return res.Error } var userVo vo.User @@ -150,20 +150,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } if userVo.Status == false { - replyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!") - replyMessage(ws, "![](/images/wx.png)") + utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!") + utils.ReplyMessage(ws, "![](/images/wx.png)") return nil } if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" { - replyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!") - replyMessage(ws, "![](/images/wx.png)") + utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!") + utils.ReplyMessage(ws, "![](/images/wx.png)") return nil } if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() { - replyMessage(ws, "您的账号已经过期,请联系管理员!") - replyMessage(ws, "![](/images/wx.png)") + utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!") + utils.ReplyMessage(ws, "![](/images/wx.png)") return nil } var req = types.ApiRequest{ @@ -238,14 +238,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession logger.Info("用户取消了请求:", prompt) return nil } else if strings.Contains(err.Error(), "no available key") { - replyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏") + utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏") return nil } else { logger.Error(err) } - replyMessage(ws, ErrorMsg) - replyMessage(ws, "![](/images/wx.png)") + utils.ReplyMessage(ws, ErrorMsg) + utils.ReplyMessage(ws, "![](/images/wx.png)") return err } else { defer response.Body.Close() @@ -280,8 +280,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 logger.Error(err, line) - replyMessage(ws, ErrorMsg) - replyMessage(ws, "![](/images/wx.png)") + utils.ReplyMessage(ws, ErrorMsg) + utils.ReplyMessage(ws, "![](/images/wx.png)") break } @@ -295,8 +295,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession functionCall = true functionName = fun.Name f := h.App.Functions[functionName] - replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())}) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())}) continue } @@ -307,14 +307,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role - replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) continue } else if responseBody.Choices[0].FinishReason != "" { break // 输出完成或者输出中断了 } else { content := responseBody.Choices[0].Delta.Content contents = append(contents, utils.InterfaceToString(content)) - replyChunkMessage(ws, types.WsMessage{ + utils.ReplyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) @@ -322,23 +322,39 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } // end for if functionCall { // 调用函数完成任务 - logger.Info(functionName) - logger.Info(arguments) + logger.Info("函数名称:", functionName) + var params map[string]interface{} + _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) + logger.Info("函数参数:", params) f := h.App.Functions[functionName] - data, err := f.Invoke(arguments) + data, err := f.Invoke(params) if err != nil { msg := "调用函数出错:" + err.Error() - replyChunkMessage(ws, types.WsMessage{ + utils.ReplyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, Content: msg, }) contents = append(contents, msg) } else { - replyChunkMessage(ws, types.WsMessage{ + content := data + if functionName == types.FuncMidJourney { + key := utils.Sha256(data) + // add task for MidJourney + h.App.MjTasks.Put(key, types.MjTask{ + UserId: userVo.Id, + RoleId: role.Id, + Icon: role.Icon, + Client: ws, + ChatId: session.ChatId, + }) + content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) + } + + utils.ReplyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, - Content: data, + Content: content, }) - contents = append(contents, data) + contents = append(contents, content) } } @@ -430,7 +446,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } else { totalTokens = replyToken + getTotalTokens(req) } - //replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)}) + //utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)}) if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗 h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens)) @@ -468,18 +484,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession // OpenAI API 调用异常处理 // TODO: 是否考虑重发消息? if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { - replyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") + utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") // 移除当前 API key h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{}) } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { - replyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") + utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") } else if strings.Contains(res.Error.Message, "This model's maximum context length") { logger.Error(res.Error.Message) - replyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") + utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") h.App.ChatContexts.Delete(session.ChatId) return h.sendMessage(ctx, session, role, prompt, ws) } else { - replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) + utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) } } @@ -534,26 +550,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin return client.Do(request) } -// 回复客户片段端消息 -func replyChunkMessage(client types.Client, message types.WsMessage) { - msg, err := json.Marshal(message) - if err != nil { - logger.Errorf("Error for decoding json data: %v", err.Error()) - return - } - err = client.(*types.WsClient).Send(msg) - if err != nil { - logger.Errorf("Error for reply message: %v", err.Error()) - } -} - -// 回复客户端一条完整的消息 -func replyMessage(ws types.Client, message string) { - replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message}) - replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) -} - // Tokens 统计 token 数量 func (h *ChatHandler) Tokens(c *gin.Context) { text := c.Query("text") diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go new file mode 100644 index 00000000..07782df8 --- /dev/null +++ b/api/handler/mj_handler.go @@ -0,0 +1,61 @@ +package handler + +import ( + "chatplus/core" + "chatplus/core/types" + "chatplus/utils" + "chatplus/utils/resp" + "github.com/gin-gonic/gin" +) + +type TaskStatus string + +const ( + Start = TaskStatus("Started") + Running = TaskStatus("Running") + Stopped = TaskStatus("Stopped") + Finished = TaskStatus("Finished") +) + +type Image struct { + URL string `json:"url"` + ProxyURL string `json:"proxy_url"` + Filename string `json:"filename"` + Width int `json:"width"` + Height int `json:"height"` + Size int `json:"size"` +} + +type MidJourneyHandler struct { + BaseHandler +} + +func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler { + h := MidJourneyHandler{} + h.App = app + return &h +} + +func (h *MidJourneyHandler) Notify(c *gin.Context) { + token := c.GetHeader("Authorization") + if token != h.App.Config.ExtConfig.Token { + resp.NotAuth(c) + return + } + + var data struct { + Image Image `json:"image"` + Content string `json:"content"` + Status TaskStatus `json:"status"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o" + wsClient := h.App.ChatClients.Get(sessionId) + utils.ReplyMessage(wsClient, "![](https://cdn.discordapp.com/attachments/1138713254718361633/1139482452579070053/lal603743923_A_Chinese_girl_walking_barefoot_on_the_beach_weari_df8b6dc0-3b13-478c-8dbb-983015d21661.png)") + logger.Infof("Data: %+v", data) + resp.ERROR(c, "Error with CallBack") +} diff --git a/api/handler/reward_handler.go b/api/handler/reward_handler.go index 0825d7ae..9673eb4f 100644 --- a/api/handler/reward_handler.go +++ b/api/handler/reward_handler.go @@ -21,9 +21,9 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler { return &h } -func (h *RewardHandler) Push(c *gin.Context) { - token := c.GetHeader("X-TOKEN") - if token != h.App.Config.ChatPlusExtApiToken { +func (h *RewardHandler) Notify(c *gin.Context) { + token := c.GetHeader("Authorization") + if token != h.App.Config.ExtConfig.Token { resp.NotAuth(c) return } diff --git a/api/main.go b/api/main.go index f5140a6e..64f61d2c 100644 --- a/api/main.go +++ b/api/main.go @@ -104,15 +104,7 @@ func main() { }), // 创建函数 - fx.Provide(func(config *types.AppConfig) (function.FuncZaoBao, error) { - return function.NewZaoBao(config.ApiConfig), nil - }), - fx.Provide(func(config *types.AppConfig) (function.FuncWeiboHot, error) { - return function.NewWeiboHot(config.ApiConfig), nil - }), - fx.Provide(func(config *types.AppConfig) (function.FuncHeadlines, error) { - return function.NewHeadLines(config.ApiConfig), nil - }), + fx.Provide(function.NewFunctions), // 创建控制器 fx.Provide(handler.NewChatRoleHandler), @@ -122,6 +114,7 @@ func main() { fx.Provide(handler.NewSmsHandler), fx.Provide(handler.NewRewardHandler), fx.Provide(handler.NewCaptchaHandler), + fx.Provide(handler.NewMidJourneyHandler), fx.Provide(admin.NewConfigHandler), fx.Provide(admin.NewAdminHandler), @@ -180,9 +173,12 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { group := s.Engine.Group("/api/reward/") - group.POST("push", h.Push) + group.POST("notify", h.Notify) group.POST("verify", h.Verify) }), + fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { + s.Engine.POST("/api/mj/notify", h.Notify) + }), // 管理后台控制器 fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { diff --git a/api/service/function/function.go b/api/service/function/function.go index dad0d09f..d490ac1c 100644 --- a/api/service/function/function.go +++ b/api/service/function/function.go @@ -1,12 +1,17 @@ package function -import "chatplus/core/types" +import ( + "chatplus/core/types" + logger2 "chatplus/logger" +) type Function interface { - Invoke(...interface{}) (string, error) + Invoke(map[string]interface{}) (string, error) Name() string } +var logger = logger2.GetLogger() + type resVo struct { Code types.BizCode `json:"code"` Message string `json:"message"` @@ -22,3 +27,12 @@ type dataItem struct { Url string `json:"url"` Remark string `json:"remark"` } + +func NewFunctions(config *types.AppConfig) map[string]Function { + return map[string]Function{ + types.FuncZaoBao: NewZaoBao(config.ApiConfig), + types.FuncWeibo: NewWeiboHot(config.ApiConfig), + types.FuncHeadLine: NewHeadLines(config.ApiConfig), + types.FuncMidJourney: NewMidJourneyFunc(config.ExtConfig), + } +} diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go new file mode 100644 index 00000000..4996466e --- /dev/null +++ b/api/service/function/mid_journey.go @@ -0,0 +1,60 @@ +package function + +import ( + "chatplus/core/types" + "chatplus/utils" + "errors" + "fmt" + "github.com/imroc/req/v3" + "time" +) + +// AI 绘画函数 + +type FuncMidJourney struct { + name string + config types.ChatPlusExtConfig + client *req.Client +} + +func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney { + return FuncMidJourney{ + name: "MidJourney AI 绘画", + config: config, + client: req.C().SetTimeout(10 * time.Second)} +} + +func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { + if f.config.Token == "" { + return "", errors.New("无效的 API Token") + } + + logger.Infof("MJ 绘画参数:%+v", params) + prompt := utils.InterfaceToString(params["prompt"]) + if !utils.IsEmptyValue(params["ar"]) { + prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"]) + delete(params, "ar") + } + prompt = prompt + " --niji 5" + var res types.BizVo + r, err := f.client.R(). + SetHeader("Authorization", f.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(params). + SetSuccessResult(&res).Post(f.config.ApiURL) + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return "", errors.New(res.Message) + } + + return prompt, nil +} + +func (f FuncMidJourney) Name() string { + return f.name +} + +var _ Function = &FuncMidJourney{} diff --git a/api/service/function/tou_tiao.go b/api/service/function/tou_tiao.go index 54de2589..c77e2141 100644 --- a/api/service/function/tou_tiao.go +++ b/api/service/function/tou_tiao.go @@ -24,7 +24,7 @@ func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines { client: req.C().SetTimeout(10 * time.Second)} } -func (f FuncHeadlines) Invoke(...interface{}) (string, error) { +func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) { if f.config.Token == "" { return "", errors.New("无效的 API Token") } @@ -35,11 +35,8 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) { SetHeader("AppId", f.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). SetSuccessResult(&res).Get(url) - if err != nil { - return "", err - } - if r.IsErrorState() { - return "", r.Err + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("%v%v", err, r.Err) } if res.Code != types.Success { @@ -57,3 +54,5 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) { func (f FuncHeadlines) Name() string { return f.name } + +var _ Function = &FuncHeadlines{} diff --git a/api/service/function/weibo_hot.go b/api/service/function/weibo_hot.go index f0c818fc..95fccc27 100644 --- a/api/service/function/weibo_hot.go +++ b/api/service/function/weibo_hot.go @@ -24,7 +24,7 @@ func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot { client: req.C().SetTimeout(10 * time.Second)} } -func (f FuncWeiboHot) Invoke(...interface{}) (string, error) { +func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) { if f.config.Token == "" { return "", errors.New("无效的 API Token") } @@ -35,11 +35,8 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) { SetHeader("AppId", f.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). SetSuccessResult(&res).Get(url) - if err != nil { - return "", err - } - if r.IsErrorState() { - return "", r.Err + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("%v%v", err, r.Err) } if res.Code != types.Success { @@ -57,3 +54,5 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) { func (f FuncWeiboHot) Name() string { return f.name } + +var _ Function = &FuncWeiboHot{} diff --git a/api/service/function/zao_bao.go b/api/service/function/zao_bao.go index 87fd5172..174c81c4 100644 --- a/api/service/function/zao_bao.go +++ b/api/service/function/zao_bao.go @@ -24,7 +24,7 @@ func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao { client: req.C().SetTimeout(10 * time.Second)} } -func (f FuncZaoBao) Invoke(...interface{}) (string, error) { +func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) { if f.config.Token == "" { return "", errors.New("无效的 API Token") } @@ -35,11 +35,8 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) { SetHeader("AppId", f.config.AppId). SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). SetSuccessResult(&res).Get(url) - if err != nil { - return "", err - } - if r.IsErrorState() { - return "", r.Err + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("%v%v", err, r.Err) } if res.Code != types.Success { @@ -58,3 +55,5 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) { func (f FuncZaoBao) Name() string { return f.name } + +var _ Function = &FuncZaoBao{} diff --git a/api/utils/http.go b/api/utils/http.go deleted file mode 100644 index 80dfc7d1..00000000 --- a/api/utils/http.go +++ /dev/null @@ -1,68 +0,0 @@ -package utils - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "net/url" -) - -func HttpGet(uri string, proxy string) ([]byte, error) { - var client *http.Client - if proxy == "" { - client = &http.Client{} - } else { - proxy, _ := url.Parse(proxy) - client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxy), - }, - } - } - - req, err := http.NewRequest("GET", uri, nil) - if err != nil { - return nil, err - } - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - return io.ReadAll(resp.Body) -} - -func HttpPost(uri string, params map[string]interface{}, proxy string) ([]byte, error) { - data, err := json.Marshal(params) - if err != nil { - return nil, err - } - - var client *http.Client - if proxy == "" { - client = &http.Client{} - } else { - proxy, _ := url.Parse(proxy) - client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxy), - }, - } - } - - req, err := http.NewRequest("POST", uri, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - return io.ReadAll(resp.Body) -} diff --git a/api/utils/websocket.go b/api/utils/websocket.go new file mode 100644 index 00000000..e161d97b --- /dev/null +++ b/api/utils/websocket.go @@ -0,0 +1,29 @@ +package utils + +import ( + "chatplus/core/types" + logger2 "chatplus/logger" + "encoding/json" +) + +var logger = logger2.GetLogger() + +// ReplyChunkMessage 回复客户片段端消息 +func ReplyChunkMessage(client types.Client, message types.WsMessage) { + msg, err := json.Marshal(message) + if err != nil { + logger.Errorf("Error for decoding json data: %v", err.Error()) + return + } + err = client.(*types.WsClient).Send(msg) + if err != nil { + logger.Errorf("Error for reply message: %v", err.Error()) + } +} + +// ReplyMessage 回复客户端一条完整的消息 +func ReplyMessage(ws types.Client, message string) { + ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message}) + ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) +}