mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-28 02:55:58 +08:00
merge v4.1.5
This commit is contained in:
@@ -71,6 +71,15 @@ TikaHost = "http://tika:9998"
|
||||
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
|
||||
RegistryKey = "chatgpt-plus" # 任务注册 key
|
||||
|
||||
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
|
||||
UseTls = false
|
||||
Host = "smtp.163.com"
|
||||
Port = 25
|
||||
AppName = "极客学长"
|
||||
From = "test@163.com" # 发件邮箱人地址
|
||||
Password = "" #邮箱 stmp 服务授权码
|
||||
|
||||
# 支付宝商户支付
|
||||
[AlipayConfig]
|
||||
Enabled = false # 启用支付宝支付通道
|
||||
SandBox = false # 是否启用沙盒模式
|
||||
@@ -80,31 +89,13 @@ TikaHost = "http://tika:9998"
|
||||
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
|
||||
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
|
||||
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
|
||||
|
||||
# 虎皮椒支付
|
||||
[HuPiPayConfig]
|
||||
Enabled = false
|
||||
Name = "wechat"
|
||||
AppId = ""
|
||||
AppSecret = ""
|
||||
ApiURL = "https://api.xunhupay.com"
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
|
||||
|
||||
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
|
||||
UseTls = false
|
||||
Host = "smtp.163.com"
|
||||
Port = 25
|
||||
AppName = "极客学长"
|
||||
From = "test@163.com" # 发件邮箱人地址
|
||||
Password = "" #邮箱 stmp 服务授权码
|
||||
|
||||
[JPayConfig] # PayJs 支付配置
|
||||
Enabled = false
|
||||
Name = "wechat" # 请不要改动
|
||||
AppId = "" # 商户 ID
|
||||
PrivateKey = "" # 秘钥
|
||||
ApiURL = "https://payjs.cn"
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
|
||||
|
||||
# 微信商户支付
|
||||
[WechatPayConfig]
|
||||
@@ -114,6 +105,11 @@ TikaHost = "http://tika:9998"
|
||||
SerialNo = "" # API 证书序列号
|
||||
PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
|
||||
ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的
|
||||
NotifyURL = "https://ai.r9it.com/api/payment/wechat/notify" # 支付成功异步回调地址,域名改成自己的
|
||||
ReturnURL = "" # 支付成功同步回调地址
|
||||
|
||||
# 易支付
|
||||
[GeekPayConfig]
|
||||
Enabled = true
|
||||
AppId = "" # 商户ID
|
||||
PrivateKey = "" # 商户私钥
|
||||
ApiURL = "https://pay.geekai.cn"
|
||||
Methods = ["alipay", "wxpay", "qqpay", "jdpay", "douyin", "paypal"] # 支持的支付方式
|
||||
|
||||
@@ -51,9 +51,9 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||
if debug { // 调试模式允许跨域请求 API
|
||||
s.Debug = debug
|
||||
s.Engine.Use(corsMiddleware())
|
||||
logger.Info("Enabled debug mode")
|
||||
}
|
||||
s.Engine.Use(corsMiddleware())
|
||||
s.Engine.Use(staticResourceMiddleware())
|
||||
s.Engine.Use(authorizeMiddleware(s, client))
|
||||
s.Engine.Use(parameterHandlerMiddleware())
|
||||
@@ -65,13 +65,13 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||
func (s *AppServer) Run(db *gorm.DB) error {
|
||||
// load system configs
|
||||
var sysConfig model.Config
|
||||
res := db.Where("marker", "system").First(&sysConfig)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||
err := db.Where("marker", "system").First(&sysConfig).Error
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to load system config: %v", err)
|
||||
}
|
||||
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode system config: %v", err)
|
||||
}
|
||||
logger.Infof("http://%s", s.Config.Listen)
|
||||
return s.Engine.Run(s.Config.Listen)
|
||||
@@ -101,9 +101,9 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization")
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
c.Header("Access-Control-Max-Age", "172800")
|
||||
//允许客户端传递校验信息比如 cookie (重要)
|
||||
@@ -131,7 +131,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||
if isAdminApi { // 后台管理 API
|
||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||
} else if c.Request.URL.Path == "/api/chat/new" {
|
||||
} else if c.Request.URL.Path == "/api/ws" { // Websocket 连接
|
||||
tokenString = c.Query("token")
|
||||
} else {
|
||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||
@@ -204,31 +204,25 @@ func needLogin(c *gin.Context) bool {
|
||||
c.Request.URL.Path == "/api/chat/history" ||
|
||||
c.Request.URL.Path == "/api/chat/detail" ||
|
||||
c.Request.URL.Path == "/api/chat/list" ||
|
||||
c.Request.URL.Path == "/api/role/list" ||
|
||||
c.Request.URL.Path == "/api/app/list" ||
|
||||
c.Request.URL.Path == "/api/app/type/list" ||
|
||||
c.Request.URL.Path == "/api/app/list/user" ||
|
||||
c.Request.URL.Path == "/api/model/list" ||
|
||||
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||
c.Request.URL.Path == "/api/mj/client" ||
|
||||
c.Request.URL.Path == "/api/mj/notify" ||
|
||||
c.Request.URL.Path == "/api/invite/hits" ||
|
||||
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||
c.Request.URL.Path == "/api/sd/client" ||
|
||||
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||
c.Request.URL.Path == "/api/dall/client" ||
|
||||
c.Request.URL.Path == "/api/product/list" ||
|
||||
c.Request.URL.Path == "/api/menu/list" ||
|
||||
c.Request.URL.Path == "/api/markMap/client" ||
|
||||
c.Request.URL.Path == "/api/payment/alipay/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/hupipay/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/payjs/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/wechat/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/client" ||
|
||||
c.Request.URL.Path == "/api/suno/detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
c.Request.URL.Path == "/api/download" ||
|
||||
c.Request.URL.Path == "/api/video/client" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||
|
||||
@@ -9,14 +9,14 @@ package types
|
||||
|
||||
// ApiRequest API 请求实体
|
||||
type ApiRequest struct {
|
||||
Model string `json:"model,omitempty"` // 兼容百度文心一言
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言
|
||||
Stream bool `json:"stream"`
|
||||
Messages []interface{} `json:"messages,omitempty"`
|
||||
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||
Model string `json:"model,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Messages []interface{} `json:"messages,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
|
||||
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
|
||||
@@ -52,12 +52,13 @@ type Delta struct {
|
||||
|
||||
// ChatSession 聊天会话对象
|
||||
type ChatSession struct {
|
||||
SessionId string `json:"session_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model ChatModel `json:"model"` // GPT 模型
|
||||
Tools string `json:"tools"` // 函数
|
||||
UserId uint `json:"user_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model ChatModel `json:"model"` // GPT 模型
|
||||
Start int64 `json:"start"` // 开始请求时间戳
|
||||
Tools []int `json:"tools"` // 工具函数列表
|
||||
Stream bool `json:"stream"` // 是否采用流式输出
|
||||
}
|
||||
|
||||
type ChatModel struct {
|
||||
|
||||
@@ -17,15 +17,17 @@ var ErrConClosed = errors.New("connection Closed")
|
||||
|
||||
// WsClient websocket client
|
||||
type WsClient struct {
|
||||
Id string
|
||||
Conn *websocket.Conn
|
||||
lock sync.Mutex
|
||||
mt int
|
||||
Closed bool
|
||||
}
|
||||
|
||||
func NewWsClient(conn *websocket.Conn) *WsClient {
|
||||
func NewWsClient(conn *websocket.Conn, id string) *WsClient {
|
||||
return &WsClient{
|
||||
Conn: conn,
|
||||
Id: id,
|
||||
lock: sync.Mutex{},
|
||||
mt: 2, // fixed bug for 'Invalid UTF-8 in text frame'
|
||||
Closed: false,
|
||||
|
||||
@@ -12,24 +12,23 @@ import (
|
||||
)
|
||||
|
||||
type AppConfig struct {
|
||||
Path string `toml:"-"`
|
||||
Listen string
|
||||
Session Session
|
||||
AdminSession Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
|
||||
Path string `toml:"-"`
|
||||
Listen string
|
||||
Session Session
|
||||
AdminSession Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
XXLConfig XXLConfig
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
JPayConfig JPayConfig // payjs 支付配置
|
||||
GeekPayConfig GeekPayConfig // GEEK 支付配置
|
||||
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
||||
TikaHost string // TiKa 服务器地址
|
||||
}
|
||||
@@ -58,8 +57,8 @@ type AlipayConfig struct {
|
||||
PublicKey string // 用户公钥文件路径
|
||||
AlipayPublicKey string // 支付宝公钥文件路径
|
||||
RootCert string // Root 秘钥路径
|
||||
NotifyURL string // 异步通知回调
|
||||
ReturnURL string // 支付成功返回地址
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
}
|
||||
|
||||
type WechatPayConfig struct {
|
||||
@@ -69,29 +68,27 @@ type WechatPayConfig struct {
|
||||
SerialNo string // 商户证书的证书序列号
|
||||
PrivateKey string // 用户私钥文件路径
|
||||
ApiV3Key string // API V3 秘钥
|
||||
NotifyURL string // 异步通知回调
|
||||
ReturnURL string // 支付成功返回地址
|
||||
NotifyURL string // 异步通知地址
|
||||
}
|
||||
|
||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||
Enabled bool // 是否启用该支付通道
|
||||
Name string // 支付名称,如:wechat/alipay
|
||||
AppId string // App ID
|
||||
AppSecret string // app 密钥
|
||||
ApiURL string // 支付网关
|
||||
NotifyURL string // 异步通知回调
|
||||
ReturnURL string // 支付成功返回地址
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
}
|
||||
|
||||
// JPayConfig PayJs 支付配置
|
||||
type JPayConfig struct {
|
||||
// GeekPayConfig GEEK支付配置
|
||||
type GeekPayConfig struct {
|
||||
Enabled bool
|
||||
Name string // 支付名称,默认 wechat
|
||||
AppId string // 商户 ID
|
||||
PrivateKey string // 私钥
|
||||
ApiURL string // API 网关
|
||||
NotifyURL string // 异步回调地址
|
||||
ReturnURL string // 支付成功返回地址
|
||||
AppId string // 商户 ID
|
||||
PrivateKey string // 私钥
|
||||
ApiURL string // API 网关
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
Methods []string // 支付方式
|
||||
}
|
||||
|
||||
type XXLConfig struct { // XXL 任务调度配置
|
||||
@@ -167,5 +164,6 @@ type SystemConfig struct {
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
||||
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||
}
|
||||
|
||||
@@ -22,3 +22,18 @@ type OrderRemark struct {
|
||||
Price float64 `json:"price"`
|
||||
Discount float64 `json:"discount"`
|
||||
}
|
||||
|
||||
var PayMethods = map[string]string{
|
||||
"alipay": "支付宝商号",
|
||||
"wechat": "微信商号",
|
||||
"hupi": "虎皮椒",
|
||||
"geek": "易支付",
|
||||
}
|
||||
var PayNames = map[string]string{
|
||||
"alipay": "支付宝",
|
||||
"wxpay": "微信支付",
|
||||
"qqpay": "QQ钱包",
|
||||
"jdpay": "京东支付",
|
||||
"douyin": "抖音支付",
|
||||
"paypal": "PayPal支付",
|
||||
}
|
||||
|
||||
@@ -24,8 +24,9 @@ const (
|
||||
|
||||
// MjTask MidJourney 任务
|
||||
type MjTask struct {
|
||||
Id uint `json:"id"`
|
||||
TaskId string `json:"task_id"`
|
||||
Id uint `json:"id"` // 任务ID
|
||||
TaskId string `json:"task_id"` // 中转任务ID
|
||||
ClientId string `json:"client_id"`
|
||||
ImgArr []string `json:"img_arr"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -43,12 +44,14 @@ type MjTask struct {
|
||||
type SdTask struct {
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
Type TaskType `json:"type"`
|
||||
ClientId string `json:"client_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Params SdTaskParams `json:"params"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
}
|
||||
|
||||
type SdTaskParams struct {
|
||||
ClientId string `json:"client_id"` // 客户端ID
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||
@@ -69,18 +72,20 @@ type SdTaskParams struct {
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
ClientId string `json:"client_id"`
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
|
||||
Power int `json:"power"`
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -104,13 +109,14 @@ const (
|
||||
)
|
||||
|
||||
type VideoTask struct {
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
}
|
||||
|
||||
type VideoParams struct {
|
||||
|
||||
@@ -17,21 +17,48 @@ type BizVo struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// WsMessage Websocket message
|
||||
type WsMessage struct {
|
||||
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
||||
Content interface{} `json:"content"`
|
||||
// ReplyMessage 对话回复消息结构
|
||||
type ReplyMessage struct {
|
||||
Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat
|
||||
ClientId string `json:"clientId"` // 客户端ID
|
||||
Type WsMsgType `json:"type"` // 消息类别
|
||||
Body interface{} `json:"body"`
|
||||
}
|
||||
|
||||
type WsMsgType string
|
||||
type WsChannel string
|
||||
|
||||
const (
|
||||
WsStart = WsMsgType("start")
|
||||
WsMiddle = WsMsgType("middle")
|
||||
WsEnd = WsMsgType("end")
|
||||
WsErr = WsMsgType("error")
|
||||
MsgTypeText = WsMsgType("text") // 输出内容
|
||||
MsgTypeEnd = WsMsgType("end")
|
||||
MsgTypeErr = WsMsgType("error")
|
||||
MsgTypePing = WsMsgType("ping") // 心跳消息
|
||||
|
||||
ChPing = WsChannel("ping")
|
||||
ChChat = WsChannel("chat")
|
||||
ChMj = WsChannel("mj")
|
||||
ChSd = WsChannel("sd")
|
||||
ChDall = WsChannel("dall")
|
||||
ChSuno = WsChannel("suno")
|
||||
ChLuma = WsChannel("luma")
|
||||
)
|
||||
|
||||
// InputMessage 对话输入消息结构
|
||||
type InputMessage struct {
|
||||
Channel WsChannel `json:"channel"` // 消息频道
|
||||
Type WsMsgType `json:"type"` // 消息类别
|
||||
Body interface{} `json:"body"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Tools []int `json:"tools,omitempty"` // 允许调用工具列表
|
||||
Stream bool `json:"stream,omitempty"` // 是否采用流式输出
|
||||
RoleId int `json:"role_id"`
|
||||
ModelId int `json:"model_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BizCode int
|
||||
|
||||
const (
|
||||
|
||||
@@ -74,7 +74,6 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||
status := h.GetBool(c, "status")
|
||||
t := h.GetTrim(c, "type")
|
||||
platform := h.GetTrim(c, "platform")
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if status {
|
||||
@@ -83,9 +82,6 @@ func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||
if t != "" {
|
||||
session = session.Where("type", t)
|
||||
}
|
||||
if platform != "" {
|
||||
session = session.Where("platform", platform)
|
||||
}
|
||||
|
||||
var items []model.ApiKey
|
||||
var keys = make([]vo.ApiKey, 0)
|
||||
|
||||
@@ -22,16 +22,16 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatRoleHandler struct {
|
||||
type ChatAppHandler struct {
|
||||
handler.BaseHandler
|
||||
}
|
||||
|
||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||
return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// Save 创建或者更新某个角色
|
||||
func (h *ChatRoleHandler) Save(c *gin.Context) {
|
||||
func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||
var data vo.ChatRole
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -64,7 +64,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
var items []model.ChatRole
|
||||
var roles = make([]vo.ChatRole, 0)
|
||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||
@@ -75,13 +75,18 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
|
||||
// initialize model mane for role
|
||||
modelIds := make([]int, 0)
|
||||
typeIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
if v.ModelId > 0 {
|
||||
modelIds = append(modelIds, v.ModelId)
|
||||
}
|
||||
if v.Tid > 0 {
|
||||
typeIds = append(typeIds, v.Tid)
|
||||
}
|
||||
}
|
||||
|
||||
modelNameMap := make(map[int]string)
|
||||
typeNameMap := make(map[int]string)
|
||||
if len(modelIds) > 0 {
|
||||
var models []model.ChatModel
|
||||
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
||||
@@ -91,6 +96,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(typeIds) > 0 {
|
||||
var appTypes []model.AppType
|
||||
tx := h.DB.Where("id IN ?", typeIds).Find(&appTypes)
|
||||
if tx.Error == nil {
|
||||
for _, m := range appTypes {
|
||||
typeNameMap[int(m.Id)] = m.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var role vo.ChatRole
|
||||
@@ -100,6 +114,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
role.CreatedAt = v.CreatedAt.Unix()
|
||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||
role.ModelName = modelNameMap[role.ModelId]
|
||||
role.TypeName = typeNameMap[role.Tid]
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
@@ -108,7 +123,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Sort 更新角色排序
|
||||
func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
||||
func (h *ChatAppHandler) Sort(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
Sorts []int `json:"sorts"`
|
||||
@@ -130,7 +145,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *ChatRoleHandler) Set(c *gin.Context) {
|
||||
func (h *ChatAppHandler) Set(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Filed string `json:"filed"`
|
||||
@@ -150,7 +165,7 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
|
||||
func (h *ChatAppHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
|
||||
if id <= 0 {
|
||||
148
api/handler/admin/chat_app_type_handler.go
Normal file
148
api/handler/admin/chat_app_type_handler.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatAppTypeHandler struct {
|
||||
handler.BaseHandler
|
||||
}
|
||||
|
||||
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
|
||||
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// Save 创建或更新App类型
|
||||
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Icon string `json:"icon"`
|
||||
SortNum int `json:"sort_num"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Id == 0 { // for add
|
||||
err := h.DB.Where("name", data.Name).First(&model.AppType{}).Error
|
||||
if err == nil {
|
||||
resp.ERROR(c, "当前分类已经存在")
|
||||
return
|
||||
}
|
||||
err = h.DB.Create(&model.AppType{
|
||||
Name: data.Name,
|
||||
Icon: data.Icon,
|
||||
Enabled: data.Enabled,
|
||||
SortNum: data.SortNum,
|
||||
}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
} else { // for update
|
||||
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).Updates(map[string]interface{}{
|
||||
"name": data.Name,
|
||||
"icon": data.Icon,
|
||||
"enabled": data.Enabled,
|
||||
}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// List 获取App类型列表
|
||||
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||
var items []model.AppType
|
||||
var appTypes = make([]vo.AppType, 0)
|
||||
err := h.DB.Order("sort_num ASC").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var appType vo.AppType
|
||||
err = utils.CopyObject(v, &appType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
appType.Id = v.Id
|
||||
appType.CreatedAt = v.CreatedAt.Unix()
|
||||
appTypes = append(appTypes, appType)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, appTypes)
|
||||
}
|
||||
|
||||
// Remove 删除App类型
|
||||
func (h *ChatAppTypeHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
err := h.DB.Where("id", id).Delete(&model.AppType{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Enable 启用|禁用
|
||||
func (h *ChatAppTypeHandler) Enable(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Model(&model.AppType{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Sort 更新排序
|
||||
func (h *ChatAppTypeHandler) Sort(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
Sorts []int `json:"sorts"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
for index, id := range data.Ids {
|
||||
err := h.DB.Model(&model.AppType{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
@@ -140,3 +140,68 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
license := h.licenseService.GetLicense()
|
||||
resp.SUCCESS(c, license)
|
||||
}
|
||||
|
||||
// FixData 修复数据
|
||||
func (h *ConfigHandler) FixData(c *gin.Context) {
|
||||
var fixed bool
|
||||
version := "data_fix_4.1.4"
|
||||
err := h.levelDB.Get(version, &fixed)
|
||||
if err == nil || fixed {
|
||||
resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
||||
return
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
var users []model.User
|
||||
err = tx.Find(&users).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.Email != "" || user.Mobile != "" {
|
||||
continue
|
||||
}
|
||||
if utils.IsValidEmail(user.Username) {
|
||||
user.Email = user.Username
|
||||
} else if utils.IsValidMobile(user.Username) {
|
||||
user.Mobile = user.Username
|
||||
}
|
||||
err = tx.Save(&user).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var orders []model.Order
|
||||
err = h.DB.Find(&orders).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
for _, order := range orders {
|
||||
if order.PayWay == "支付宝" {
|
||||
order.PayWay = "alipay"
|
||||
order.PayType = "alipay"
|
||||
} else if order.PayWay == "微信支付" {
|
||||
order.PayWay = "wechat"
|
||||
order.PayType = "wxpay"
|
||||
} else if order.PayWay == "hupi" {
|
||||
order.PayType = "wxpay"
|
||||
}
|
||||
err = tx.Save(&order).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
err = h.levelDB.Put(version, true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -67,6 +68,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
order.Id = item.Id
|
||||
order.CreatedAt = item.CreatedAt.Unix()
|
||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||
payMethod, ok := types.PayMethods[item.PayWay]
|
||||
if !ok {
|
||||
payMethod = item.PayWay
|
||||
}
|
||||
payName, ok := types.PayNames[item.PayType]
|
||||
if !ok {
|
||||
payName = item.PayWay
|
||||
}
|
||||
order.PayMethod = payMethod
|
||||
order.PayName = payName
|
||||
list = append(list, order)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
@@ -92,7 +103,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}).Error
|
||||
err := h.DB.Where("id = ?", id).Delete(&model.Order{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -102,8 +113,20 @@ func (h *OrderHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *OrderHandler) Clear(c *gin.Context) {
|
||||
|
||||
err := h.DB.Unscoped().Where("status <> ?", 2).Where("pay_time", 0).Delete(&model.Order{}).Error
|
||||
var orders []model.Order
|
||||
err := h.DB.Where("status <> ?", 2).Where("pay_time", 0).Find(&orders).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
deleteIds := make([]uint, 0)
|
||||
for _, order := range orders {
|
||||
// 只删除 15 分钟内的未支付订单
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
|
||||
deleteIds = append(deleteIds, order.Id)
|
||||
}
|
||||
}
|
||||
err = h.DB.Where("id IN ?", deleteIds).Delete(&model.Order{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -26,10 +27,11 @@ import (
|
||||
type UserHandler struct {
|
||||
handler.BaseHandler
|
||||
licenseService *service.LicenseService
|
||||
redis *redis.Client
|
||||
}
|
||||
|
||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
|
||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService, redisCli *redis.Client) *UserHandler {
|
||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
|
||||
}
|
||||
|
||||
// List 用户列表
|
||||
@@ -73,6 +75,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
Id uint `json:"id"`
|
||||
Password string `json:"password"`
|
||||
Username string `json:"username"`
|
||||
Mobile string `json:"mobile"`
|
||||
Email string `json:"email"`
|
||||
ChatRoles []string `json:"chat_roles"`
|
||||
ChatModels []int `json:"chat_models"`
|
||||
ExpiredTime string `json:"expired_time"`
|
||||
@@ -102,6 +106,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
}
|
||||
var oldPower = user.Power
|
||||
user.Username = data.Username
|
||||
user.Email = data.Email
|
||||
user.Mobile = data.Mobile
|
||||
user.Status = data.Status
|
||||
user.Vip = data.Vip
|
||||
user.Power = data.Power
|
||||
@@ -109,7 +115,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
user.ChatModels = utils.JsonEncode(data.ChatModels)
|
||||
user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
|
||||
|
||||
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
||||
res = h.DB.Select("username", "mobile", "email", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
|
||||
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
@@ -135,6 +142,13 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
// 如果禁用了用户,则将用户踢下线
|
||||
if user.Status == false {
|
||||
key := fmt.Sprintf("users/%v", user.Id)
|
||||
if _, err := h.redis.Del(c, key).Result(); err != nil {
|
||||
logger.Error("error with delete session: ", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 检查用户是否已经存在
|
||||
h.DB.Where("username", data.Username).First(&user)
|
||||
@@ -147,6 +161,8 @@ func (h *UserHandler) Save(c *gin.Context) {
|
||||
u := model.User{
|
||||
Username: data.Username,
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Mobile: data.Mobile,
|
||||
Email: data.Email,
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
Power: data.Power,
|
||||
@@ -262,7 +278,7 @@ func (h *UserHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, "删除失败")
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,13 +8,13 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
|
||||
}
|
||||
|
||||
var user model.User
|
||||
res := h.DB.First(&user, userId)
|
||||
res := h.DB.Where("id", userId).First(&user)
|
||||
// 更新缓存
|
||||
if res.Error == nil {
|
||||
c.Set(types.LoginUserCache, user)
|
||||
|
||||
44
api/handler/chat_app_type_handler.go
Normal file
44
api/handler/chat_app_type_handler.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatAppTypeHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler {
|
||||
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// List 获取App类型列表
|
||||
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||
var items []model.AppType
|
||||
var appTypes = make([]vo.AppType, 0)
|
||||
err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var appType vo.AppType
|
||||
err = utils.CopyObject(v, &appType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
appType.Id = v.Id
|
||||
appType.CreatedAt = v.CreatedAt.Unix()
|
||||
appTypes = append(appTypes, appType)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, appTypes)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package chatimpl
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
@@ -15,8 +15,6 @@ import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
@@ -33,14 +31,11 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type ChatHandler struct {
|
||||
handler.BaseHandler
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
@@ -51,7 +46,7 @@ type ChatHandler struct {
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
@@ -61,112 +56,6 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
|
||||
}
|
||||
}
|
||||
|
||||
// ChatHandle 处理聊天 WebSocket 请求
|
||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionId := c.Query("session_id")
|
||||
roleId := h.GetInt(c, "role_id", 0)
|
||||
chatId := c.Query("chat_id")
|
||||
modelId := h.GetInt(c, "model_id", 0)
|
||||
tools := c.Query("tools")
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
var chatRole model.ChatRole
|
||||
res := h.DB.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// if the role bind a model_id, use role's bind model_id
|
||||
if chatRole.ModelId > 0 {
|
||||
modelId = chatRole.ModelId
|
||||
}
|
||||
// get model info
|
||||
var chatModel model.ChatModel
|
||||
res = h.DB.First(&chatModel, modelId)
|
||||
if res.Error != nil || chatModel.Enabled == false {
|
||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
session := &types.ChatSession{
|
||||
SessionId: sessionId,
|
||||
ClientIP: c.ClientIP(),
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Tools: tools,
|
||||
}
|
||||
|
||||
// use old chat data override the chat model and role ID
|
||||
var chat model.ChatItem
|
||||
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
|
||||
if res.Error == nil {
|
||||
chatModel.Id = chat.ModelId
|
||||
roleId = int(chat.RoleId)
|
||||
}
|
||||
|
||||
session.ChatId = chatId
|
||||
session.Model = types.ChatModel{
|
||||
Id: chatModel.Id,
|
||||
Name: chatModel.Name,
|
||||
Value: chatModel.Value,
|
||||
Power: chatModel.Power,
|
||||
MaxTokens: chatModel.MaxTokens,
|
||||
MaxContext: chatModel.MaxContext,
|
||||
Temperature: chatModel.Temperature,
|
||||
KeyId: chatModel.KeyId}
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||
client.Close()
|
||||
cancelFunc := h.ReqCancelFunc.Get(sessionId)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
h.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var message types.WsMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
if message.Type == "heartbeat" {
|
||||
logger.Debug("收到 Chat 心跳消息:", message.Content)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Receive a message: ", message.Content)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.ReplyMessage(client, err.Error())
|
||||
} else {
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
logger.Infof("回答完毕: %v", message.Content)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
||||
if !h.App.Debug {
|
||||
defer func() {
|
||||
@@ -208,16 +97,22 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: session.Model.Value,
|
||||
Stream: true,
|
||||
Model: session.Model.Value,
|
||||
}
|
||||
// 兼容 GPT-O1 模型
|
||||
if strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
utils.SendChunkMsg(ws, "AI 正在思考...\n")
|
||||
req.Stream = false
|
||||
session.Start = time.Now().Unix()
|
||||
} else {
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.Stream = session.Stream
|
||||
}
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
|
||||
if session.Tools != "" {
|
||||
toolIds := strings.Split(session.Tools, ",")
|
||||
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
var items []model.Function
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
|
||||
if res.Error == nil {
|
||||
var tools = make([]types.Tool, 0)
|
||||
for _, v := range items {
|
||||
@@ -279,7 +174,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
v := messages[i]
|
||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||
tks, _ = utils.CalcTokens(v.Content, req.Model)
|
||||
// 上下文 token 超出了模型的最大上下文长度
|
||||
if tokens+tks >= session.Model.MaxContext {
|
||||
break
|
||||
@@ -450,7 +345,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debugf(utils.JsonEncode(req))
|
||||
logger.Debugf("对话请求消息体:%+v", req)
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 创建 HttpClient 请求对象
|
||||
@@ -502,8 +397,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
|
||||
|
||||
func (h *ChatHandler) saveChatHistory(
|
||||
req types.ApiRequest,
|
||||
prompt string,
|
||||
contents []string,
|
||||
usage Usage,
|
||||
message types.Message,
|
||||
chatCtx []types.Message,
|
||||
session *types.ChatSession,
|
||||
@@ -511,12 +405,8 @@ func (h *ChatHandler) saveChatHistory(
|
||||
userVo vo.User,
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
if message.Role == "" {
|
||||
message.Role = "assistant"
|
||||
}
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
useMsg := types.Message{Role: "user", Content: usage.Prompt}
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
@@ -526,42 +416,52 @@ func (h *ChatHandler) saveChatHistory(
|
||||
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
var promptTokens, replyTokens, totalTokens int
|
||||
if usage.PromptTokens > 0 {
|
||||
promptTokens = usage.PromptTokens
|
||||
} else {
|
||||
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
|
||||
}
|
||||
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(usage.Prompt),
|
||||
Tokens: promptTokens,
|
||||
TotalTokens: promptTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
err = h.DB.Save(&historyUserMsg).Error
|
||||
err := h.DB.Save(&historyUserMsg).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save prompt history message: ", err)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyTokens + getTotalTokens(req)
|
||||
if usage.CompletionTokens > 0 {
|
||||
replyTokens = usage.CompletionTokens
|
||||
totalTokens = usage.TotalTokens
|
||||
} else {
|
||||
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens = replyTokens + getTotalTokens(req)
|
||||
}
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: usage.Content,
|
||||
Tokens: replyTokens,
|
||||
TotalTokens: totalTokens,
|
||||
UseContext: true,
|
||||
Model: req.Model,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
@@ -572,7 +472,7 @@ func (h *ChatHandler) saveChatHistory(
|
||||
|
||||
// 更新用户算力
|
||||
if session.Model.Power > 0 {
|
||||
h.subUserPower(userVo, session, promptToken, replyTokens)
|
||||
h.subUserPower(userVo, session, promptTokens, replyTokens)
|
||||
}
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
@@ -582,10 +482,10 @@ func (h *ChatHandler) saveChatHistory(
|
||||
chatItem.UserId = userVo.Id
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
if utf8.RuneCountInString(prompt) > 30 {
|
||||
chatItem.Title = string([]rune(prompt)[:30]) + "..."
|
||||
if utf8.RuneCountInString(usage.Prompt) > 30 {
|
||||
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
|
||||
} else {
|
||||
chatItem.Title = prompt
|
||||
chatItem.Title = usage.Prompt
|
||||
}
|
||||
chatItem.Model = req.Model
|
||||
err = h.DB.Create(&chatItem).Error
|
||||
@@ -1,4 +1,4 @@
|
||||
package chatimpl
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
@@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var items = make([]vo.ChatItem, 0)
|
||||
var chats []model.ChatItem
|
||||
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
||||
if res.Error == nil {
|
||||
var roleIds = make([]uint, 0)
|
||||
for _, chat := range chats {
|
||||
roleIds = append(roleIds, chat.RoleId)
|
||||
}
|
||||
var roles []model.ChatRole
|
||||
res = h.DB.Find(&roles, roleIds)
|
||||
if res.Error == nil {
|
||||
roleMap := make(map[uint]model.ChatRole)
|
||||
for _, role := range roles {
|
||||
roleMap[role.Id] = role
|
||||
}
|
||||
h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
if len(chats) == 0 {
|
||||
resp.SUCCESS(c, items)
|
||||
return
|
||||
}
|
||||
|
||||
for _, chat := range chats {
|
||||
var item vo.ChatItem
|
||||
err := utils.CopyObject(chat, &item)
|
||||
if err == nil {
|
||||
item.Id = chat.Id
|
||||
item.Icon = roleMap[chat.RoleId].Icon
|
||||
items = append(items, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
var roleIds = make([]uint, 0)
|
||||
var modelValues = make([]string, 0)
|
||||
for _, chat := range chats {
|
||||
roleIds = append(roleIds, chat.RoleId)
|
||||
modelValues = append(modelValues, chat.Model)
|
||||
}
|
||||
|
||||
var roles []model.ChatRole
|
||||
var models []model.ChatModel
|
||||
roleMap := make(map[uint]model.ChatRole)
|
||||
modelMap := make(map[string]model.ChatModel)
|
||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||
h.DB.Where("value IN ?", modelValues).Find(&models)
|
||||
for _, role := range roles {
|
||||
roleMap[role.Id] = role
|
||||
}
|
||||
for _, m := range models {
|
||||
modelMap[m.Value] = m
|
||||
}
|
||||
for _, chat := range chats {
|
||||
var item vo.ChatItem
|
||||
err := utils.CopyObject(chat, &item)
|
||||
if err == nil {
|
||||
item.Id = chat.Id
|
||||
item.Icon = roleMap[chat.RoleId].Icon
|
||||
item.ModelId = modelMap[chat.Model].Id
|
||||
items = append(items, item)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, items)
|
||||
}
|
||||
@@ -29,10 +29,37 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||
|
||||
// List 获取用户聊天应用列表
|
||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
tid := h.GetInt(c, "tid", 0)
|
||||
var roles []model.ChatRole
|
||||
session := h.DB.Where("enable", true)
|
||||
if tid > 0 {
|
||||
session = session.Where("tid", tid)
|
||||
}
|
||||
err := session.Order("sort_num ASC").Find(&roles).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
for _, r := range roles {
|
||||
var v vo.ChatRole
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
v.Id = r.Id
|
||||
roleVos = append(roleVos, v)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, roleVos)
|
||||
}
|
||||
|
||||
// ListByUser 获取用户添加的角色列表
|
||||
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
var roles []model.ChatRole
|
||||
query := h.DB.Where("enable", true)
|
||||
session := h.DB.Where("enable", true)
|
||||
// 如果用户没登录,则获取所有角色
|
||||
if userId > 0 {
|
||||
var user model.User
|
||||
h.DB.First(&user, userId)
|
||||
@@ -42,12 +69,16 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
resp.ERROR(c, "角色解析失败!")
|
||||
return
|
||||
}
|
||||
query = query.Where("marker IN ?", roleKeys)
|
||||
// 保证用户至少有一个角色可用
|
||||
if len(roleKeys) > 0 {
|
||||
session = session.Where("marker IN ?", roleKeys)
|
||||
}
|
||||
}
|
||||
|
||||
if id > 0 {
|
||||
query = query.Or("id", id)
|
||||
session = session.Or("id", id)
|
||||
}
|
||||
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
|
||||
res := session.Order("sort_num ASC").Find(&roles)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
|
||||
@@ -20,9 +20,7 @@ import (
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type DallJobHandler struct {
|
||||
@@ -45,49 +43,6 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *DallJobHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.dallService.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
h.dallService.Clients.Delete(uint(userId))
|
||||
return
|
||||
}
|
||||
|
||||
var message types.WsMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
if message.Type == "heartbeat" {
|
||||
logger.Debug("收到 DallE 心跳消息:", message.Content)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -129,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.dallService.PushTask(types.DallTask{
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
ClientId: data.ClientId,
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
})
|
||||
|
||||
client := h.dallService.Clients.Get(job.UserId)
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,24 +8,15 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MarkMapHandler 生成思维导图
|
||||
@@ -43,69 +34,35 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
// Generate 生成思维导图
|
||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
ModelId int `json:"model_id"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
modelId := h.GetInt(c, "model_id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.clients.Put(userId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
h.clients.Delete(userId)
|
||||
return
|
||||
}
|
||||
|
||||
var message types.WsMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
if message.Type == "heartbeat" {
|
||||
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
|
||||
continue
|
||||
}
|
||||
// change model
|
||||
if message.Type == "model_id" {
|
||||
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Receive a message: ", message.Content)
|
||||
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, userId)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query user info: %v", res.Error)
|
||||
err := h.DB.Where("id", userId).First(&user, userId).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with query user info")
|
||||
return
|
||||
}
|
||||
var chatModel model.ChatModel
|
||||
res = h.DB.Where("id", modelId).First(&chatModel)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query chat model: %v", res.Error)
|
||||
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with query chat model")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < chatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
|
||||
resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power))
|
||||
return
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 0)
|
||||
@@ -127,122 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
||||
### 支付宝
|
||||
### 微信
|
||||
|
||||
另外,除此之外不要任何解释性语句。
|
||||
请直接生成结果,不要任何解释性语句。
|
||||
`})
|
||||
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", prompt)})
|
||||
var req = types.ApiRequest{
|
||||
Model: chatModel.Value,
|
||||
Stream: true,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
response, err := h.doRequest(req, chatModel, &apiKey)
|
||||
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", data.Prompt)})
|
||||
content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
// 循环读取 Chunk 消息
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
var isNew = true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
return fmt.Errorf("error with decode data: %v", line)
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" {
|
||||
break
|
||||
}
|
||||
|
||||
if isNew {
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
|
||||
isNew = false
|
||||
}
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
} // end for
|
||||
|
||||
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
||||
|
||||
} else {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
if chatModel.Power > 0 {
|
||||
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
resp.ERROR(c, "error with save power log, "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if chatModel.KeyId > 0 {
|
||||
session = session.Where("id", chatModel.KeyId)
|
||||
} else { // use the last unused key
|
||||
session = session.Where("type", "chat").
|
||||
Where("enabled", true).Order("last_used_at ASC")
|
||||
}
|
||||
|
||||
res := session.First(apiKey)
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
||||
return client.Do(request)
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
@@ -19,12 +18,10 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -65,31 +62,11 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.mjService.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
TaskType string `json:"task_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegPrompt string `json:"neg_prompt"`
|
||||
Rate string `json:"rate"`
|
||||
@@ -200,6 +177,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
ClientId: data.ClientId,
|
||||
TaskId: taskId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
Prompt: data.Prompt,
|
||||
@@ -210,11 +188,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
// update user's power
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -231,6 +204,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
type reqVo struct {
|
||||
Index int `json:"index"`
|
||||
ClientId string `json:"client_id"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
@@ -267,6 +241,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskUpscale,
|
||||
UserId: userId,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -276,11 +251,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
// update user's power
|
||||
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -328,6 +298,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
Type: types.TaskVariation,
|
||||
ClientId: data.ClientId,
|
||||
UserId: userId,
|
||||
Index: data.Index,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -336,11 +307,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "mid-journey",
|
||||
@@ -420,14 +386,6 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
return nil, vo.NewPage(total, page, pageSize, jobs)
|
||||
@@ -472,11 +430,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,9 @@ func (h *NetHandler) Upload(c *gin.Context) {
|
||||
|
||||
func (h *NetHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Urls []string `json:"urls,omitempty"`
|
||||
Urls []string `json:"urls,omitempty"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -79,21 +81,32 @@ func (h *NetHandler) List(c *gin.Context) {
|
||||
if len(data.Urls) > 0 {
|
||||
session = session.Where("url IN ?", data.Urls)
|
||||
}
|
||||
session.Find(&items)
|
||||
if len(items) > 0 {
|
||||
for _, v := range items {
|
||||
var file vo.File
|
||||
err := utils.CopyObject(v, &file)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
file.CreatedAt = v.CreatedAt.Unix()
|
||||
files = append(files, file)
|
||||
}
|
||||
// 统计总数
|
||||
var total int64
|
||||
session.Model(&model.File{}).Count(&total)
|
||||
|
||||
if data.Page > 0 && data.PageSize > 0 {
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
session = session.Offset(offset).Limit(data.PageSize)
|
||||
}
|
||||
err := session.Order("id desc").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, files)
|
||||
for _, v := range items {
|
||||
var file vo.File
|
||||
err := utils.CopyObject(v, &file)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
file.CreatedAt = v.CreatedAt.Unix()
|
||||
files = append(files, file)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, files))
|
||||
}
|
||||
|
||||
// Remove remove files
|
||||
@@ -1,4 +1,4 @@
|
||||
package chatimpl
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
@@ -23,6 +23,32 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Usage struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIResVo struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// OPenAI 消息发送实现
|
||||
func (h *ChatHandler) sendOpenAiMessage(
|
||||
chatCtx []types.Message,
|
||||
@@ -49,17 +75,21 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
defer response.Body.Close()
|
||||
}
|
||||
|
||||
if response.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
replyCreatedAt := time.Now() // 记录回复时间
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var message = types.Message{Role: "assistant"}
|
||||
var contents = make([]string, 0)
|
||||
var function model.Function
|
||||
var toolCall = false
|
||||
var arguments = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
var isNew = true
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
@@ -78,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -106,8 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
if res.Error == nil {
|
||||
toolCall = true
|
||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
|
||||
utils.SendChunkMsg(ws, callMsg)
|
||||
contents = append(contents, callMsg)
|
||||
}
|
||||
continue
|
||||
@@ -124,14 +153,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, utils.InterfaceToString(content))
|
||||
if isNew {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
isNew = false
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
utils.SendChunkMsg(ws, responseBody.Choices[0].Delta.Content)
|
||||
}
|
||||
} // end for
|
||||
|
||||
@@ -149,7 +171,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
params["user_id"] = userVo.Id
|
||||
var apiRes types.BizVo
|
||||
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
|
||||
r, err := req2.C().R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", function.Token).
|
||||
SetBody(params).
|
||||
SetSuccessResult(&apiRes).Post(function.Action)
|
||||
@@ -160,28 +182,45 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
errMsg = r.Status
|
||||
}
|
||||
if errMsg != "" || apiRes.Code != types.Success {
|
||||
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
errMsg = "调用函数工具出错:" + apiRes.Message + errMsg
|
||||
contents = append(contents, errMsg)
|
||||
} else {
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: apiRes.Data,
|
||||
})
|
||||
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
||||
errMsg = utils.InterfaceToString(apiRes.Data)
|
||||
contents = append(contents, errMsg)
|
||||
}
|
||||
utils.SendChunkMsg(ws, errMsg)
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
usage := Usage{
|
||||
Prompt: prompt,
|
||||
Content: strings.Join(contents, ""),
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
message.Content = usage.Content
|
||||
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
}
|
||||
} else {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%s", body)
|
||||
} else { // 非流式输出
|
||||
var respVo OpenAIResVo
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取响应失败:%v", body)
|
||||
}
|
||||
err = json.Unmarshal(body, &respVo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析响应失败:%v", body)
|
||||
}
|
||||
content := respVo.Choices[0].Message.Content
|
||||
if strings.HasPrefix(req.Model, "o1-") {
|
||||
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||
}
|
||||
utils.SendChunkMsg(ws, content)
|
||||
respVo.Usage.Prompt = prompt
|
||||
respVo.Usage.Content = content
|
||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -48,6 +48,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
order.Id = item.Id
|
||||
order.CreatedAt = item.CreatedAt.Unix()
|
||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||
payMethod, ok := types.PayMethods[item.PayWay]
|
||||
if !ok {
|
||||
payMethod = item.PayWay
|
||||
}
|
||||
payName, ok := types.PayNames[item.PayType]
|
||||
if !ok {
|
||||
payName = item.PayWay
|
||||
}
|
||||
order.PayMethod = payMethod
|
||||
order.PayName = payName
|
||||
list = append(list, order)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
|
||||
@@ -9,7 +9,6 @@ package handler
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
@@ -19,9 +18,7 @@ import (
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/shopspring/decimal"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -34,21 +31,15 @@ type PayWay struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
var (
|
||||
PayWayAlipay = PayWay{Name: "支付宝", Value: "alipay"}
|
||||
PayWayXunHu = PayWay{Name: "虎皮椒", Value: "hupi"}
|
||||
PayWayJs = PayWay{Name: "PayJS", Value: "payjs"}
|
||||
PayWayWechat = PayWay{Name: "微信支付", Value: "wechat"}
|
||||
)
|
||||
|
||||
// PaymentHandler 支付服务回调 handler
|
||||
type PaymentHandler struct {
|
||||
BaseHandler
|
||||
alipayService *payment.AlipayService
|
||||
huPiPayService *payment.HuPiPayService
|
||||
jsPayService *payment.JPayService
|
||||
geekPayService *payment.GeekPayService
|
||||
wechatPayService *payment.WechatPayService
|
||||
snowflake *service.Snowflake
|
||||
userService *service.UserService
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
signKey string // 用来签名的随机秘钥
|
||||
@@ -58,17 +49,19 @@ func NewPaymentHandler(
|
||||
server *core.AppServer,
|
||||
alipayService *payment.AlipayService,
|
||||
huPiPayService *payment.HuPiPayService,
|
||||
jsPayService *payment.JPayService,
|
||||
geekPayService *payment.GeekPayService,
|
||||
wechatPayService *payment.WechatPayService,
|
||||
db *gorm.DB,
|
||||
userService *service.UserService,
|
||||
snowflake *service.Snowflake,
|
||||
fs embed.FS) *PaymentHandler {
|
||||
return &PaymentHandler{
|
||||
alipayService: alipayService,
|
||||
huPiPayService: huPiPayService,
|
||||
jsPayService: jsPayService,
|
||||
geekPayService: geekPayService,
|
||||
wechatPayService: wechatPayService,
|
||||
snowflake: snowflake,
|
||||
userService: userService,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
BaseHandler: BaseHandler{
|
||||
@@ -79,309 +72,167 @@ func NewPaymentHandler(
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) DoPay(c *gin.Context) {
|
||||
orderNo := h.GetTrim(c, "order_no")
|
||||
payWay := h.GetTrim(c, "pay_way")
|
||||
t := h.GetInt(c, "t", 0)
|
||||
sign := h.GetTrim(c, "sign")
|
||||
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey)
|
||||
newSign := utils.Sha256(signStr)
|
||||
if newSign != sign {
|
||||
resp.ERROR(c, "订单签名错误!")
|
||||
return
|
||||
func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way"`
|
||||
PayType string `json:"pay_type"`
|
||||
ProductId int `json:"product_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Device string `json:"device"`
|
||||
Host string `json:"host"`
|
||||
}
|
||||
|
||||
// 检查二维码是否过期
|
||||
if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) {
|
||||
resp.ERROR(c, "支付二维码已过期,请重新生成!")
|
||||
return
|
||||
}
|
||||
|
||||
if orderNo == "" {
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var order model.Order
|
||||
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Order not found")
|
||||
var product model.Product
|
||||
err := h.DB.Where("id", data.ProductId).First(&product).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Product not found")
|
||||
return
|
||||
}
|
||||
|
||||
// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
|
||||
if order.Status == types.OrderPaidSuccess {
|
||||
resp.ERROR(c, "订单已支付成功,无需重复支付!")
|
||||
orderNo, err := h.snowflake.Next(false)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
err = h.DB.Where("id", data.UserId).First(&user).Error
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新扫码状态
|
||||
h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
|
||||
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||
var payURL, returnURL, notifyURL string
|
||||
switch data.PayWay {
|
||||
case "alipay":
|
||||
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
|
||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
|
||||
}
|
||||
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
|
||||
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
money := fmt.Sprintf("%.2f", amount)
|
||||
if data.Device == "wechat" {
|
||||
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: money,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
} else {
|
||||
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: money,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
}
|
||||
|
||||
if payWay == "alipay" { // 支付宝
|
||||
amount := fmt.Sprintf("%.2f", order.Amount)
|
||||
uri, err := h.alipayService.PayUrlMobile(order.OrderNo, amount, order.Subject)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(302, uri)
|
||||
return
|
||||
} else if payWay == "hupi" { // 虎皮椒支付
|
||||
params := payment.HuPiPayReq{
|
||||
Version: "1.1",
|
||||
TradeOrderId: orderNo,
|
||||
TotalFee: fmt.Sprintf("%f", order.Amount),
|
||||
Title: order.Subject,
|
||||
NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL,
|
||||
WapName: "极客学长",
|
||||
break
|
||||
case "wechat":
|
||||
if h.App.Config.WechatPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
|
||||
}
|
||||
if data.Device == "wechat" {
|
||||
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: int(amount * 100),
|
||||
Subject: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ClientIP: c.ClientIP(),
|
||||
})
|
||||
} else {
|
||||
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: int(amount * 100),
|
||||
Subject: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
}
|
||||
r, err := h.huPiPayService.Pay(params)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(302, r.URL)
|
||||
}
|
||||
resp.ERROR(c, "Invalid operations")
|
||||
}
|
||||
|
||||
// PayQrcode 生成支付 URL 二维码
|
||||
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way"` // 支付方式
|
||||
ProductId uint `json:"product_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var product model.Product
|
||||
res := h.DB.First(&product, data.ProductId)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Product not found")
|
||||
return
|
||||
}
|
||||
|
||||
orderNo, err := h.snowflake.Next(false)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
var payWay string
|
||||
var notifyURL string
|
||||
switch data.PayWay {
|
||||
break
|
||||
case "hupi":
|
||||
payWay = PayWayXunHu.Value
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
break
|
||||
case "payjs":
|
||||
payWay = PayWayJs.Value
|
||||
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||
break
|
||||
case "alipay":
|
||||
payWay = PayWayAlipay.Value
|
||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||
break
|
||||
default:
|
||||
payWay = PayWayWechat.Value
|
||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
||||
|
||||
}
|
||||
// 创建订单
|
||||
remark := types.OrderRemark{
|
||||
Days: product.Days,
|
||||
Power: product.Power,
|
||||
Name: product.Name,
|
||||
Price: product.Price,
|
||||
Discount: product.Discount,
|
||||
}
|
||||
|
||||
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||
order := model.Order{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
ProductId: product.Id,
|
||||
OrderNo: orderNo,
|
||||
Subject: product.Name,
|
||||
Amount: amount,
|
||||
Status: types.OrderNotPaid,
|
||||
PayWay: payWay,
|
||||
Remark: utils.JsonEncode(remark),
|
||||
}
|
||||
res = h.DB.Create(&order)
|
||||
if res.Error != nil || res.RowsAffected == 0 {
|
||||
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// PayJs 单独处理,只能用官方生成的二维码
|
||||
if data.PayWay == "payjs" {
|
||||
params := payment.JPayReq{
|
||||
TotalFee: int(math.Ceil(order.Amount * 100)),
|
||||
OutTradeNo: order.OrderNo,
|
||||
Subject: product.Name,
|
||||
}
|
||||
r := h.jsPayService.Pay(params)
|
||||
if r.IsOK() {
|
||||
resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
|
||||
return
|
||||
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
} else {
|
||||
resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
|
||||
return
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
|
||||
}
|
||||
}
|
||||
|
||||
var logo string
|
||||
if data.PayWay == "alipay" {
|
||||
logo = "res/img/alipay.jpg"
|
||||
} else if data.PayWay == "hupi" {
|
||||
if h.App.Config.HuPiPayConfig.Name == "wechat" {
|
||||
logo = "res/img/wechat-pay.jpg"
|
||||
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
|
||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||
} else {
|
||||
logo = "res/img/alipay.jpg"
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
} else if data.PayWay == "wechat" {
|
||||
logo = "res/img/wechat-pay.jpg"
|
||||
}
|
||||
|
||||
file, err := h.fs.Open(logo)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with open qrcode log file: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
parse, err := url.Parse(notifyURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
timestamp := time.Now().Unix()
|
||||
signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
|
||||
sign := utils.Sha256(signStr)
|
||||
var imageURL string
|
||||
if data.PayWay == "wechat" {
|
||||
payUrl, err := h.wechatPayService.PayUrlNative(order.OrderNo, int(math.Floor(order.Amount*100)), product.Name)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generating wechat payment qrcode: "+err.Error())
|
||||
return
|
||||
} else {
|
||||
imageURL = payUrl
|
||||
}
|
||||
} else {
|
||||
imageURL = fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
|
||||
}
|
||||
imgData, err := utils.GenQrcode(imageURL, 400, file)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
imgDataBase64 := base64.StdEncoding.EncodeToString(imgData)
|
||||
resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
|
||||
}
|
||||
|
||||
// Mobile 移动端支付
|
||||
func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way"` // 支付方式
|
||||
ProductId uint `json:"product_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var product model.Product
|
||||
res := h.DB.First(&product, data.ProductId)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Product not found")
|
||||
return
|
||||
}
|
||||
|
||||
orderNo, err := h.snowflake.Next(false)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||
var payWay string
|
||||
var notifyURL, returnURL string
|
||||
var payURL string
|
||||
switch data.PayWay {
|
||||
case "hupi":
|
||||
payWay = PayWayXunHu.Name
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||
parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
|
||||
baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
|
||||
params := payment.HuPiPayReq{
|
||||
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
|
||||
Version: "1.1",
|
||||
TradeOrderId: orderNo,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
Title: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ReturnURL: returnURL,
|
||||
CallbackURL: returnURL,
|
||||
WapName: "极客学长",
|
||||
WapUrl: baseURL,
|
||||
Type: "WAP",
|
||||
}
|
||||
r, err := h.huPiPayService.Pay(params)
|
||||
WapName: "GeekAI助手",
|
||||
})
|
||||
if err != nil {
|
||||
errMsg := "error with generating Pay Hupi URL: " + err.Error()
|
||||
logger.Error(errMsg)
|
||||
resp.ERROR(c, errMsg)
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
payURL = r.URL
|
||||
case "payjs":
|
||||
payWay = PayWayJs.Name
|
||||
notifyURL = h.App.Config.JPayConfig.NotifyURL
|
||||
returnURL = h.App.Config.JPayConfig.ReturnURL
|
||||
totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
|
||||
params := url.Values{}
|
||||
params.Add("total_fee", fmt.Sprintf("%d", totalFee))
|
||||
params.Add("out_trade_no", orderNo)
|
||||
params.Add("body", product.Name)
|
||||
params.Add("notify_url", notifyURL)
|
||||
params.Add("auto", "0")
|
||||
payURL = h.jsPayService.PayH5(params)
|
||||
case "alipay":
|
||||
payWay = PayWayAlipay.Name
|
||||
payURL, err = h.alipayService.PayUrlMobile(orderNo, fmt.Sprintf("%.2f", amount), product.Name)
|
||||
break
|
||||
case "geek":
|
||||
if h.App.Config.GeekPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
|
||||
}
|
||||
if h.App.Config.GeekPayConfig.ReturnURL != "" {
|
||||
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
|
||||
}
|
||||
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
|
||||
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
params := payment.GeekPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Method: "web",
|
||||
Name: product.Name,
|
||||
Money: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
Type: data.PayType,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
|
||||
res, err := h.geekPayService.Pay(params)
|
||||
if err != nil {
|
||||
errMsg := "error with generating Alipay URL: " + err.Error()
|
||||
resp.ERROR(c, errMsg)
|
||||
return
|
||||
}
|
||||
case "wechat":
|
||||
payWay = PayWayWechat.Name
|
||||
payURL, err = h.wechatPayService.PayUrlH5(orderNo, int(amount*100), product.Name, c.ClientIP())
|
||||
if err != nil {
|
||||
errMsg := "error with generating Wechat URL: " + err.Error()
|
||||
logger.Error(errMsg)
|
||||
resp.ERROR(c, errMsg)
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
payURL = res.PayURL
|
||||
default:
|
||||
resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
|
||||
resp.ERROR(c, "不支持的支付渠道")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建订单
|
||||
remark := types.OrderRemark{
|
||||
Days: product.Days,
|
||||
@@ -390,7 +241,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
Price: product.Price,
|
||||
Discount: product.Discount,
|
||||
}
|
||||
|
||||
order := model.Order{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
@@ -399,26 +249,24 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
|
||||
Subject: product.Name,
|
||||
Amount: amount,
|
||||
Status: types.OrderNotPaid,
|
||||
PayWay: payWay,
|
||||
PayWay: data.PayWay,
|
||||
PayType: data.PayType,
|
||||
Remark: utils.JsonEncode(remark),
|
||||
}
|
||||
res = h.DB.Create(&order)
|
||||
if res.Error != nil || res.RowsAffected == 0 {
|
||||
resp.ERROR(c, "error with create order: "+res.Error.Error())
|
||||
err = h.DB.Create(&order).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with create order: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"url": payURL, "order_no": orderNo})
|
||||
resp.SUCCESS(c, payURL)
|
||||
}
|
||||
|
||||
// 异步通知回调公共逻辑
|
||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
var order model.Order
|
||||
res := h.DB.Where("order_no = ?", orderNo).First(&order)
|
||||
if res.Error != nil {
|
||||
err := fmt.Errorf("error with fetch order: %v", res.Error)
|
||||
logger.Error(err)
|
||||
return err
|
||||
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with fetch order: %v", err)
|
||||
}
|
||||
|
||||
h.lock.Lock()
|
||||
@@ -430,45 +278,24 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
}
|
||||
|
||||
var user model.User
|
||||
res = h.DB.First(&user, order.UserId)
|
||||
if res.Error != nil {
|
||||
err := fmt.Errorf("error with fetch user info: %v", res.Error)
|
||||
logger.Error(err)
|
||||
return err
|
||||
err = h.DB.First(&user, order.UserId).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with fetch user info: %v", err)
|
||||
}
|
||||
|
||||
var remark types.OrderRemark
|
||||
err := utils.JsonDecode(order.Remark, &remark)
|
||||
err = utils.JsonDecode(order.Remark, &remark)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error with decode order remark: %v", err)
|
||||
logger.Error(err)
|
||||
return err
|
||||
return fmt.Errorf("error with decode order remark: %v", err)
|
||||
}
|
||||
|
||||
var opt string
|
||||
var power int
|
||||
if remark.Days > 0 { // VIP 充值
|
||||
if user.ExpiredTime >= time.Now().Unix() {
|
||||
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
|
||||
opt = "VIP充值,VIP 没到期,只延期不增加算力"
|
||||
} else {
|
||||
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
|
||||
user.Power += h.App.SysConfig.VipMonthPower
|
||||
power = h.App.SysConfig.VipMonthPower
|
||||
opt = "VIP充值"
|
||||
}
|
||||
user.Vip = true
|
||||
} else { // 充值点卡,直接增加次数即可
|
||||
user.Power += remark.Power
|
||||
opt = "点卡充值"
|
||||
power = remark.Power
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
res = h.DB.Updates(&user)
|
||||
if res.Error != nil {
|
||||
err := fmt.Errorf("error with update user info: %v", res.Error)
|
||||
logger.Error(err)
|
||||
// 增加用户算力
|
||||
err = h.userService.IncreasePower(int(order.UserId), remark.Power, model.PowerLog{
|
||||
Type: types.PowerRecharge,
|
||||
Model: order.PayWay,
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -476,29 +303,16 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
order.PayTime = time.Now().Unix()
|
||||
order.Status = types.OrderPaidSuccess
|
||||
order.TradeNo = tradeNo
|
||||
res = h.DB.Updates(&order)
|
||||
if res.Error != nil {
|
||||
err := fmt.Errorf("error with update order info: %v", res.Error)
|
||||
logger.Error(err)
|
||||
return err
|
||||
err = h.DB.Updates(&order).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with update order info: %v", err)
|
||||
}
|
||||
|
||||
// 更新产品销量
|
||||
h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
|
||||
|
||||
// 记录算力充值日志
|
||||
if power > 0 {
|
||||
h.DB.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerRecharge,
|
||||
Amount: power,
|
||||
Balance: user.Power,
|
||||
Mark: types.PowerAdd,
|
||||
Model: order.PayWay,
|
||||
Remark: fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
err = h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).
|
||||
UpdateColumn("sales", gorm.Expr("sales + ?", 1)).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with update product sales: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -506,20 +320,22 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
|
||||
// GetPayWays 获取支付方式
|
||||
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
||||
data := gin.H{}
|
||||
payWays := make([]gin.H, 0)
|
||||
if h.App.Config.AlipayConfig.Enabled {
|
||||
data["alipay"] = gin.H{"name": "alipay"}
|
||||
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
|
||||
}
|
||||
if h.App.Config.HuPiPayConfig.Enabled {
|
||||
data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
|
||||
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
|
||||
}
|
||||
if h.App.Config.JPayConfig.Enabled {
|
||||
data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
|
||||
if h.App.Config.GeekPayConfig.Enabled {
|
||||
for _, v := range h.App.Config.GeekPayConfig.Methods {
|
||||
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
|
||||
}
|
||||
}
|
||||
if h.App.Config.WechatPayConfig.Enabled {
|
||||
data["wechat"] = gin.H{"name": "wechat"}
|
||||
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
|
||||
}
|
||||
resp.SUCCESS(c, data)
|
||||
resp.SUCCESS(c, payWays)
|
||||
}
|
||||
|
||||
// HuPiPayNotify 虎皮椒支付异步回调
|
||||
@@ -532,15 +348,17 @@ func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
|
||||
|
||||
orderNo := c.Request.Form.Get("trade_order_id")
|
||||
tradeNo := c.Request.Form.Get("open_order_id")
|
||||
logger.Infof("收到虎皮椒订单支付回调,订单 NO:%s,交易流水号:%s", orderNo, tradeNo)
|
||||
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
|
||||
|
||||
if err = h.huPiPayService.Check(tradeNo); err != nil {
|
||||
if err = h.huPiPayService.Check(orderNo); err != nil {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.notify(orderNo, tradeNo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
@@ -556,18 +374,18 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO:验证交易签名
|
||||
res := h.alipayService.TradeVerify(c.Request)
|
||||
logger.Infof("验证支付结果:%+v", res)
|
||||
if !res.Success() {
|
||||
logger.Error("订单校验失败:", res.Message)
|
||||
result := h.alipayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", result.Message)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
tradeNo := c.Request.Form.Get("trade_no")
|
||||
err = h.notify(res.OutTradeNo, tradeNo)
|
||||
err = h.notify(result.OutTradeNo, tradeNo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
@@ -575,33 +393,30 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// PayJsNotify PayJs 支付异步回调
|
||||
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
if err != nil {
|
||||
// GeekPayNotify 支付异步回调
|
||||
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
||||
var params = make(map[string]string)
|
||||
for k := range c.Request.URL.Query() {
|
||||
params[k] = c.Query(k)
|
||||
}
|
||||
|
||||
logger.Infof("收到GeekPay订单支付回调:%+v", params)
|
||||
// 检查支付状态
|
||||
if params["trade_status"] != "TRADE_SUCCESS" {
|
||||
c.String(http.StatusOK, "success")
|
||||
return
|
||||
}
|
||||
|
||||
sign := h.geekPayService.Sign(params)
|
||||
if sign != c.Query("sign") {
|
||||
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
orderNo := c.Request.Form.Get("out_trade_no")
|
||||
returnCode := c.Request.Form.Get("return_code")
|
||||
logger.Infof("收到PayJs订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
|
||||
// 支付失败
|
||||
if returnCode != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
// 校验订单支付状态
|
||||
tradeNo := c.Request.Form.Get("payjs_order_id")
|
||||
err = h.jsPayService.TradeVerify(tradeNo)
|
||||
if err != nil {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.notify(orderNo, tradeNo)
|
||||
err := h.notify(params["out_trade_no"], params["trade_no"])
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
@@ -618,6 +433,7 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
||||
}
|
||||
|
||||
result := h.wechatPayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到微信商号订单支付回调:%+v", result)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
@@ -629,6 +445,7 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
||||
|
||||
err = h.notify(result.OutTradeNo, result.TradeId)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -19,11 +19,8 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
@@ -59,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer,
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *SdJobHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.sdService.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -168,17 +144,13 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.sdService.PushTask(types.SdTask{
|
||||
Id: int(job.Id),
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
Id: int(job.Id),
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
client := h.sdService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
// update user's power
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -260,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if item.Progress < 100 {
|
||||
// 从 leveldb 中获取图片预览数据
|
||||
var imageData string
|
||||
err = h.leveldb.Get(item.TaskId, &imageData)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + imageData
|
||||
}
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
|
||||
@@ -76,6 +76,20 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||
return
|
||||
}
|
||||
// 检查邮箱后缀是否在白名单
|
||||
if len(h.App.SysConfig.EmailWhiteList) > 0 {
|
||||
inWhiteList := false
|
||||
for _, suffix := range h.App.SysConfig.EmailWhiteList {
|
||||
if strings.HasSuffix(data.Receiver, suffix) {
|
||||
inWhiteList = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !inWhiteList {
|
||||
resp.ERROR(c, "邮箱后缀不在白名单中")
|
||||
return
|
||||
}
|
||||
}
|
||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||
} else {
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
||||
|
||||
@@ -19,9 +19,7 @@ import (
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -44,30 +42,10 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *SunoHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.sunoService.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Instrumental bool `json:"instrumental"`
|
||||
Lyrics string `json:"lyrics"`
|
||||
@@ -86,6 +64,17 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.SunoPower {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
|
||||
// 歌曲拼接
|
||||
if data.SongId != "" && data.Type == 3 {
|
||||
var song model.SunoJob
|
||||
@@ -127,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
// 创建任务
|
||||
h.sunoService.PushTask(types.SunoTask{
|
||||
ClientId: data.ClientId,
|
||||
Id: job.Id,
|
||||
UserId: job.UserId,
|
||||
Type: job.Type,
|
||||
@@ -143,7 +133,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
})
|
||||
|
||||
// update user's power
|
||||
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
|
||||
CreatedAt: time.Now(),
|
||||
@@ -153,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client := h.sunoService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -225,6 +211,13 @@ func (h *SunoHandler) Remove(c *gin.Context) {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败,或者超时的任务才能删除
|
||||
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
|
||||
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除任务
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Delete(&job).Error; err != nil {
|
||||
@@ -233,18 +226,16 @@ func (h *SunoHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果任务未完成,或者任务失败,则恢复用户算力
|
||||
if job.Progress != 100 {
|
||||
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: job.ModelName,
|
||||
Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 恢复用户算力
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: job.ModelName,
|
||||
Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
@@ -372,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
type TestHandler struct {
|
||||
db *gorm.DB
|
||||
snowflake *service.Snowflake
|
||||
js *payment.JPayService
|
||||
js *payment.GeekPayService
|
||||
}
|
||||
|
||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.JPayService) *TestHandler {
|
||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
|
||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||
}
|
||||
|
||||
func (h *TestHandler) SseTest(c *gin.Context) {
|
||||
//c.Header("Content-Type", "text/event-stream")
|
||||
//c.Header("Body-Type", "text/event-stream")
|
||||
//c.Header("Cache-Control", "no-cache")
|
||||
//c.Header("Connection", "keep-alive")
|
||||
//
|
||||
|
||||
@@ -74,6 +74,20 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
data.Password = strings.TrimSpace(data.Password)
|
||||
if len(data.Password) < 8 {
|
||||
resp.ERROR(c, "密码长度不能少于8个字符")
|
||||
@@ -230,8 +244,10 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
|
||||
needVerify, err := h.redis.Get(c, verifyKey).Bool()
|
||||
|
||||
if h.App.SysConfig.EnabledVerify {
|
||||
if h.App.SysConfig.EnabledVerify && needVerify {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
@@ -247,12 +263,14 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
var user model.User
|
||||
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||
if res.Error != nil {
|
||||
h.redis.Set(c, verifyKey, true, 0)
|
||||
resp.ERROR(c, "用户名不存在")
|
||||
return
|
||||
}
|
||||
|
||||
password := utils.GenPassword(data.Password, user.Salt)
|
||||
if password != user.Password {
|
||||
h.redis.Set(c, verifyKey, true, 0)
|
||||
resp.ERROR(c, "用户名或密码错误")
|
||||
return
|
||||
}
|
||||
@@ -285,11 +303,13 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 保存到 redis
|
||||
key := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
// 移除登录行为验证码
|
||||
h.redis.Del(c, verifyKey)
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
@@ -587,7 +607,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
|
||||
session = session.Where("email", data.Email)
|
||||
key = CodeStorePrefix + data.Email
|
||||
} else if data.Type == "mobile" {
|
||||
session = session.Where("mobile", data.Email)
|
||||
session = session.Where("mobile", data.Mobile)
|
||||
key = CodeStorePrefix + data.Mobile
|
||||
} else {
|
||||
resp.ERROR(c, "验证类别错误")
|
||||
|
||||
@@ -19,9 +19,8 @@ import (
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VideoHandler struct {
|
||||
@@ -43,30 +42,10 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *VideoHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.videoService.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImg string `json:"first_frame_img,omitempty"`
|
||||
EndFrameImg string `json:"end_frame_img,omitempty"`
|
||||
@@ -77,6 +56,18 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.LumaPower {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
|
||||
if data.Prompt == "" {
|
||||
resp.ERROR(c, "prompt is needed")
|
||||
return
|
||||
@@ -105,15 +96,16 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
// 创建任务
|
||||
h.videoService.PushTask(types.VideoTask{
|
||||
Id: job.Id,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
ClientId: data.ClientId,
|
||||
Id: job.Id,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
})
|
||||
|
||||
// update user's power
|
||||
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "luma",
|
||||
Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id),
|
||||
@@ -122,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
client := h.videoService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -184,6 +171,12 @@ func (h *VideoHandler) Remove(c *gin.Context) {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 只有失败或者超时的任务才能删除
|
||||
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
|
||||
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除任务
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Delete(&job).Error; err != nil {
|
||||
@@ -192,18 +185,16 @@ func (h *VideoHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果任务未完成,或者任务失败,则恢复用户算力
|
||||
if job.Progress != 100 {
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "luma",
|
||||
Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 恢复算力
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "luma",
|
||||
Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
|
||||
145
api/handler/ws_handler.go
Normal file
145
api/handler/ws_handler.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package handler
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Websocket 连接处理 handler
|
||||
|
||||
type WebsocketHandler struct {
|
||||
BaseHandler
|
||||
wsService *service.WebsocketService
|
||||
chatHandler *ChatHandler
|
||||
}
|
||||
|
||||
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
|
||||
return &WebsocketHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
chatHandler: chatHandler,
|
||||
wsService: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
clientId := c.Query("client_id")
|
||||
client := types.NewWsClient(ws, clientId)
|
||||
userId := h.GetLoginUserId(c)
|
||||
if userId == 0 {
|
||||
_ = client.Send([]byte("Invalid user_id"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
_ = client.Send([]byte("Invalid user_id"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
h.wsService.Clients.Put(clientId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||
client.Close()
|
||||
h.wsService.Clients.Delete(clientId)
|
||||
break
|
||||
}
|
||||
|
||||
var message types.InputMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("Receive a message:%+v", message)
|
||||
if message.Type == types.MsgTypePing {
|
||||
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||||
continue
|
||||
}
|
||||
|
||||
// 当前只处理聊天消息,其他消息全部丢弃
|
||||
var chatMessage types.ChatMessage
|
||||
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
|
||||
if err != nil || message.Channel != types.ChChat {
|
||||
logger.Warnf("invalid message body:%+v", message.Body)
|
||||
continue
|
||||
}
|
||||
var chatRole model.ChatRole
|
||||
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
|
||||
continue
|
||||
}
|
||||
// if the role bind a model_id, use role's bind model_id
|
||||
if chatRole.ModelId > 0 {
|
||||
chatMessage.RoleId = chatRole.ModelId
|
||||
}
|
||||
// get model info
|
||||
var chatModel model.ChatModel
|
||||
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
|
||||
if err != nil || chatModel.Enabled == false {
|
||||
utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
|
||||
continue
|
||||
}
|
||||
|
||||
session := &types.ChatSession{
|
||||
ClientIP: c.ClientIP(),
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
// use old chat data override the chat model and role ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
chatModel.Id = chat.ModelId
|
||||
chatMessage.RoleId = int(chat.RoleId)
|
||||
}
|
||||
|
||||
session.ChatId = chatMessage.ChatId
|
||||
session.Tools = chatMessage.Tools
|
||||
session.Stream = chatMessage.Stream
|
||||
// 复制模型数据
|
||||
err = utils.CopyObject(chatModel, &session.Model)
|
||||
if err != nil {
|
||||
logger.Error(err, chatModel)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
|
||||
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.SendAndFlush(client, err.Error())
|
||||
} else {
|
||||
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
|
||||
logger.Infof("回答完毕: %v", message.Body)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
64
api/main.go
64
api/main.go
@@ -14,7 +14,6 @@ import (
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/handler/admin"
|
||||
"geekai/handler/chatimpl"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
@@ -128,7 +127,7 @@ func main() {
|
||||
// 创建控制器
|
||||
fx.Provide(handler.NewChatRoleHandler),
|
||||
fx.Provide(handler.NewUserHandler),
|
||||
fx.Provide(chatimpl.NewChatHandler),
|
||||
fx.Provide(handler.NewChatHandler),
|
||||
fx.Provide(handler.NewNetHandler),
|
||||
fx.Provide(handler.NewSmsHandler),
|
||||
fx.Provide(handler.NewRedeemHandler),
|
||||
@@ -146,7 +145,7 @@ func main() {
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
fx.Provide(admin.NewApiKeyHandler),
|
||||
fx.Provide(admin.NewUserHandler),
|
||||
fx.Provide(admin.NewChatRoleHandler),
|
||||
fx.Provide(admin.NewChatAppHandler),
|
||||
fx.Provide(admin.NewRedeemHandler),
|
||||
fx.Provide(admin.NewDashboardHandler),
|
||||
fx.Provide(admin.NewChatModelHandler),
|
||||
@@ -226,8 +225,9 @@ func main() {
|
||||
|
||||
// 注册路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||
group := s.Engine.Group("/api/role/")
|
||||
group := s.Engine.Group("/api/app/")
|
||||
group.GET("list", h.List)
|
||||
group.GET("list/user", h.ListByUser)
|
||||
group.POST("update", h.UpdateRole)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
||||
@@ -245,9 +245,8 @@ func main() {
|
||||
group.GET("clogin", h.CLogin)
|
||||
group.GET("clogin/callback", h.CLoginCallback)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
||||
group := s.Engine.Group("/api/chat/")
|
||||
group.Any("new", h.ChatHandle)
|
||||
group.GET("list", h.List)
|
||||
group.GET("detail", h.Detail)
|
||||
group.POST("update", h.Update)
|
||||
@@ -280,7 +279,6 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||
group := s.Engine.Group("/api/mj/")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("image", h.Image)
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
@@ -291,7 +289,6 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||
group := s.Engine.Group("/api/sd")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
@@ -306,11 +303,12 @@ func main() {
|
||||
|
||||
// 管理后台控制器
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
group.POST("config/update", h.Update)
|
||||
group.GET("config/get", h.Get)
|
||||
group := s.Engine.Group("/api/admin/config")
|
||||
group.POST("update", h.Update)
|
||||
group.GET("get", h.Get)
|
||||
group.POST("active", h.Active)
|
||||
group.GET("config/get/license", h.GetLicense)
|
||||
group.GET("fixData", h.FixData)
|
||||
group.GET("license", h.GetLicense)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
@@ -338,7 +336,7 @@ func main() {
|
||||
group.GET("loginLog", h.LoginLog)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatRoleHandler) {
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
||||
group := s.Engine.Group("/api/admin/role/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
@@ -371,14 +369,12 @@ func main() {
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
||||
group := s.Engine.Group("/api/payment/")
|
||||
group.GET("doPay", h.DoPay)
|
||||
group.POST("doPay", h.Pay)
|
||||
group.GET("payWays", h.GetPayWays)
|
||||
group.POST("qrcode", h.PayQrcode)
|
||||
group.POST("mobile", h.Mobile)
|
||||
group.POST("alipay/notify", h.AlipayNotify)
|
||||
group.POST("hupipay/notify", h.HuPiPayNotify)
|
||||
group.POST("payjs/notify", h.PayJsNotify)
|
||||
group.POST("wechat/notify", h.WechatPayNotify)
|
||||
group.POST("notify/alipay", h.AlipayNotify)
|
||||
group.GET("notify/geek", h.GeekPayNotify)
|
||||
group.POST("notify/wechat", h.WechatPayNotify)
|
||||
group.POST("notify/hupi", h.HuPiPayNotify)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||
group := s.Engine.Group("/api/admin/product/")
|
||||
@@ -467,13 +463,11 @@ func main() {
|
||||
}),
|
||||
fx.Provide(handler.NewMarkMapHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||
group := s.Engine.Group("/api/markMap/")
|
||||
group.Any("client", h.Client)
|
||||
s.Engine.POST("/api/markMap/gen", h.Generate)
|
||||
}),
|
||||
fx.Provide(handler.NewDallJobHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||
group := s.Engine.Group("/api/dall")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
@@ -483,7 +477,6 @@ func main() {
|
||||
fx.Provide(handler.NewSunoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
||||
group := s.Engine.Group("/api/suno")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("create", h.Create)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
@@ -496,22 +489,41 @@ func main() {
|
||||
fx.Provide(handler.NewVideoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
||||
group := s.Engine.Group("/api/video")
|
||||
group.Any("client", h.Client)
|
||||
group.POST("luma/create", h.LumaCreate)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
fx.Provide(admin.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/admin/app/type")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
}),
|
||||
fx.Provide(handler.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/app/type")
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Provide(handler.NewTestHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||
group := s.Engine.Group("/api/test")
|
||||
group.Any("sse", h.PostTest, h.SseTest)
|
||||
}),
|
||||
fx.Provide(service.NewWebsocketService),
|
||||
fx.Provide(handler.NewWebsocketHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
|
||||
s.Engine.Any("/api/ws", h.Client)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
go func() {
|
||||
err := s.Run(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
logger.Error(err)
|
||||
os.Exit(0)
|
||||
}
|
||||
}()
|
||||
}),
|
||||
|
||||
BIN
api/res/img/geek-pay.jpg
Normal file
BIN
api/res/img/geek-pay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
BIN
api/res/img/qq-pay.jpg
Normal file
BIN
api/res/img/qq-pay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
@@ -34,19 +34,21 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
userService *service.UserService
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +69,7 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
s.clientIds[task.JobId] = task.ClientId
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
@@ -74,7 +77,7 @@ func (s *Service) Run() {
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
@@ -158,7 +161,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
Quality: task.Quality,
|
||||
}
|
||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
|
||||
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
SetErrorResult(&errRes).
|
||||
@@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
|
||||
@@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
||||
@@ -28,18 +28,20 @@ type Service struct {
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
uploaderManager *oss.UploaderManager
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
|
||||
client: client,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploaderManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +58,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
@@ -65,7 +67,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
@@ -77,6 +79,7 @@ func (s *Service) Run() {
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
}
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
|
||||
var job model.MidJourneyJob
|
||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||
@@ -119,7 +122,7 @@ func (s *Service) Run() {
|
||||
// update the task progress
|
||||
s.db.Updates(&job)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
@@ -166,14 +169,12 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cli := s.Clients.Get(uint(message.UserId))
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = cli.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChMj, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -211,14 +212,11 @@ func (s *Service) DownloadImages() {
|
||||
v.ImgURL = imgURL
|
||||
s.db.Updates(&v)
|
||||
|
||||
cli := s.Clients.Get(uint(v.UserId))
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = cli.Send([]byte(service.TaskStatusFinished))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[v.Id],
|
||||
UserId: v.UserId,
|
||||
JobId: int(v.Id),
|
||||
Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
@@ -237,8 +235,8 @@ func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.MidJourneyJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
err := s.db.Where("progress < ?", 100).Find(&jobs).Error
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -251,6 +249,10 @@ func (s *Service) SyncTaskProgress() {
|
||||
continue
|
||||
}
|
||||
|
||||
if job.ChannelId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query task: %v", err)
|
||||
@@ -264,7 +266,11 @@ func (s *Service) SyncTaskProgress() {
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
logger.Errorf("task failed: %v", task.FailReason)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -289,7 +295,11 @@ func (s *Service) SyncTaskProgress() {
|
||||
if job.Progress == 100 {
|
||||
message = service.TaskStatusFinished
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: message})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
fileExt := utils.GetImgExt(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Content-Type"),
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
|
||||
|
||||
@@ -43,10 +43,8 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
||||
|
||||
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
||||
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
||||
SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2
|
||||
SetReturnUrl(config.ReturnURL). // 设置返回URL
|
||||
SetNotifyUrl(config.NotifyURL)
|
||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
||||
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
|
||||
|
||||
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
||||
@@ -55,23 +53,31 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
||||
return &AlipayService{config: &config, client: client}, nil
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", subject)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
bm.Set("quit_url", s.config.ReturnURL)
|
||||
bm.Set("total_amount", amount)
|
||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
||||
return s.client.TradeWapPay(context.Background(), bm)
|
||||
type AlipayParams struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
Subject string `json:"subject"`
|
||||
TotalFee string `json:"total_fee"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) {
|
||||
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", subject)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
bm.Set("total_amount", amount)
|
||||
bm.Set("subject", params.Subject)
|
||||
bm.Set("out_trade_no", params.OutTradeNo)
|
||||
bm.Set("quit_url", params.ReturnURL)
|
||||
bm.Set("total_amount", params.TotalFee)
|
||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", params.Subject)
|
||||
bm.Set("out_trade_no", params.OutTradeNo)
|
||||
bm.Set("total_amount", params.TotalFee)
|
||||
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
||||
return s.client.TradePagePay(context.Background(), bm)
|
||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
|
||||
139
api/service/payment/geekpay_service.go
Normal file
139
api/service/payment/geekpay_service.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GeekPayService Geek 支付服务
|
||||
type GeekPayService struct {
|
||||
config *types.GeekPayConfig
|
||||
}
|
||||
|
||||
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
|
||||
return &GeekPayService{
|
||||
config: &appConfig.GeekPayConfig,
|
||||
}
|
||||
}
|
||||
|
||||
type GeekPayParams struct {
|
||||
Method string `json:"method"` // 接口类型
|
||||
Device string `json:"device"` // 设备类型
|
||||
Type string `json:"type"` // 支付方式
|
||||
OutTradeNo string `json:"out_trade_no"` // 商户订单号
|
||||
Name string `json:"name"` // 商品名称
|
||||
Money string `json:"money"` // 商品金额
|
||||
ClientIP string `json:"clientip"` //用户IP地址
|
||||
SubOpenId string `json:"sub_openid"` // 微信用户 openid,仅小程序支付需要
|
||||
SubAppId string `json:"sub_appid"` // 小程序 AppId,仅小程序支付需要
|
||||
NotifyURL string `json:"notify_url"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
}
|
||||
|
||||
// Pay 支付订单
|
||||
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
||||
p := map[string]string{
|
||||
"pid": s.config.AppId,
|
||||
//"method": params.Method,
|
||||
"device": params.Device,
|
||||
"type": params.Type,
|
||||
"out_trade_no": params.OutTradeNo,
|
||||
"name": params.Name,
|
||||
"money": params.Money,
|
||||
"clientip": params.ClientIP,
|
||||
"notify_url": params.NotifyURL,
|
||||
"return_url": params.ReturnURL,
|
||||
"timestamp": fmt.Sprintf("%d", time.Now().Unix()),
|
||||
}
|
||||
p["sign"] = s.Sign(p)
|
||||
p["sign_type"] = "MD5"
|
||||
return s.sendRequest(s.config.ApiURL, p)
|
||||
}
|
||||
|
||||
func (s *GeekPayService) Sign(params map[string]string) string {
|
||||
// 按字母顺序排序参数
|
||||
var keys []string
|
||||
for k := range params {
|
||||
if params[k] == "" || k == "sign" || k == "sign_type" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// 构建待签名字符串
|
||||
var signStr strings.Builder
|
||||
for _, k := range keys {
|
||||
signStr.WriteString(k)
|
||||
signStr.WriteString("=")
|
||||
signStr.WriteString(params[k])
|
||||
signStr.WriteString("&")
|
||||
}
|
||||
signString := strings.TrimSuffix(signStr.String(), "&") + s.config.PrivateKey
|
||||
|
||||
return utils.Md5(signString)
|
||||
}
|
||||
|
||||
type GeekPayResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
PayURL string `json:"payurl"`
|
||||
QrCode string `json:"qrcode"`
|
||||
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
|
||||
}
|
||||
|
||||
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Add(k, v)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/mapi.php", endpoint)
|
||||
logger.Infof(apiURL)
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // 取消 SSL 证书验证
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
resp, err := client.PostForm(apiURL, form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
logger.Debugf(string(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var r GeekPayResp
|
||||
err = json.Unmarshal(body, &r)
|
||||
if err != nil {
|
||||
return nil, errors.New("当前支付渠道暂不支持")
|
||||
}
|
||||
if r.Code != 1 {
|
||||
return nil, errors.New(r.Msg)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
||||
}
|
||||
}
|
||||
|
||||
type HuPiPayReq struct {
|
||||
type HuPiPayParams struct {
|
||||
AppId string `json:"appid"`
|
||||
Version string `json:"version"`
|
||||
TradeOrderId string `json:"trade_order_id"`
|
||||
@@ -53,7 +53,7 @@ type HuPiPayReq struct {
|
||||
WapUrl string `json:"wap_url"`
|
||||
}
|
||||
|
||||
type HuPiResp struct {
|
||||
type HuPiPayResp struct {
|
||||
Openid interface{} `json:"openid"`
|
||||
UrlQrcode string `json:"url_qrcode"`
|
||||
URL string `json:"url"`
|
||||
@@ -62,7 +62,7 @@ type HuPiResp struct {
|
||||
}
|
||||
|
||||
// Pay 执行支付请求操作
|
||||
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
|
||||
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
|
||||
data := url.Values{}
|
||||
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
params.AppId = s.appId
|
||||
@@ -80,22 +80,22 @@ func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
|
||||
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
||||
resp, err := http.PostForm(apiURL, data)
|
||||
if err != nil {
|
||||
return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
|
||||
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
all, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
|
||||
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var res HuPiResp
|
||||
var res HuPiPayResp
|
||||
err = utils.JsonDecode(string(all), &res)
|
||||
if err != nil {
|
||||
return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
||||
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
||||
}
|
||||
|
||||
if res.ErrCode != 0 {
|
||||
return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
||||
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
@@ -127,10 +127,10 @@ func (s *HuPiPayService) Sign(params url.Values) string {
|
||||
}
|
||||
|
||||
// Check 校验订单状态
|
||||
func (s *HuPiPayService) Check(tradeNo string) error {
|
||||
func (s *HuPiPayService) Check(outTradeNo string) error {
|
||||
data := url.Values{}
|
||||
data.Add("appid", s.appId)
|
||||
data.Add("open_order_id", tradeNo)
|
||||
data.Add("out_trade_order", outTradeNo)
|
||||
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
data.Add("time", stamp)
|
||||
data.Add("nonce_str", stamp)
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JPayService struct {
|
||||
config *types.JPayConfig
|
||||
}
|
||||
|
||||
func NewJPayService(appConfig *types.AppConfig) *JPayService {
|
||||
return &JPayService{
|
||||
config: &appConfig.JPayConfig,
|
||||
}
|
||||
}
|
||||
|
||||
type JPayReq struct {
|
||||
TotalFee int `json:"total_fee"`
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
Subject string `json:"body"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
ReturnURL string `json:"callback_url"`
|
||||
}
|
||||
type JPayReps struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
OrderId string `json:"payjs_order_id"`
|
||||
ReturnCode int `json:"return_code"`
|
||||
ReturnMsg string `json:"return_msg"`
|
||||
Sign string `json:"Sign"`
|
||||
TotalFee string `json:"total_fee"`
|
||||
CodeUrl string `json:"code_url,omitempty"`
|
||||
Qrcode string `json:"qrcode,omitempty"`
|
||||
}
|
||||
|
||||
func (r JPayReps) IsOK() bool {
|
||||
return r.ReturnMsg == "SUCCESS"
|
||||
}
|
||||
|
||||
func (js *JPayService) Pay(param JPayReq) JPayReps {
|
||||
param.NotifyURL = js.config.NotifyURL
|
||||
var p = url.Values{}
|
||||
encode := utils.JsonEncode(param)
|
||||
m := make(map[string]interface{})
|
||||
_ = utils.JsonDecode(encode, &m)
|
||||
for k, v := range m {
|
||||
p.Add(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
p.Add("mchid", js.config.AppId)
|
||||
|
||||
p.Add("sign", js.sign(p))
|
||||
|
||||
cli := http.Client{}
|
||||
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
|
||||
r, err := cli.PostForm(apiURL, p)
|
||||
if err != nil {
|
||||
return JPayReps{ReturnMsg: err.Error()}
|
||||
}
|
||||
defer r.Body.Close()
|
||||
bs, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return JPayReps{ReturnMsg: err.Error()}
|
||||
}
|
||||
|
||||
var data JPayReps
|
||||
err = utils.JsonDecode(string(bs), &data)
|
||||
if err != nil {
|
||||
return JPayReps{ReturnMsg: err.Error()}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (js *JPayService) PayH5(p url.Values) string {
|
||||
p.Add("mchid", js.config.AppId)
|
||||
p.Add("sign", js.sign(p))
|
||||
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
|
||||
}
|
||||
|
||||
func (js *JPayService) sign(params url.Values) string {
|
||||
params.Del(`sign`)
|
||||
var keys = make([]string, 0, 0)
|
||||
for key := range params {
|
||||
if params.Get(key) != `` {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var pList = make([]string, 0, 0)
|
||||
for _, key := range keys {
|
||||
var value = strings.TrimSpace(params.Get(key))
|
||||
if len(value) > 0 {
|
||||
pList = append(pList, key+"="+value)
|
||||
}
|
||||
}
|
||||
var src = strings.Join(pList, "&")
|
||||
src += "&key=" + js.config.PrivateKey
|
||||
|
||||
md5bs := md5.Sum([]byte(src))
|
||||
md5res := hex.EncodeToString(md5bs[:])
|
||||
return strings.ToUpper(md5res)
|
||||
}
|
||||
|
||||
// TradeVerify 查询订单支付状态
|
||||
// @param tradeNo 支付平台交易 ID
|
||||
func (js *JPayService) TradeVerify(tradeNo string) error {
|
||||
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
|
||||
params := url.Values{}
|
||||
params.Add("payjs_order_id", tradeNo)
|
||||
params.Add("sign", js.sign(params))
|
||||
data := strings.NewReader(params.Encode())
|
||||
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with http reqeust: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var r struct {
|
||||
ReturnCode int `json:"return_code"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
err = utils.JsonDecode(string(body), &r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
|
||||
if r.ReturnCode == 1 && r.Status == 1 {
|
||||
return nil
|
||||
} else {
|
||||
logger.Errorf("PayJs 支付验证响应:%s", string(body))
|
||||
return errors.New("order not paid")
|
||||
}
|
||||
}
|
||||
@@ -46,18 +46,27 @@ func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
|
||||
return &WechatPayService{config: &config, client: client}, nil
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) {
|
||||
type WechatPayParams struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
TotalFee int `json:"total_fee"`
|
||||
Subject string `json:"subject"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", subject).
|
||||
Set("out_trade_no", outTradeNo).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", s.config.NotifyURL).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", amount).
|
||||
bm.Set("total", params.TotalFee).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
|
||||
@@ -71,22 +80,22 @@ func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject s
|
||||
return wxRsp.Response.CodeUrl, nil
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) {
|
||||
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", subject).
|
||||
Set("out_trade_no", outTradeNo).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", s.config.NotifyURL).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", amount).
|
||||
bm.Set("total", params.TotalFee).
|
||||
Set("currency", "CNY")
|
||||
}).
|
||||
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("payer_client_ip", ip).
|
||||
bm.Set("payer_client_ip", params.ClientIP).
|
||||
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("type", "Wap")
|
||||
})
|
||||
|
||||
@@ -33,18 +33,16 @@ type Service struct {
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
leveldb *store.LevelDB
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C(),
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
@@ -62,7 +60,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
} else {
|
||||
@@ -72,7 +70,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate negative prompt
|
||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.NegPrompt = content
|
||||
} else {
|
||||
@@ -90,7 +88,7 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -126,9 +124,8 @@ type Txt2ImgResp struct {
|
||||
|
||||
// TaskProgressResp 任务进度响应实体
|
||||
type TaskProgressResp struct {
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
CurrentImage string `json:"current_image"`
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
@@ -213,9 +210,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
@@ -223,11 +218,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
if err == nil && resp.Progress > 0 {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@@ -267,14 +258,12 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSd, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
|
||||
|
||||
func (s *SmtpService) SendVerifyCode(to string, code int) error {
|
||||
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
|
||||
body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
|
||||
body := fmt.Sprintf("【%s】:您的验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
|
||||
|
||||
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
|
||||
if s.config.UseTls {
|
||||
|
||||
@@ -34,17 +34,19 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[string]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
wsService: wsService,
|
||||
clientIds: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +98,7 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -105,6 +107,7 @@ func (s *Service) Run() {
|
||||
"task_id": r.Data,
|
||||
"channel": r.Channel,
|
||||
})
|
||||
s.clientIds[r.Data] = task.ClientId
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
logger.Debugf("client id: %+v", s.wsService.Clients)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
logger.Debugf("%+v", client)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSuno, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() {
|
||||
v.AudioURL = audioURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
@@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() {
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@ const (
|
||||
)
|
||||
|
||||
type NotifyMessage struct {
|
||||
UserId int `json:"user_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
UserId int `json:"user_id"`
|
||||
ClientId string `json:"client_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
|
||||
|
||||
@@ -34,17 +34,19 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +84,21 @@ func (s *Service) Run() {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if task.ClientId != "" {
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
}
|
||||
|
||||
var r LumaRespVo
|
||||
r, err = s.LumaCreate(task)
|
||||
if err != nil {
|
||||
@@ -94,7 +111,7 @@ func (s *Service) Run() {
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -179,14 +196,12 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("Receive notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChLuma, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -226,7 +241,7 @@ func (s *Service) DownloadFiles() {
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
|
||||
13
api/service/ws_service.go
Normal file
13
api/service/ws_service.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import "geekai/core/types"
|
||||
|
||||
type WebsocketService struct {
|
||||
Clients *types.LMap[string, *types.WsClient] // clientId => Client
|
||||
}
|
||||
|
||||
func NewWebsocketService() *WebsocketService {
|
||||
return &WebsocketService{
|
||||
Clients: types.NewLMap[string, *types.WsClient](),
|
||||
}
|
||||
}
|
||||
12
api/store/model/app_type.go
Normal file
12
api/store/model/app_type.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type AppType struct {
|
||||
Id uint `gorm:"primarykey"`
|
||||
Name string
|
||||
Icon string
|
||||
Enabled bool
|
||||
SortNum int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
@@ -4,16 +4,17 @@ import "gorm.io/gorm"
|
||||
|
||||
type ChatMessage struct {
|
||||
BaseModel
|
||||
ChatId string // 会话 ID
|
||||
UserId uint // 用户 ID
|
||||
RoleId uint // 角色 ID
|
||||
Model string // AI模型
|
||||
Type string
|
||||
Icon string
|
||||
Tokens int
|
||||
Content string
|
||||
UseContext bool // 是否可以作为聊天上下文
|
||||
DeletedAt gorm.DeletedAt
|
||||
ChatId string // 会话 ID
|
||||
UserId uint // 用户 ID
|
||||
RoleId uint // 角色 ID
|
||||
Model string // AI模型
|
||||
Type string
|
||||
Icon string
|
||||
Tokens int
|
||||
TotalTokens int // 总 token 消耗
|
||||
Content string
|
||||
UseContext bool // 是否可以作为聊天上下文
|
||||
DeletedAt gorm.DeletedAt
|
||||
}
|
||||
|
||||
func (ChatMessage) TableName() string {
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
type ChatRole struct {
|
||||
BaseModel
|
||||
Tid int
|
||||
Key string `gorm:"column:marker;unique"` // 角色唯一标识
|
||||
Name string // 角色名称
|
||||
Context string `gorm:"column:context_json"` // 角色语料信息 json
|
||||
|
||||
@@ -2,7 +2,6 @@ package model
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Order 充值订单
|
||||
@@ -18,6 +17,6 @@ type Order struct {
|
||||
Status types.OrderStatus
|
||||
Remark string
|
||||
PayTime int64
|
||||
PayWay string // 支付方式
|
||||
DeletedAt gorm.DeletedAt
|
||||
PayWay string // 支付渠道
|
||||
PayType string // 支付类型
|
||||
}
|
||||
|
||||
10
api/store/vo/app_type.go
Normal file
10
api/store/vo/app_type.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package vo
|
||||
|
||||
type AppType struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
SortNum int `json:"sort_num"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
@@ -4,7 +4,8 @@ import "geekai/core/types"
|
||||
|
||||
type ChatRole struct {
|
||||
BaseVo
|
||||
Key string `json:"key"` // 角色唯一标识
|
||||
Key string `json:"key"` // 角色唯一标识
|
||||
Tid int `json:"tid"`
|
||||
Name string `json:"name"` // 角色名称
|
||||
Context []types.Message `json:"context"` // 角色语料信息
|
||||
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
||||
@@ -13,4 +14,5 @@ type ChatRole struct {
|
||||
SortNum int `json:"sort"` // 排序
|
||||
ModelId int `json:"model_id"` // 绑定模型 ID
|
||||
ModelName string `json:"model_name"` // 模型名称
|
||||
TypeName string `json:"type_name"` // 分类名称
|
||||
}
|
||||
|
||||
@@ -16,5 +16,8 @@ type Order struct {
|
||||
Status types.OrderStatus `json:"status"`
|
||||
PayTime int64 `json:"pay_time"`
|
||||
PayWay string `json:"pay_way"`
|
||||
PayType string `json:"pay_type"`
|
||||
PayMethod string `json:"pay_method"`
|
||||
PayName string `json:"pay_name"`
|
||||
Remark types.OrderRemark `json:"remark"`
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"io"
|
||||
@@ -18,8 +19,9 @@ import (
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// ReplyChunkMessage 回复客户片段端消息
|
||||
func ReplyChunkMessage(client *types.WsClient, message interface{}) {
|
||||
// SendMsg 回复客户片段端消息
|
||||
func SendMsg(client *types.WsClient, message types.ReplyMessage) {
|
||||
message.ClientId = client.Id
|
||||
msg, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||
@@ -31,11 +33,23 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// ReplyMessage 回复客户端一条完整的消息
|
||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
|
||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||
// SendAndFlush 回复客户端一条完整的消息
|
||||
func SendAndFlush(ws *types.WsClient, message interface{}) {
|
||||
SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeText, Body: message})
|
||||
SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
|
||||
}
|
||||
|
||||
func SendChunkMsg(ws *types.WsClient, message interface{}) {
|
||||
SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeText, Body: message})
|
||||
}
|
||||
|
||||
// SendErrMsg 向客户端发送错误消息
|
||||
func SendErrMsg(ws *types.WsClient, message interface{}) {
|
||||
SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeErr, Body: message})
|
||||
}
|
||||
|
||||
func SendChannelMsg(ws *types.WsClient, channel types.WsChannel, message interface{}) {
|
||||
SendMsg(ws, types.ReplyMessage{Channel: channel, Type: types.MsgTypeText, Body: message})
|
||||
}
|
||||
|
||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||
@@ -59,7 +73,9 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
imageBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -68,3 +84,11 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||
|
||||
return imageBytes, nil
|
||||
}
|
||||
|
||||
func GetBaseURL(strURL string) string {
|
||||
u, err := url.Parse(strURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s://%s", u.Scheme, u.Host)
|
||||
}
|
||||
|
||||
@@ -45,18 +45,25 @@ type apiRes struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) {
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
return SendOpenAIMessage(db, messages, modelName, keyId)
|
||||
}
|
||||
|
||||
func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
|
||||
if keyId > 0 {
|
||||
session = session.Where("id", keyId)
|
||||
}
|
||||
err := session.First(&apiKey).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", err)
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
client := req.C()
|
||||
@@ -65,7 +72,7 @@ func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error)
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, modelName)
|
||||
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||
r, err := client.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: modelName,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
@@ -134,3 +135,17 @@ func GenRedeemCode(codeLength int) (string, error) {
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// IsValidEmail 检查给定的字符串是否是有效的电子邮件地址
|
||||
func IsValidEmail(email string) bool {
|
||||
// 这个正则表达式匹配大多数常见的邮箱格式
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
return emailRegex.MatchString(email)
|
||||
}
|
||||
|
||||
// IsValidMobile 检查给定的字符串是否是有效的中国大陆手机号
|
||||
func IsValidMobile(phone string) bool {
|
||||
// 支持 13x, 14x, 15x, 16x, 17x, 18x, 19x 开头的号码
|
||||
phoneRegex := regexp.MustCompile(`^1[3-9]\d{9}$`)
|
||||
return phoneRegex.MatchString(phone)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user