mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-22 03:06:38 +08:00
feat: add midjouney api implements, optimize function calls
This commit is contained in:
parent
0b27890484
commit
302cb8a5be
@ -34,13 +34,10 @@ type AppServer struct {
|
|||||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||||
Functions map[string]function.Function
|
Functions map[string]function.Function
|
||||||
|
MjTasks *types.LMap[string, types.MjTask]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(
|
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
|
||||||
appConfig *types.AppConfig,
|
|
||||||
funZaoBao function.FuncZaoBao,
|
|
||||||
funZhiHu function.FuncHeadlines,
|
|
||||||
funWeibo function.FuncWeiboHot) *AppServer {
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
gin.DefaultWriter = io.Discard
|
gin.DefaultWriter = io.Discard
|
||||||
return &AppServer{
|
return &AppServer{
|
||||||
@ -51,11 +48,8 @@ func NewServer(
|
|||||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
Functions: map[string]function.Function{
|
MjTasks: types.NewLMap[string, types.MjTask](),
|
||||||
types.FuncZaoBao: funZaoBao,
|
Functions: functions,
|
||||||
types.FuncWeibo: funWeibo,
|
|
||||||
types.FuncHeadLine: funZhiHu,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,7 +180,8 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
|
|||||||
if c.Request.URL.Path == "/api/user/login" ||
|
if c.Request.URL.Path == "/api/user/login" ||
|
||||||
c.Request.URL.Path == "/api/admin/login" ||
|
c.Request.URL.Path == "/api/admin/login" ||
|
||||||
c.Request.URL.Path == "/api/user/register" ||
|
c.Request.URL.Path == "/api/user/register" ||
|
||||||
c.Request.URL.Path == "/api/reward/push" ||
|
c.Request.URL.Path == "/api/reward/notify" ||
|
||||||
|
c.Request.URL.Path == "/api/mj/notify" ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
|
||||||
|
@ -33,8 +33,8 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
HttpOnly: false,
|
HttpOnly: false,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
},
|
},
|
||||||
ApiConfig: types.ChatPlusApiConfig{},
|
ApiConfig: types.ChatPlusApiConfig{},
|
||||||
ChatPlusExtApiToken: utils.RandString(32),
|
ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,6 +42,16 @@ type ChatSession struct {
|
|||||||
Model string `json:"model"` // GPT 模型
|
Model string `json:"model"` // GPT 模型
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MjTask struct {
|
||||||
|
Client Client
|
||||||
|
ChatId string
|
||||||
|
MessageId string
|
||||||
|
MessageHash string
|
||||||
|
UserId uint
|
||||||
|
RoleId uint
|
||||||
|
Icon string
|
||||||
|
}
|
||||||
|
|
||||||
type ApiError struct {
|
type ApiError struct {
|
||||||
Error struct {
|
Error struct {
|
||||||
Message string
|
Message string
|
||||||
|
@ -6,19 +6,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
Path string `toml:"-"`
|
Path string `toml:"-"`
|
||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
ProxyURL string
|
ProxyURL string
|
||||||
MysqlDns string // mysql 连接地址
|
MysqlDns string // mysql 连接地址
|
||||||
Manager Manager // 后台管理员账户信息
|
Manager Manager // 后台管理员账户信息
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||||
AesEncryptKey string
|
AesEncryptKey string
|
||||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
SmsConfig AliYunSmsConfig // AliYun send message service config
|
||||||
ChatPlusExtApiToken string // chatgpt-plus-exts callback api token
|
ExtConfig ChatPlusExtConfig // ChatPlus extensions callback api config
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatPlusApiConfig struct {
|
type ChatPlusApiConfig struct {
|
||||||
@ -27,6 +27,11 @@ type ChatPlusApiConfig struct {
|
|||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatPlusExtConfig struct {
|
||||||
|
ApiURL string
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
type AliYunSmsConfig struct {
|
type AliYunSmsConfig struct {
|
||||||
AccessKey string
|
AccessKey string
|
||||||
AccessSecret string
|
AccessSecret string
|
||||||
|
@ -23,9 +23,10 @@ type Property struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FuncZaoBao = "zao_bao" // 每日早报
|
FuncZaoBao = "zao_bao" // 每日早报
|
||||||
FuncHeadLine = "headline" // 今日头条
|
FuncHeadLine = "headline" // 今日头条
|
||||||
FuncWeibo = "weibo_hot" // 微博热搜
|
FuncWeibo = "weibo_hot" // 微博热搜
|
||||||
|
FuncMidJourney = "mid_journey" // MJ 绘画
|
||||||
)
|
)
|
||||||
|
|
||||||
var InnerFunctions = []Function{
|
var InnerFunctions = []Function{
|
||||||
@ -73,4 +74,23 @@ var InnerFunctions = []Function{
|
|||||||
Required: []string{},
|
Required: []string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
Name: FuncMidJourney,
|
||||||
|
Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画",
|
||||||
|
Parameters: Parameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]Property{
|
||||||
|
"prompt": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "绘画内容描述,提示词,此参数需要翻译成英文",
|
||||||
|
},
|
||||||
|
"ar": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "图片长宽比,如 16:9",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ type MKey interface {
|
|||||||
string | int
|
string | int
|
||||||
}
|
}
|
||||||
type MValue interface {
|
type MValue interface {
|
||||||
*WsClient | ChatSession | context.CancelFunc | []interface{}
|
*WsClient | ChatSession | context.CancelFunc | []interface{} | MjTask
|
||||||
}
|
}
|
||||||
type LMap[K MKey, T MValue] struct {
|
type LMap[K MKey, T MValue] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
@ -88,7 +88,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
var chatRole model.ChatRole
|
var chatRole model.ChatRole
|
||||||
res = h.db.First(&chatRole, roleId)
|
res = h.db.First(&chatRole, roleId)
|
||||||
if res.Error != nil || !chatRole.Enable {
|
if res.Error != nil || !chatRole.Enable {
|
||||||
replyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -98,7 +98,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
h.db.Where("marker", "chat").First(&config)
|
h.db.Where("marker", "chat").First(&config)
|
||||||
err = utils.JsonDecode(config.Config, &chatConfig)
|
err = utils.JsonDecode(config.Config, &chatConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
replyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Receive a message: ", string(message))
|
logger.Info("Receive a message: ", string(message))
|
||||||
//replyMessage(client, "这是一条测试消息!")
|
//utils.ReplyMessage(client, "这是一条测试消息!")
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
||||||
// 回复消息
|
// 回复消息
|
||||||
@ -124,7 +124,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
} else {
|
} else {
|
||||||
replyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||||
logger.Info("回答完毕: " + string(message))
|
logger.Info("回答完毕: " + string(message))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
var user model.User
|
var user model.User
|
||||||
res := h.db.Model(&model.User{}).First(&user, session.UserId)
|
res := h.db.Model(&model.User{}).First(&user, session.UserId)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
replyMessage(ws, "非法用户,请联系管理员!")
|
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
var userVo vo.User
|
var userVo vo.User
|
||||||
@ -150,20 +150,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userVo.Status == false {
|
if userVo.Status == false {
|
||||||
replyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||||
replyMessage(ws, "")
|
utils.ReplyMessage(ws, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" {
|
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKey == "" {
|
||||||
replyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
|
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
|
||||||
replyMessage(ws, "")
|
utils.ReplyMessage(ws, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||||
replyMessage(ws, "您的账号已经过期,请联系管理员!")
|
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
|
||||||
replyMessage(ws, "")
|
utils.ReplyMessage(ws, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var req = types.ApiRequest{
|
var req = types.ApiRequest{
|
||||||
@ -238,14 +238,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
logger.Info("用户取消了请求:", prompt)
|
logger.Info("用户取消了请求:", prompt)
|
||||||
return nil
|
return nil
|
||||||
} else if strings.Contains(err.Error(), "no available key") {
|
} else if strings.Contains(err.Error(), "no available key") {
|
||||||
replyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
|
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY🔑,您可以导入自己的 API KEY🔑 继续使用!🙏🙏🙏")
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyMessage(ws, ErrorMsg)
|
utils.ReplyMessage(ws, ErrorMsg)
|
||||||
replyMessage(ws, "")
|
utils.ReplyMessage(ws, "")
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@ -280,8 +280,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||||
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
|
||||||
logger.Error(err, line)
|
logger.Error(err, line)
|
||||||
replyMessage(ws, ErrorMsg)
|
utils.ReplyMessage(ws, ErrorMsg)
|
||||||
replyMessage(ws, "")
|
utils.ReplyMessage(ws, "")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,8 +295,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
functionCall = true
|
functionCall = true
|
||||||
functionName = fun.Name
|
functionName = fun.Name
|
||||||
f := h.App.Functions[functionName]
|
f := h.App.Functions[functionName]
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,14 +307,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
// 初始化 role
|
// 初始化 role
|
||||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||||
message.Role = responseBody.Choices[0].Delta.Role
|
message.Role = responseBody.Choices[0].Delta.Role
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
continue
|
continue
|
||||||
} else if responseBody.Choices[0].FinishReason != "" {
|
} else if responseBody.Choices[0].FinishReason != "" {
|
||||||
break // 输出完成或者输出中断了
|
break // 输出完成或者输出中断了
|
||||||
} else {
|
} else {
|
||||||
content := responseBody.Choices[0].Delta.Content
|
content := responseBody.Choices[0].Delta.Content
|
||||||
contents = append(contents, utils.InterfaceToString(content))
|
contents = append(contents, utils.InterfaceToString(content))
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
})
|
})
|
||||||
@ -322,23 +322,39 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
} // end for
|
} // end for
|
||||||
|
|
||||||
if functionCall { // 调用函数完成任务
|
if functionCall { // 调用函数完成任务
|
||||||
logger.Info(functionName)
|
logger.Info("函数名称:", functionName)
|
||||||
logger.Info(arguments)
|
var params map[string]interface{}
|
||||||
|
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||||
|
logger.Info("函数参数:", params)
|
||||||
f := h.App.Functions[functionName]
|
f := h.App.Functions[functionName]
|
||||||
data, err := f.Invoke(arguments)
|
data, err := f.Invoke(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "调用函数出错:" + err.Error()
|
msg := "调用函数出错:" + err.Error()
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: msg,
|
Content: msg,
|
||||||
})
|
})
|
||||||
contents = append(contents, msg)
|
contents = append(contents, msg)
|
||||||
} else {
|
} else {
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
content := data
|
||||||
|
if functionName == types.FuncMidJourney {
|
||||||
|
key := utils.Sha256(data)
|
||||||
|
// add task for MidJourney
|
||||||
|
h.App.MjTasks.Put(key, types.MjTask{
|
||||||
|
UserId: userVo.Id,
|
||||||
|
RoleId: role.Id,
|
||||||
|
Icon: role.Icon,
|
||||||
|
Client: ws,
|
||||||
|
ChatId: session.ChatId,
|
||||||
|
})
|
||||||
|
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: data,
|
Content: content,
|
||||||
})
|
})
|
||||||
contents = append(contents, data)
|
contents = append(contents, content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,7 +446,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
} else {
|
} else {
|
||||||
totalTokens = replyToken + getTotalTokens(req)
|
totalTokens = replyToken + getTotalTokens(req)
|
||||||
}
|
}
|
||||||
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
//utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||||
if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗
|
if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗
|
||||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||||
totalTokens))
|
totalTokens))
|
||||||
@ -468,18 +484,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
// OpenAI API 调用异常处理
|
// OpenAI API 调用异常处理
|
||||||
// TODO: 是否考虑重发消息?
|
// TODO: 是否考虑重发消息?
|
||||||
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
|
||||||
replyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
|
||||||
// 移除当前 API key
|
// 移除当前 API key
|
||||||
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
|
||||||
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
|
||||||
replyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
|
||||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||||
logger.Error(res.Error.Message)
|
logger.Error(res.Error.Message)
|
||||||
replyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
|
||||||
h.App.ChatContexts.Delete(session.ChatId)
|
h.App.ChatContexts.Delete(session.ChatId)
|
||||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||||
} else {
|
} else {
|
||||||
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -534,26 +550,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
|||||||
return client.Do(request)
|
return client.Do(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 回复客户片段端消息
|
|
||||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
|
||||||
msg, err := json.Marshal(message)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = client.(*types.WsClient).Send(msg)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error for reply message: %v", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 回复客户端一条完整的消息
|
|
||||||
func replyMessage(ws types.Client, message string) {
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tokens 统计 token 数量
|
// Tokens 统计 token 数量
|
||||||
func (h *ChatHandler) Tokens(c *gin.Context) {
|
func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||||
text := c.Query("text")
|
text := c.Query("text")
|
||||||
|
61
api/handler/mj_handler.go
Normal file
61
api/handler/mj_handler.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/utils"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Start = TaskStatus("Started")
|
||||||
|
Running = TaskStatus("Running")
|
||||||
|
Stopped = TaskStatus("Stopped")
|
||||||
|
Finished = TaskStatus("Finished")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Image struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
ProxyURL string `json:"proxy_url"`
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
Size int `json:"size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidJourneyHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler {
|
||||||
|
h := MidJourneyHandler{}
|
||||||
|
h.App = app
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||||
|
token := c.GetHeader("Authorization")
|
||||||
|
if token != h.App.Config.ExtConfig.Token {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
Image Image `json:"image"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Status TaskStatus `json:"status"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o"
|
||||||
|
wsClient := h.App.ChatClients.Get(sessionId)
|
||||||
|
utils.ReplyMessage(wsClient, "")
|
||||||
|
logger.Infof("Data: %+v", data)
|
||||||
|
resp.ERROR(c, "Error with CallBack")
|
||||||
|
}
|
@ -21,9 +21,9 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
|
|||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RewardHandler) Push(c *gin.Context) {
|
func (h *RewardHandler) Notify(c *gin.Context) {
|
||||||
token := c.GetHeader("X-TOKEN")
|
token := c.GetHeader("Authorization")
|
||||||
if token != h.App.Config.ChatPlusExtApiToken {
|
if token != h.App.Config.ExtConfig.Token {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
16
api/main.go
16
api/main.go
@ -104,15 +104,7 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建函数
|
// 创建函数
|
||||||
fx.Provide(func(config *types.AppConfig) (function.FuncZaoBao, error) {
|
fx.Provide(function.NewFunctions),
|
||||||
return function.NewZaoBao(config.ApiConfig), nil
|
|
||||||
}),
|
|
||||||
fx.Provide(func(config *types.AppConfig) (function.FuncWeiboHot, error) {
|
|
||||||
return function.NewWeiboHot(config.ApiConfig), nil
|
|
||||||
}),
|
|
||||||
fx.Provide(func(config *types.AppConfig) (function.FuncHeadlines, error) {
|
|
||||||
return function.NewHeadLines(config.ApiConfig), nil
|
|
||||||
}),
|
|
||||||
|
|
||||||
// 创建控制器
|
// 创建控制器
|
||||||
fx.Provide(handler.NewChatRoleHandler),
|
fx.Provide(handler.NewChatRoleHandler),
|
||||||
@ -122,6 +114,7 @@ func main() {
|
|||||||
fx.Provide(handler.NewSmsHandler),
|
fx.Provide(handler.NewSmsHandler),
|
||||||
fx.Provide(handler.NewRewardHandler),
|
fx.Provide(handler.NewRewardHandler),
|
||||||
fx.Provide(handler.NewCaptchaHandler),
|
fx.Provide(handler.NewCaptchaHandler),
|
||||||
|
fx.Provide(handler.NewMidJourneyHandler),
|
||||||
|
|
||||||
fx.Provide(admin.NewConfigHandler),
|
fx.Provide(admin.NewConfigHandler),
|
||||||
fx.Provide(admin.NewAdminHandler),
|
fx.Provide(admin.NewAdminHandler),
|
||||||
@ -180,9 +173,12 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
||||||
group := s.Engine.Group("/api/reward/")
|
group := s.Engine.Group("/api/reward/")
|
||||||
group.POST("push", h.Push)
|
group.POST("notify", h.Notify)
|
||||||
group.POST("verify", h.Verify)
|
group.POST("verify", h.Verify)
|
||||||
}),
|
}),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||||
|
s.Engine.POST("/api/mj/notify", h.Notify)
|
||||||
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
package function
|
package function
|
||||||
|
|
||||||
import "chatplus/core/types"
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
logger2 "chatplus/logger"
|
||||||
|
)
|
||||||
|
|
||||||
type Function interface {
|
type Function interface {
|
||||||
Invoke(...interface{}) (string, error)
|
Invoke(map[string]interface{}) (string, error)
|
||||||
Name() string
|
Name() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
type resVo struct {
|
type resVo struct {
|
||||||
Code types.BizCode `json:"code"`
|
Code types.BizCode `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
@ -22,3 +27,12 @@ type dataItem struct {
|
|||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
Remark string `json:"remark"`
|
Remark string `json:"remark"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewFunctions(config *types.AppConfig) map[string]Function {
|
||||||
|
return map[string]Function{
|
||||||
|
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
|
||||||
|
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
|
||||||
|
types.FuncHeadLine: NewHeadLines(config.ApiConfig),
|
||||||
|
types.FuncMidJourney: NewMidJourneyFunc(config.ExtConfig),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
60
api/service/function/mid_journey.go
Normal file
60
api/service/function/mid_journey.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package function
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/utils"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AI 绘画函数
|
||||||
|
|
||||||
|
type FuncMidJourney struct {
|
||||||
|
name string
|
||||||
|
config types.ChatPlusExtConfig
|
||||||
|
client *req.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney {
|
||||||
|
return FuncMidJourney{
|
||||||
|
name: "MidJourney AI 绘画",
|
||||||
|
config: config,
|
||||||
|
client: req.C().SetTimeout(10 * time.Second)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
||||||
|
if f.config.Token == "" {
|
||||||
|
return "", errors.New("无效的 API Token")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("MJ 绘画参数:%+v", params)
|
||||||
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
|
if !utils.IsEmptyValue(params["ar"]) {
|
||||||
|
prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"])
|
||||||
|
delete(params, "ar")
|
||||||
|
}
|
||||||
|
prompt = prompt + " --niji 5"
|
||||||
|
var res types.BizVo
|
||||||
|
r, err := f.client.R().
|
||||||
|
SetHeader("Authorization", f.config.Token).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(params).
|
||||||
|
SetSuccessResult(&res).Post(f.config.ApiURL)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return "", fmt.Errorf("%v%v", r.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return "", errors.New(res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncMidJourney) Name() string {
|
||||||
|
return f.name
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Function = &FuncMidJourney{}
|
@ -24,7 +24,7 @@ func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
|
|||||||
client: req.C().SetTimeout(10 * time.Second)}
|
client: req.C().SetTimeout(10 * time.Second)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
|
||||||
if f.config.Token == "" {
|
if f.config.Token == "" {
|
||||||
return "", errors.New("无效的 API Token")
|
return "", errors.New("无效的 API Token")
|
||||||
}
|
}
|
||||||
@ -35,11 +35,8 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
|||||||
SetHeader("AppId", f.config.AppId).
|
SetHeader("AppId", f.config.AppId).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil {
|
if err != nil || r.IsErrorState() {
|
||||||
return "", err
|
return "", fmt.Errorf("%v%v", err, r.Err)
|
||||||
}
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return "", r.Err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Code != types.Success {
|
if res.Code != types.Success {
|
||||||
@ -57,3 +54,5 @@ func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
|||||||
func (f FuncHeadlines) Name() string {
|
func (f FuncHeadlines) Name() string {
|
||||||
return f.name
|
return f.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Function = &FuncHeadlines{}
|
||||||
|
@ -24,7 +24,7 @@ func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
|
|||||||
client: req.C().SetTimeout(10 * time.Second)}
|
client: req.C().SetTimeout(10 * time.Second)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
|
||||||
if f.config.Token == "" {
|
if f.config.Token == "" {
|
||||||
return "", errors.New("无效的 API Token")
|
return "", errors.New("无效的 API Token")
|
||||||
}
|
}
|
||||||
@ -35,11 +35,8 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
|||||||
SetHeader("AppId", f.config.AppId).
|
SetHeader("AppId", f.config.AppId).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil {
|
if err != nil || r.IsErrorState() {
|
||||||
return "", err
|
return "", fmt.Errorf("%v%v", err, r.Err)
|
||||||
}
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return "", r.Err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Code != types.Success {
|
if res.Code != types.Success {
|
||||||
@ -57,3 +54,5 @@ func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
|||||||
func (f FuncWeiboHot) Name() string {
|
func (f FuncWeiboHot) Name() string {
|
||||||
return f.name
|
return f.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Function = &FuncWeiboHot{}
|
||||||
|
@ -24,7 +24,7 @@ func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
|
|||||||
client: req.C().SetTimeout(10 * time.Second)}
|
client: req.C().SetTimeout(10 * time.Second)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
|
||||||
if f.config.Token == "" {
|
if f.config.Token == "" {
|
||||||
return "", errors.New("无效的 API Token")
|
return "", errors.New("无效的 API Token")
|
||||||
}
|
}
|
||||||
@ -35,11 +35,8 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
|||||||
SetHeader("AppId", f.config.AppId).
|
SetHeader("AppId", f.config.AppId).
|
||||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
|
||||||
SetSuccessResult(&res).Get(url)
|
SetSuccessResult(&res).Get(url)
|
||||||
if err != nil {
|
if err != nil || r.IsErrorState() {
|
||||||
return "", err
|
return "", fmt.Errorf("%v%v", err, r.Err)
|
||||||
}
|
|
||||||
if r.IsErrorState() {
|
|
||||||
return "", r.Err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Code != types.Success {
|
if res.Code != types.Success {
|
||||||
@ -58,3 +55,5 @@ func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
|||||||
func (f FuncZaoBao) Name() string {
|
func (f FuncZaoBao) Name() string {
|
||||||
return f.name
|
return f.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Function = &FuncZaoBao{}
|
||||||
|
@ -1,68 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HttpGet(uri string, proxy string) ([]byte, error) {
|
|
||||||
var client *http.Client
|
|
||||||
if proxy == "" {
|
|
||||||
client = &http.Client{}
|
|
||||||
} else {
|
|
||||||
proxy, _ := url.Parse(proxy)
|
|
||||||
client = &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", uri, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func HttpPost(uri string, params map[string]interface{}, proxy string) ([]byte, error) {
|
|
||||||
data, err := json.Marshal(params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var client *http.Client
|
|
||||||
if proxy == "" {
|
|
||||||
client = &http.Client{}
|
|
||||||
} else {
|
|
||||||
proxy, _ := url.Parse(proxy)
|
|
||||||
client = &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", uri, bytes.NewBuffer(data))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
|
||||||
}
|
|
29
api/utils/websocket.go
Normal file
29
api/utils/websocket.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
logger2 "chatplus/logger"
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
// ReplyChunkMessage 回复客户片段端消息
|
||||||
|
func ReplyChunkMessage(client types.Client, message types.WsMessage) {
|
||||||
|
msg, err := json.Marshal(message)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = client.(*types.WsClient).Send(msg)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error for reply message: %v", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplyMessage 回复客户端一条完整的消息
|
||||||
|
func ReplyMessage(ws types.Client, message string) {
|
||||||
|
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
|
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||||
|
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user