mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-21 18:56:38 +08:00
feat: add midjouney api implements, optimize function calls
This commit is contained in:
parent
0b27890484
commit
302cb8a5be
@ -34,13 +34,10 @@ type AppServer struct {
|
||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
Functions map[string]function.Function
|
||||
MjTasks *types.LMap[string, types.MjTask]
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
appConfig *types.AppConfig,
|
||||
funZaoBao function.FuncZaoBao,
|
||||
funZhiHu function.FuncHeadlines,
|
||||
funWeibo function.FuncWeiboHot) *AppServer {
|
||||
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
@ -51,11 +48,8 @@ func NewServer(
|
||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
Functions: map[string]function.Function{
|
||||
types.FuncZaoBao: funZaoBao,
|
||||
types.FuncWeibo: funWeibo,
|
||||
types.FuncHeadLine: funZhiHu,
|
||||
},
|
||||
MjTasks: types.NewLMap[string, types.MjTask](),
|
||||
Functions: functions,
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,7 +180,8 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
c.Request.URL.Path == "/api/reward/push" ||
|
||||
c.Request.URL.Path == "/api/reward/notify" ||
|
||||
c.Request.URL.Path == "/api/mj/notify" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
|
||||
|
@ -33,8 +33,8 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
},
|
||||
ApiConfig: types.ChatPlusApiConfig{},
|
||||
ChatPlusExtApiToken: utils.RandString(32),
|
||||
ApiConfig: types.ChatPlusApiConfig{},
|
||||
ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,6 +42,16 @@ type ChatSession struct {
|
||||
Model string `json:"model"` // GPT 模型
|
||||
}
|
||||
|
||||
type MjTask struct {
|
||||
Client Client
|
||||
ChatId string
|
||||
MessageId string
|
||||
MessageHash string
|
||||
UserId uint
|
||||
RoleId uint
|
||||
Icon string
|
||||
}
|
||||
|
||||
type ApiError struct {
|
||||
Error struct {
|
||||
Message string
|
||||
|
@ -6,19 +6,19 @@ import (
|
||||
)
|
||||
|
||||
type AppConfig struct {
|
||||
Path string `toml:"-"`
|
||||
Listen string
|
||||
Session Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
Manager Manager // 后台管理员账户信息
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||
AesEncryptKey string
|
||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
||||
ChatPlusExtApiToken string // chatgpt-plus-exts callback api token
|
||||
Path string `toml:"-"`
|
||||
Listen string
|
||||
Session Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
Manager Manager // 后台管理员账户信息
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||
AesEncryptKey string
|
||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
||||
ExtConfig ChatPlusExtConfig // ChatPlus extensions callback api config
|
||||
}
|
||||
|
||||
type ChatPlusApiConfig struct {
|
||||
@ -27,6 +27,11 @@ type ChatPlusApiConfig struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
type ChatPlusExtConfig struct {
|
||||
ApiURL string
|
||||
Token string
|
||||
}
|
||||
|
||||
type AliYunSmsConfig struct {
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
|
@ -23,9 +23,10 @@ type Property struct {
|
||||
}
|
||||
|
||||
const (
|
||||
FuncZaoBao = "zao_bao" // 每日早报
|
||||
FuncHeadLine = "headline" // 今日头条
|
||||
FuncWeibo = "weibo_hot" // 微博热搜
|
||||
FuncZaoBao = "zao_bao" // 每日早报
|
||||
FuncHeadLine = "headline" // 今日头条
|
||||
FuncWeibo = "weibo_hot" // 微博热搜
|
||||
FuncMidJourney = "mid_journey" // MJ 绘画
|
||||
)
|
||||
|
||||
var InnerFunctions = []Function{
|
||||
@ -73,4 +74,23 @@ var InnerFunctions = []Function{
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: FuncMidJourney,
|
||||
Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画",
|
||||
Parameters: Parameters{
|
||||
Type: "object",
|
||||
Properties: map[string]Property{
|
||||
"prompt": {
|
||||
Type: "string",
|
||||
Description: "绘画内容描述,提示词,此参数需要翻译成英文",
|
||||
},
|
||||
"ar": {
|
||||
Type: "string",
|
||||
Description: "图片长宽比,如 16:9",
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ type MKey interface {
|
||||
string | int
|
||||
}
|
||||
type MValue interface {
|
||||
*WsClient | ChatSession | context.CancelFunc | []interface{}
|
||||
*WsClient | ChatSession | context.CancelFunc | []interface{} | MjTask
|
||||
}
|
||||
type LMap[K MKey, T MValue] struct {
|
||||
lock sync.RWMutex
|
||||
|
@ -88,7 +88,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
var chatRole model.ChatRole
|
||||
res = h.db.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
replyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@ -98,7 +98,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
h.db.Where("marker", "chat").First(&config)
|
||||
err = utils.JsonDecode(config.Config, &chatConfig)
|
||||
if err != nil {
|
||||
replyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
||||
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@ -116,7 +116,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
//replyMessage(client, "这是一条测试消息!")
|
||||
//utils.ReplyMessage(client, "这是一条测试消息!")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
@ -124,7 +124,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
} else {
|
||||
replyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
logger.Info("回答完毕: " + string(message))
|
||||
}
|
||||
|
||||
@ -139,7 +139,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
var user model.User
|
||||
res := h.db.Model(&model.User{}).First(&user, session.UserId)
|
||||
if res.Error != nil {
|
||||
replyMessage(ws, "非法用户,请联系管理员!")
|
||||
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
|
||||
return res.Error
|
||||
}
|
||||
var userVo vo.User
|
||||
@ -150,20 +150,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
}
|
||||
|
||||
if userVo.Status == false {
|
||||
replyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
replyMessage(ws, "")
|
||||
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
utils.ReplyMessage(ws, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" {
|
||||
replyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
|
||||
replyMessage(ws, "")
|
||||
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
|
||||
utils.ReplyMessage(ws, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||
replyMessage(ws, "您的账号已经过期,请联系管理员!")
|
||||
replyMessage(ws, "")
|
||||
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
|
||||
utils.ReplyMessage(ws, "")
|
||||
return nil
|
||||
}
|
||||
var req = types.ApiRequest{
|
||||
@ -238,14 +238,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
logger.Info("用户取消了请求:", prompt)
|
||||
return nil
|
||||
} else if strings.Contains(err.Error(), "no available key") {
|
||||
replyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
|
||||
return nil
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
replyMessage(ws, ErrorMsg)
|
||||
replyMessage(ws, "")
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, "")
|
||||
return err
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
@ -280,8 +280,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
replyMessage(ws, ErrorMsg)
|
||||
replyMessage(ws, "")
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, "")
|
||||
break
|
||||
}
|
||||
|
||||
@ -295,8 +295,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
functionCall = true
|
||||
functionName = fun.Name
|
||||
f := h.App.Functions[functionName]
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
continue
|
||||
}
|
||||
|
||||
@ -307,14 +307,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
continue
|
||||
} else if responseBody.Choices[0].FinishReason != "" {
|
||||
break // 输出完成或者输出中断了
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
@ -322,23 +322,39 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
} // end for
|
||||
|
||||
if functionCall { // 调用函数完成任务
|
||||
logger.Info(functionName)
|
||||
logger.Info(arguments)
|
||||
logger.Info("函数名称:", functionName)
|
||||
var params map[string]interface{}
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Info("函数参数:", params)
|
||||
f := h.App.Functions[functionName]
|
||||
data, err := f.Invoke(arguments)
|
||||
data, err := f.Invoke(params)
|
||||
if err != nil {
|
||||
msg := "调用函数出错:" + err.Error()
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
content := data
|
||||
if functionName == types.FuncMidJourney {
|
||||
key := utils.Sha256(data)
|
||||
// add task for MidJourney
|
||||
h.App.MjTasks.Put(key, types.MjTask{
|
||||
UserId: userVo.Id,
|
||||
RoleId: role.Id,
|
||||
Icon: role.Icon,
|
||||
Client: ws,
|
||||
ChatId: session.ChatId,
|
||||
})
|
||||
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
||||
}
|
||||
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: data,
|
||||
Content: content,
|
||||
})
|
||||
contents = append(contents, data)
|
||||
contents = append(contents, content)
|
||||
}
|
||||
}
|
||||
|
||||
@ -430,7 +446,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
} else {
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
}
|
||||
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||
//utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||
if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗
|
||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||
totalTokens))
|
||||
@ -468,18 +484,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// OpenAI API 调用异常处理
|
||||
// TODO: 是否考虑重发消息?
|
||||
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
||||
// 移除当前 API key
|
||||
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
||||
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||
logger.Error(res.Error.Message)
|
||||
replyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
||||
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
||||
h.App.ChatContexts.Delete(session.ChatId)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
@ -534,26 +550,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
// 回复客户片段端消息
|
||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
||||
msg, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||
return
|
||||
}
|
||||
err = client.(*types.WsClient).Send(msg)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for reply message: %v", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 回复客户端一条完整的消息
|
||||
func replyMessage(ws types.Client, message string) {
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
text := c.Query("text")
|
||||
|
61
api/handler/mj_handler.go
Normal file
61
api/handler/mj_handler.go
Normal file
@ -0,0 +1,61 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
Start = TaskStatus("Started")
|
||||
Running = TaskStatus("Running")
|
||||
Stopped = TaskStatus("Stopped")
|
||||
Finished = TaskStatus("Finished")
|
||||
)
|
||||
|
||||
type Image struct {
|
||||
URL string `json:"url"`
|
||||
ProxyURL string `json:"proxy_url"`
|
||||
Filename string `json:"filename"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Size int `json:"size"`
|
||||
}
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler {
|
||||
h := MidJourneyHandler{}
|
||||
h.App = app
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token != h.App.Config.ExtConfig.Token {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Image Image `json:"image"`
|
||||
Content string `json:"content"`
|
||||
Status TaskStatus `json:"status"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o"
|
||||
wsClient := h.App.ChatClients.Get(sessionId)
|
||||
utils.ReplyMessage(wsClient, "")
|
||||
logger.Infof("Data: %+v", data)
|
||||
resp.ERROR(c, "Error with CallBack")
|
||||
}
|
@ -21,9 +21,9 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *RewardHandler) Push(c *gin.Context) {
|
||||
token := c.GetHeader("X-TOKEN")
|
||||
if token != h.App.Config.ChatPlusExtApiToken {
|
||||
func (h *RewardHandler) Notify(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token != h.App.Config.ExtConfig.Token {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
16
api/main.go
16
api/main.go
@ -104,15 +104,7 @@ func main() {
|
||||
}),
|
||||
|
||||
// 创建函数
|
||||
fx.Provide(func(config *types.AppConfig) (function.FuncZaoBao, error) {
|
||||
return function.NewZaoBao(config.ApiConfig), nil
|
||||
}),
|
||||
fx.Provide(func(config *types.AppConfig) (function.FuncWeiboHot, error) {
|
||||
return function.NewWeiboHot(config.ApiConfig), nil
|
||||
}),
|
||||
fx.Provide(func(config *types.AppConfig) (function.FuncHeadlines, error) {
|
||||
return function.NewHeadLines(config.ApiConfig), nil
|
||||
}),
|
||||
fx.Provide(function.NewFunctions),
|
||||
|
||||
// 创建控制器
|
||||
fx.Provide(handler.NewChatRoleHandler),
|
||||
@ -122,6 +114,7 @@ func main() {
|
||||
fx.Provide(handler.NewSmsHandler),
|
||||
fx.Provide(handler.NewRewardHandler),
|
||||
fx.Provide(handler.NewCaptchaHandler),
|
||||
fx.Provide(handler.NewMidJourneyHandler),
|
||||
|
||||
fx.Provide(admin.NewConfigHandler),
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
@ -180,9 +173,12 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
||||
group := s.Engine.Group("/api/reward/")
|
||||
group.POST("push", h.Push)
|
||||
group.POST("notify", h.Notify)
|
||||
group.POST("verify", h.Verify)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||
s.Engine.POST("/api/mj/notify", h.Notify)
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||
|
@ -1,12 +1,17 @@
|
||||
package function
|
||||
|
||||
import "chatplus/core/types"
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
)
|
||||
|
||||
type Function interface {
|
||||
Invoke(...interface{}) (string, error)
|
||||
Invoke(map[string]interface{}) (string, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type resVo struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@ -22,3 +27,12 @@ type dataItem struct {
|
||||
Url string `json:"url"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func NewFunctions(config *types.AppConfig) map[string]Function {
|
||||
return map[string]Function{
|
||||
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
|
||||
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
|
||||
types.FuncHeadLine: NewHeadLines(config.ApiConfig),
|
||||
types.FuncMidJourney: NewMidJourneyFunc(config.ExtConfig),
|
||||
}
|
||||
}
|
||||
|
60
api/service/function/mid_journey.go
Normal file
60
api/service/function/mid_journey.go
Normal file
@ -0,0 +1,60 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AI 绘画函数
|
||||
|
||||
type FuncMidJourney struct {
|
||||
name string
|
||||
config types.ChatPlusExtConfig
|
||||
client *req.Client
|
||||
}
|
||||
|
||||
func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney {
|
||||
return FuncMidJourney{
|
||||
name: "MidJourney AI 绘画",
|
||||
config: config,
|
||||
client: req.C().SetTimeout(10 * time.Second)}
|
||||
}
|
||||
|
||||
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
||||
if f.config.Token == "" {
|
||||
return "", errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
logger.Infof("MJ 绘画参数:%+v", params)
|
||||
prompt := utils.InterfaceToString(params["prompt"])
|
||||
if !utils.IsEmptyValue(params["ar"]) {
|
||||
prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"])
|
||||
delete(params, "ar")
|
||||
}
|
||||
prompt = prompt + " --niji 5"
|
||||
var res types.BizVo
|
||||
r, err := f.client.R().
|
||||
SetHeader("Authorization", f.config.Token).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(params).
|
||||
SetSuccessResult(&res).Post(f.config.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("%v%v", r.String(), err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return "", errors.New(res.Message)
|
||||
}
|
||||
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (f FuncMidJourney) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
var _ Function = &FuncMidJourney{}
|
@ -24,7 +24,7 @@ func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
|
||||
client: req.C().SetTimeout(10 * time.Second)}
|
||||
}
|
||||
|
||||
func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
||||
func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
|
||||
if f.config.Token == "" {
|
||||
return "", errors.New("无效的 API Token")
|
||||
}
|
||||
@ -35,11 +35,8 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
||||
SetHeader("AppId", f.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
return "", r.Err
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("%v%v", err, r.Err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
@ -57,3 +54,5 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
||||
func (f FuncHeadlines) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
var _ Function = &FuncHeadlines{}
|
||||
|
@ -24,7 +24,7 @@ func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
|
||||
client: req.C().SetTimeout(10 * time.Second)}
|
||||
}
|
||||
|
||||
func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
||||
func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
|
||||
if f.config.Token == "" {
|
||||
return "", errors.New("无效的 API Token")
|
||||
}
|
||||
@ -35,11 +35,8 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
||||
SetHeader("AppId", f.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
return "", r.Err
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("%v%v", err, r.Err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
@ -57,3 +54,5 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
||||
func (f FuncWeiboHot) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
var _ Function = &FuncWeiboHot{}
|
||||
|
@ -24,7 +24,7 @@ func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
|
||||
client: req.C().SetTimeout(10 * time.Second)}
|
||||
}
|
||||
|
||||
func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
||||
func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
|
||||
if f.config.Token == "" {
|
||||
return "", errors.New("无效的 API Token")
|
||||
}
|
||||
@ -35,11 +35,8 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
||||
SetHeader("AppId", f.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
return "", r.Err
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("%v%v", err, r.Err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
@ -58,3 +55,5 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
||||
func (f FuncZaoBao) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
var _ Function = &FuncZaoBao{}
|
||||
|
@ -1,68 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func HttpGet(uri string, proxy string) ([]byte, error) {
|
||||
var client *http.Client
|
||||
if proxy == "" {
|
||||
client = &http.Client{}
|
||||
} else {
|
||||
proxy, _ := url.Parse(proxy)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", uri, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func HttpPost(uri string, params map[string]interface{}, proxy string) ([]byte, error) {
|
||||
data, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if proxy == "" {
|
||||
client = &http.Client{}
|
||||
} else {
|
||||
proxy, _ := url.Parse(proxy)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", uri, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
29
api/utils/websocket.go
Normal file
29
api/utils/websocket.go
Normal file
@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// ReplyChunkMessage 回复客户片段端消息
|
||||
func ReplyChunkMessage(client types.Client, message types.WsMessage) {
|
||||
msg, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||
return
|
||||
}
|
||||
err = client.(*types.WsClient).Send(msg)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for reply message: %v", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// ReplyMessage 回复客户端一条完整的消息
|
||||
func ReplyMessage(ws types.Client, message string) {
|
||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||
}
|
Loading…
Reference in New Issue
Block a user