feat: add midjouney api implements, optimize function calls

This commit is contained in:
RockYang 2023-08-11 18:46:56 +08:00
parent 0b27890484
commit 302cb8a5be
17 changed files with 298 additions and 183 deletions

View File

@ -34,13 +34,10 @@ type AppServer struct {
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
Functions map[string]function.Function Functions map[string]function.Function
MjTasks *types.LMap[string, types.MjTask]
} }
func NewServer( func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
appConfig *types.AppConfig,
funZaoBao function.FuncZaoBao,
funZhiHu function.FuncHeadlines,
funWeibo function.FuncWeiboHot) *AppServer {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard gin.DefaultWriter = io.Discard
return &AppServer{ return &AppServer{
@ -51,11 +48,8 @@ func NewServer(
ChatSession: types.NewLMap[string, types.ChatSession](), ChatSession: types.NewLMap[string, types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](), ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
Functions: map[string]function.Function{ MjTasks: types.NewLMap[string, types.MjTask](),
types.FuncZaoBao: funZaoBao, Functions: functions,
types.FuncWeibo: funWeibo,
types.FuncHeadLine: funZhiHu,
},
} }
} }
@ -186,7 +180,8 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
if c.Request.URL.Path == "/api/user/login" || if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/admin/login" || c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/user/register" || c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/reward/push" || c.Request.URL.Path == "/api/reward/notify" ||
c.Request.URL.Path == "/api/mj/notify" ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
strings.HasPrefix(c.Request.URL.Path, "/static/") || strings.HasPrefix(c.Request.URL.Path, "/static/") ||

View File

@ -33,8 +33,8 @@ func NewDefaultConfig() *types.AppConfig {
HttpOnly: false, HttpOnly: false,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, },
ApiConfig: types.ChatPlusApiConfig{}, ApiConfig: types.ChatPlusApiConfig{},
ChatPlusExtApiToken: utils.RandString(32), ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)},
} }
} }

View File

@ -42,6 +42,16 @@ type ChatSession struct {
Model string `json:"model"` // GPT 模型 Model string `json:"model"` // GPT 模型
} }
type MjTask struct {
Client Client
ChatId string
MessageId string
MessageHash string
UserId uint
RoleId uint
Icon string
}
type ApiError struct { type ApiError struct {
Error struct { Error struct {
Message string Message string

View File

@ -6,19 +6,19 @@ import (
) )
type AppConfig struct { type AppConfig struct {
Path string `toml:"-"` Path string `toml:"-"`
Listen string Listen string
Session Session Session Session
ProxyURL string ProxyURL string
MysqlDns string // mysql 连接地址 MysqlDns string // mysql 连接地址
Manager Manager // 后台管理员账户信息 Manager Manager // 后台管理员账户信息
StaticDir string // 静态资源目录 StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息 Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
AesEncryptKey string AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config SmsConfig AliYunSmsConfig // AliYun send message service config
ChatPlusExtApiToken string // chatgpt-plus-exts callback api token ExtConfig ChatPlusExtConfig // ChatPlus extensions callback api config
} }
type ChatPlusApiConfig struct { type ChatPlusApiConfig struct {
@ -27,6 +27,11 @@ type ChatPlusApiConfig struct {
Token string Token string
} }
type ChatPlusExtConfig struct {
ApiURL string
Token string
}
type AliYunSmsConfig struct { type AliYunSmsConfig struct {
AccessKey string AccessKey string
AccessSecret string AccessSecret string

View File

@ -23,9 +23,10 @@ type Property struct {
} }
const ( const (
FuncZaoBao = "zao_bao" // 每日早报 FuncZaoBao = "zao_bao" // 每日早报
FuncHeadLine = "headline" // 今日头条 FuncHeadLine = "headline" // 今日头条
FuncWeibo = "weibo_hot" // 微博热搜 FuncWeibo = "weibo_hot" // 微博热搜
FuncMidJourney = "mid_journey" // MJ 绘画
) )
var InnerFunctions = []Function{ var InnerFunctions = []Function{
@ -73,4 +74,23 @@ var InnerFunctions = []Function{
Required: []string{}, Required: []string{},
}, },
}, },
{
Name: FuncMidJourney,
Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
"prompt": {
Type: "string",
Description: "绘画内容描述,提示词,此参数需要翻译成英文",
},
"ar": {
Type: "string",
Description: "图片长宽比,如 16:9",
},
},
Required: []string{},
},
},
} }

View File

@ -9,7 +9,7 @@ type MKey interface {
string | int string | int
} }
type MValue interface { type MValue interface {
*WsClient | ChatSession | context.CancelFunc | []interface{} *WsClient | ChatSession | context.CancelFunc | []interface{} | MjTask
} }
type LMap[K MKey, T MValue] struct { type LMap[K MKey, T MValue] struct {
lock sync.RWMutex lock sync.RWMutex

View File

@ -88,7 +88,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
var chatRole model.ChatRole var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId) res = h.db.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable { if res.Error != nil || !chatRole.Enable {
replyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort() c.Abort()
return return
} }
@ -98,7 +98,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
h.db.Where("marker", "chat").First(&config) h.db.Where("marker", "chat").First(&config)
err = utils.JsonDecode(config.Config, &chatConfig) err = utils.JsonDecode(config.Config, &chatConfig)
if err != nil { if err != nil {
replyMessage(client, "加载系统配置失败,连接已关闭!!!") utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
c.Abort() c.Abort()
return return
} }
@ -116,7 +116,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return return
} }
logger.Info("Receive a message: ", string(message)) logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!") //utils.ReplyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
h.App.ReqCancelFunc.Put(sessionId, cancel) h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息 // 回复消息
@ -124,7 +124,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} else { } else {
replyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Info("回答完毕: " + string(message)) logger.Info("回答完毕: " + string(message))
} }
@ -139,7 +139,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
var user model.User var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId) res := h.db.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil { if res.Error != nil {
replyMessage(ws, "非法用户,请联系管理员!") utils.ReplyMessage(ws, "非法用户,请联系管理员!")
return res.Error return res.Error
} }
var userVo vo.User var userVo vo.User
@ -150,20 +150,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} }
if userVo.Status == false { if userVo.Status == false {
replyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!") utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
replyMessage(ws, "![](/images/wx.png)") utils.ReplyMessage(ws, "![](/images/wx.png)")
return nil return nil
} }
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" { if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" {
replyMessage(ws, "您的对话次数已经用尽请联系管理员或者点击左下角菜单加入众筹获得100次对话") utils.ReplyMessage(ws, "您的对话次数已经用尽请联系管理员或者点击左下角菜单加入众筹获得100次对话")
replyMessage(ws, "![](/images/wx.png)") utils.ReplyMessage(ws, "![](/images/wx.png)")
return nil return nil
} }
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() { if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
replyMessage(ws, "您的账号已经过期,请联系管理员!") utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
replyMessage(ws, "![](/images/wx.png)") utils.ReplyMessage(ws, "![](/images/wx.png)")
return nil return nil
} }
var req = types.ApiRequest{ var req = types.ApiRequest{
@ -238,14 +238,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
logger.Info("用户取消了请求:", prompt) logger.Info("用户取消了请求:", prompt)
return nil return nil
} else if strings.Contains(err.Error(), "no available key") { } else if strings.Contains(err.Error(), "no available key") {
replyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏") utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
return nil return nil
} else { } else {
logger.Error(err) logger.Error(err)
} }
replyMessage(ws, ErrorMsg) utils.ReplyMessage(ws, ErrorMsg)
replyMessage(ws, "![](/images/wx.png)") utils.ReplyMessage(ws, "![](/images/wx.png)")
return err return err
} else { } else {
defer response.Body.Close() defer response.Body.Close()
@ -280,8 +280,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
err = json.Unmarshal([]byte(line[6:]), &responseBody) err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
logger.Error(err, line) logger.Error(err, line)
replyMessage(ws, ErrorMsg) utils.ReplyMessage(ws, ErrorMsg)
replyMessage(ws, "![](/images/wx.png)") utils.ReplyMessage(ws, "![](/images/wx.png)")
break break
} }
@ -295,8 +295,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
functionCall = true functionCall = true
functionName = fun.Name functionName = fun.Name
f := h.App.Functions[functionName] f := h.App.Functions[functionName]
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
continue continue
} }
@ -307,14 +307,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 初始化 role // 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role message.Role = responseBody.Choices[0].Delta.Role
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue continue
} else if responseBody.Choices[0].FinishReason != "" { } else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了 break // 输出完成或者输出中断了
} else { } else {
content := responseBody.Choices[0].Delta.Content content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content)) contents = append(contents, utils.InterfaceToString(content))
replyChunkMessage(ws, types.WsMessage{ utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle, Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
}) })
@ -322,23 +322,39 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} // end for } // end for
if functionCall { // 调用函数完成任务 if functionCall { // 调用函数完成任务
logger.Info(functionName) logger.Info("函数名称:", functionName)
logger.Info(arguments) var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Info("函数参数:", params)
f := h.App.Functions[functionName] f := h.App.Functions[functionName]
data, err := f.Invoke(arguments) data, err := f.Invoke(params)
if err != nil { if err != nil {
msg := "调用函数出错:" + err.Error() msg := "调用函数出错:" + err.Error()
replyChunkMessage(ws, types.WsMessage{ utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle, Type: types.WsMiddle,
Content: msg, Content: msg,
}) })
contents = append(contents, msg) contents = append(contents, msg)
} else { } else {
replyChunkMessage(ws, types.WsMessage{ content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
// add task for MidJourney
h.App.MjTasks.Put(key, types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: role.Icon,
Client: ws,
ChatId: session.ChatId,
})
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle, Type: types.WsMiddle,
Content: data, Content: content,
}) })
contents = append(contents, data) contents = append(contents, content)
} }
} }
@ -430,7 +446,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} else { } else {
totalTokens = replyToken + getTotalTokens(req) totalTokens = replyToken + getTotalTokens(req)
} }
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)}) //utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗 if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?", h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
totalTokens)) totalTokens))
@ -468,18 +484,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// OpenAI API 调用异常处理 // OpenAI API 调用异常处理
// TODO: 是否考虑重发消息? // TODO: 是否考虑重发消息?
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
replyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。") utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key // 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{}) h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") { } else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
replyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。") utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") { } else if strings.Contains(res.Error.Message, "This model's maximum context length") {
logger.Error(res.Error.Message) logger.Error(res.Error.Message)
replyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId) h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws) return h.sendMessage(ctx, session, role, prompt, ws)
} else { } else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
} }
} }
@ -534,26 +550,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return client.Do(request) return client.Do(request)
} }
// 回复客户片段端消息
func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*types.WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws types.Client, message string) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
}
// Tokens 统计 token 数量 // Tokens 统计 token 数量
func (h *ChatHandler) Tokens(c *gin.Context) { func (h *ChatHandler) Tokens(c *gin.Context) {
text := c.Query("text") text := c.Query("text")

61
api/handler/mj_handler.go Normal file
View File

@ -0,0 +1,61 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
)
type TaskStatus string
const (
Start = TaskStatus("Started")
Running = TaskStatus("Running")
Stopped = TaskStatus("Stopped")
Finished = TaskStatus("Finished")
)
type Image struct {
URL string `json:"url"`
ProxyURL string `json:"proxy_url"`
Filename string `json:"filename"`
Width int `json:"width"`
Height int `json:"height"`
Size int `json:"size"`
}
type MidJourneyHandler struct {
BaseHandler
}
func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler {
h := MidJourneyHandler{}
h.App = app
return &h
}
func (h *MidJourneyHandler) Notify(c *gin.Context) {
token := c.GetHeader("Authorization")
if token != h.App.Config.ExtConfig.Token {
resp.NotAuth(c)
return
}
var data struct {
Image Image `json:"image"`
Content string `json:"content"`
Status TaskStatus `json:"status"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o"
wsClient := h.App.ChatClients.Get(sessionId)
utils.ReplyMessage(wsClient, "![](https://cdn.discordapp.com/attachments/1138713254718361633/1139482452579070053/lal603743923_A_Chinese_girl_walking_barefoot_on_the_beach_weari_df8b6dc0-3b13-478c-8dbb-983015d21661.png)")
logger.Infof("Data: %+v", data)
resp.ERROR(c, "Error with CallBack")
}

View File

@ -21,9 +21,9 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
return &h return &h
} }
func (h *RewardHandler) Push(c *gin.Context) { func (h *RewardHandler) Notify(c *gin.Context) {
token := c.GetHeader("X-TOKEN") token := c.GetHeader("Authorization")
if token != h.App.Config.ChatPlusExtApiToken { if token != h.App.Config.ExtConfig.Token {
resp.NotAuth(c) resp.NotAuth(c)
return return
} }

View File

@ -104,15 +104,7 @@ func main() {
}), }),
// 创建函数 // 创建函数
fx.Provide(func(config *types.AppConfig) (function.FuncZaoBao, error) { fx.Provide(function.NewFunctions),
return function.NewZaoBao(config.ApiConfig), nil
}),
fx.Provide(func(config *types.AppConfig) (function.FuncWeiboHot, error) {
return function.NewWeiboHot(config.ApiConfig), nil
}),
fx.Provide(func(config *types.AppConfig) (function.FuncHeadlines, error) {
return function.NewHeadLines(config.ApiConfig), nil
}),
// 创建控制器 // 创建控制器
fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewChatRoleHandler),
@ -122,6 +114,7 @@ func main() {
fx.Provide(handler.NewSmsHandler), fx.Provide(handler.NewSmsHandler),
fx.Provide(handler.NewRewardHandler), fx.Provide(handler.NewRewardHandler),
fx.Provide(handler.NewCaptchaHandler), fx.Provide(handler.NewCaptchaHandler),
fx.Provide(handler.NewMidJourneyHandler),
fx.Provide(admin.NewConfigHandler), fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler), fx.Provide(admin.NewAdminHandler),
@ -180,9 +173,12 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
group := s.Engine.Group("/api/reward/") group := s.Engine.Group("/api/reward/")
group.POST("push", h.Push) group.POST("notify", h.Notify)
group.POST("verify", h.Verify) group.POST("verify", h.Verify)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
s.Engine.POST("/api/mj/notify", h.Notify)
}),
// 管理后台控制器 // 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {

View File

@ -1,12 +1,17 @@
package function package function
import "chatplus/core/types" import (
"chatplus/core/types"
logger2 "chatplus/logger"
)
type Function interface { type Function interface {
Invoke(...interface{}) (string, error) Invoke(map[string]interface{}) (string, error)
Name() string Name() string
} }
var logger = logger2.GetLogger()
type resVo struct { type resVo struct {
Code types.BizCode `json:"code"` Code types.BizCode `json:"code"`
Message string `json:"message"` Message string `json:"message"`
@ -22,3 +27,12 @@ type dataItem struct {
Url string `json:"url"` Url string `json:"url"`
Remark string `json:"remark"` Remark string `json:"remark"`
} }
func NewFunctions(config *types.AppConfig) map[string]Function {
return map[string]Function{
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
types.FuncHeadLine: NewHeadLines(config.ApiConfig),
types.FuncMidJourney: NewMidJourneyFunc(config.ExtConfig),
}
}

View File

@ -0,0 +1,60 @@
package function
import (
"chatplus/core/types"
"chatplus/utils"
"errors"
"fmt"
"github.com/imroc/req/v3"
"time"
)
// AI 绘画函数
type FuncMidJourney struct {
name string
config types.ChatPlusExtConfig
client *req.Client
}
func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney {
return FuncMidJourney{
name: "MidJourney AI 绘画",
config: config,
client: req.C().SetTimeout(10 * time.Second)}
}
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"])
if !utils.IsEmptyValue(params["ar"]) {
prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"])
delete(params, "ar")
}
prompt = prompt + " --niji 5"
var res types.BizVo
r, err := f.client.R().
SetHeader("Authorization", f.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(params).
SetSuccessResult(&res).Post(f.config.ApiURL)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
return prompt, nil
}
func (f FuncMidJourney) Name() string {
return f.name
}
var _ Function = &FuncMidJourney{}

View File

@ -24,7 +24,7 @@ func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
client: req.C().SetTimeout(10 * time.Second)} client: req.C().SetTimeout(10 * time.Second)}
} }
func (f FuncHeadlines) Invoke(...interface{}) (string, error) { func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" { if f.config.Token == "" {
return "", errors.New("无效的 API Token") return "", errors.New("无效的 API Token")
} }
@ -35,11 +35,8 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
SetHeader("AppId", f.config.AppId). SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil { if err != nil || r.IsErrorState() {
return "", err return "", fmt.Errorf("%v%v", err, r.Err)
}
if r.IsErrorState() {
return "", r.Err
} }
if res.Code != types.Success { if res.Code != types.Success {
@ -57,3 +54,5 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
func (f FuncHeadlines) Name() string { func (f FuncHeadlines) Name() string {
return f.name return f.name
} }
var _ Function = &FuncHeadlines{}

View File

@ -24,7 +24,7 @@ func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
client: req.C().SetTimeout(10 * time.Second)} client: req.C().SetTimeout(10 * time.Second)}
} }
func (f FuncWeiboHot) Invoke(...interface{}) (string, error) { func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" { if f.config.Token == "" {
return "", errors.New("无效的 API Token") return "", errors.New("无效的 API Token")
} }
@ -35,11 +35,8 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
SetHeader("AppId", f.config.AppId). SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil { if err != nil || r.IsErrorState() {
return "", err return "", fmt.Errorf("%v%v", err, r.Err)
}
if r.IsErrorState() {
return "", r.Err
} }
if res.Code != types.Success { if res.Code != types.Success {
@ -57,3 +54,5 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
func (f FuncWeiboHot) Name() string { func (f FuncWeiboHot) Name() string {
return f.name return f.name
} }
var _ Function = &FuncWeiboHot{}

View File

@ -24,7 +24,7 @@ func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
client: req.C().SetTimeout(10 * time.Second)} client: req.C().SetTimeout(10 * time.Second)}
} }
func (f FuncZaoBao) Invoke(...interface{}) (string, error) { func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
if f.config.Token == "" { if f.config.Token == "" {
return "", errors.New("无效的 API Token") return "", errors.New("无效的 API Token")
} }
@ -35,11 +35,8 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
SetHeader("AppId", f.config.AppId). SetHeader("AppId", f.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)). SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
SetSuccessResult(&res).Get(url) SetSuccessResult(&res).Get(url)
if err != nil { if err != nil || r.IsErrorState() {
return "", err return "", fmt.Errorf("%v%v", err, r.Err)
}
if r.IsErrorState() {
return "", r.Err
} }
if res.Code != types.Success { if res.Code != types.Success {
@ -58,3 +55,5 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
func (f FuncZaoBao) Name() string { func (f FuncZaoBao) Name() string {
return f.name return f.name
} }
var _ Function = &FuncZaoBao{}

View File

@ -1,68 +0,0 @@
package utils
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/url"
)
func HttpGet(uri string, proxy string) ([]byte, error) {
var client *http.Client
if proxy == "" {
client = &http.Client{}
} else {
proxy, _ := url.Parse(proxy)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
}
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
func HttpPost(uri string, params map[string]interface{}, proxy string) ([]byte, error) {
data, err := json.Marshal(params)
if err != nil {
return nil, err
}
var client *http.Client
if proxy == "" {
client = &http.Client{}
} else {
proxy, _ := url.Parse(proxy)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
}
req, err := http.NewRequest("POST", uri, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}

29
api/utils/websocket.go Normal file
View File

@ -0,0 +1,29 @@
package utils
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"encoding/json"
)
var logger = logger2.GetLogger()
// ReplyChunkMessage 回复客户片段端消息
func ReplyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*types.WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// ReplyMessage 回复客户端一条完整的消息
func ReplyMessage(ws types.Client, message string) {
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
}