diff --git a/api/core/app_server.go b/api/core/app_server.go index 2fca4ea7..d40feb90 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -12,11 +12,11 @@ import ( "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" "github.com/nfnt/resize" + "golang.org/x/image/webp" "gorm.io/gorm" "image" "image/jpeg" "io" - "log" "net/http" "os" "runtime/debug" @@ -215,9 +215,12 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/client" || + c.Request.URL.Path == "/api/dall/imgWall" || + c.Request.URL.Path == "/api/dall/client" || c.Request.URL.Path == "/api/config/get" || c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/menu/list" || + c.Request.URL.Path == "/api/markMap/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || @@ -327,6 +330,10 @@ func staticResourceMiddleware() gin.HandlerFunc { // 解码图片 img, _, err := image.Decode(file) + // for .webp image + if err != nil { + img, err = webp.Decode(file) + } if err != nil { c.String(http.StatusInternalServerError, "Error decoding image") return @@ -343,7 +350,9 @@ func staticResourceMiddleware() gin.HandlerFunc { var buffer bytes.Buffer err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality}) if err != nil { - log.Fatal(err) + logger.Error(err) + c.String(http.StatusInternalServerError, err.Error()) + return } // 设置图片缓存有效期为一年 (365天) diff --git a/api/core/config.go b/api/core/config.go index 5447542e..74f7e305 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -23,7 +23,7 @@ func NewDefaultConfig() *types.AppConfig { SecretKey: utils.RandString(64), MaxAge: 86400, }, - ApiConfig: types.ChatPlusApiConfig{}, + ApiConfig: types.ApiConfig{}, OSS: types.OSSConfig{ Active: "local", Local: types.LocalStorageConfig{ diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 7ba2a252..b6b63aa2 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -8,7 +8,7 @@ type ApiRequest struct { Stream bool `json:"stream"` Messages []interface{} `json:"messages,omitempty"` Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM - Tools []interface{} `json:"tools,omitempty"` + Tools []Tool `json:"tools,omitempty"` Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 ToolChoice string `json:"tool_choice,omitempty"` @@ -62,6 +62,7 @@ type ChatModel struct { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` // 绑定 API KEY } type ApiError struct { diff --git a/api/core/types/config.go b/api/core/types/config.go index 24a94e78..4c59875a 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -14,7 +14,7 @@ type AppConfig struct { StaticDir string // 静态资源目录 StaticUrl string // 静态资源 URL Redis RedisConfig // redis 连接信息 - ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs + ApiConfig ApiConfig // ChatPlus API authorization configs SMS SMSConfig // send mobile message config OSS OSSConfig // OSS config MjProxyConfigs []MjProxyConfig // MJ proxy config @@ -30,6 +30,7 @@ type AppConfig struct { } type SmtpConfig struct { + UseTls bool // 是否使用 TLS 发送 Host string Port int AppName string // 应用名称 @@ -37,7 +38,7 @@ type SmtpConfig struct { Password string // 发件人邮箱密码 } -type ChatPlusApiConfig struct { +type ApiConfig struct { ApiURL string AppId string Token string @@ -114,6 +115,17 @@ type RedisConfig struct { DB int } +// LicenseKey 存储许可证书的 KEY +const LicenseKey = "Geek-AI-License" + +type License struct { + Key string // 许可证书密钥 + MachineId string // 机器码 + UserNum int // 用户数量 + ExpiredAt int64 // 过期时间 + IsActive bool // 是否激活 +} + func (c RedisConfig) Url() string { return fmt.Sprintf("%s:%d", c.Host, c.Port) } @@ -136,7 +148,7 @@ type SystemConfig struct { InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值 VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值 - RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机,邮箱注册,账号密码注册 + RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册 EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册 RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址 diff --git a/api/core/types/function.go b/api/core/types/function.go index 8b5f183f..09808461 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -8,19 +8,14 @@ type ToolCall struct { } `json:"function"` } +type Tool struct { + Type string `json:"type"` + Function Function `json:"function"` +} + type Function struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters Parameters `json:"parameters"` -} - -type Parameters struct { - Type string `json:"type"` - Required []string `json:"required"` - Properties map[string]Property `json:"properties"` -} - -type Property struct { - Type string `json:"type"` - Description string `json:"description"` + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` + Required interface{} `json:"required,omitempty"` } diff --git a/api/core/types/task.go b/api/core/types/task.go index cd4b516e..bb1f7689 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -59,3 +59,16 @@ type SdTaskParams struct { HdScaleAlg string `json:"hd_scale_alg"` // 放大算法 HdSteps int `json:"hd_steps"` // 高清修复迭代步数 } + +// DallTask DALL-E task +type DallTask struct { + JobId uint `json:"job_id"` + UserId uint `json:"user_id"` + Prompt string `json:"prompt"` + N int `json:"n"` + Quality string `json:"quality"` + Size string `json:"size"` + Style string `json:"style"` + + Power int `json:"power"` +} diff --git a/api/core/types/web.go b/api/core/types/web.go index 601612fa..041a9859 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -21,7 +21,7 @@ const ( WsStart = WsMsgType("start") WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end") - WsMjImg = WsMsgType("mj") + WsErr = WsMsgType("error") ) type BizCode int diff --git a/api/go.mod b/api/go.mod index fc131837..dcc8c8dc 100644 --- a/api/go.mod +++ b/api/go.mod @@ -27,14 +27,19 @@ require github.com/xxl-job/xxl-job-executor-go v1.2.0 require ( github.com/mojocn/base64Captcha v1.3.1 + github.com/shirou/gopsutil v3.21.11+incompatible github.com/shopspring/decimal v1.3.1 github.com/syndtr/goleveldb v1.0.0 + golang.org/x/image v0.0.0-20211028202545-6944b10bf410 ) require ( + github.com/go-ole/go-ole v1.2.6 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect - golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect + github.com/tklauser/go-sysconf v0.3.13 // indirect + github.com/tklauser/numcpus v0.7.0 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect ) require ( @@ -107,6 +112,6 @@ require ( go.uber.org/fx v1.19.3 go.uber.org/multierr v1.6.0 // indirect golang.org/x/crypto v0.12.0 - golang.org/x/sys v0.11.0 // indirect + golang.org/x/sys v0.15.0 // indirect gorm.io/gorm v1.25.1 ) diff --git a/api/go.sum b/api/go.sum index e5c987ce..64ea4f3a 100644 --- a/api/go.sum +++ b/api/go.sum @@ -40,6 +40,8 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= @@ -175,6 +177,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= +github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -203,6 +207,10 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= +github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4= +github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= +github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4= +github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= @@ -215,6 +223,8 @@ github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onq github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -239,8 +249,9 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ= golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= @@ -263,6 +274,7 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -274,8 +286,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go index 94cab69e..7935d0ba 100644 --- a/api/handler/admin/api_key_handler.go +++ b/api/handler/admin/api_key_handler.go @@ -8,6 +8,7 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -65,14 +66,20 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { } func (h *ApiKeyHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } + status := h.GetBool(c, "status") + t := h.GetTrim(c, "type") + session := h.DB.Session(&gorm.Session{}) + if status { + session = session.Where("enabled", true) + } + if t != "" { + session = session.Where("type", t) + } + var items []model.ApiKey var keys = make([]vo.ApiKey, 0) - res := h.DB.Find(&items) + res := session.Find(&items) if res.Error == nil { for _, item := range items { var key vo.ApiKey @@ -122,6 +129,5 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) { resp.ERROR(c, "更新数据库失败!") return } - resp.SUCCESS(c) } diff --git a/api/handler/admin/chat_handler.go b/api/handler/admin/chat_handler.go index 3d29d165..64fb5587 100644 --- a/api/handler/admin/chat_handler.go +++ b/api/handler/admin/chat_handler.go @@ -33,11 +33,6 @@ type chatItemVo struct { } func (h *ChatHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - var data struct { Title string `json:"title"` UserId uint `json:"user_id"` diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go index 97bb559e..ad0ce3c2 100644 --- a/api/handler/admin/chat_model_handler.go +++ b/api/handler/admin/chat_model_handler.go @@ -10,7 +10,6 @@ import ( "chatplus/utils/resp" "github.com/gin-gonic/gin" "gorm.io/gorm" - "time" ) type ChatModelHandler struct { @@ -34,6 +33,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id,omitempty"` CreatedAt int64 `json:"created_at"` } if err := c.ShouldBindJSON(&data); err != nil { @@ -51,12 +51,15 @@ func (h *ChatModelHandler) Save(c *gin.Context) { MaxTokens: data.MaxTokens, MaxContext: data.MaxContext, Temperature: data.Temperature, + KeyId: data.KeyId, Power: data.Power} - item.Id = data.Id - if item.Id > 0 { - item.CreatedAt = time.Unix(data.CreatedAt, 0) + var res *gorm.DB + if data.Id > 0 { + item.Id = data.Id + res = h.DB.Select("*").Omit("created_at").Updates(&item) + } else { + res = h.DB.Create(&item) } - res := h.DB.Save(&item) if res.Error != nil { resp.ERROR(c, "更新数据库失败!") return @@ -75,11 +78,6 @@ func (h *ChatModelHandler) Save(c *gin.Context) { // List 模型列表 func (h *ChatModelHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - session := h.DB.Session(&gorm.Session{}) enable := h.GetBool(c, "enable") if enable { @@ -88,18 +86,33 @@ func (h *ChatModelHandler) List(c *gin.Context) { var items []model.ChatModel var cms = make([]vo.ChatModel, 0) res := session.Order("sort_num ASC").Find(&items) - if res.Error == nil { - for _, item := range items { - var cm vo.ChatModel - err := utils.CopyObject(item, &cm) - if err == nil { - cm.Id = item.Id - cm.CreatedAt = item.CreatedAt.Unix() - cm.UpdatedAt = item.UpdatedAt.Unix() - cms = append(cms, cm) - } else { - logger.Error(err) - } + if res.Error != nil { + resp.SUCCESS(c, cms) + return + } + + // initialize key name + keyIds := make([]int, 0) + for _, v := range items { + keyIds = append(keyIds, v.KeyId) + } + var keys []model.ApiKey + keyMap := make(map[uint]string) + h.DB.Where("id IN ?", keyIds).Find(&keys) + for _, v := range keys { + keyMap[v.Id] = v.Name + } + for _, item := range items { + var cm vo.ChatModel + err := utils.CopyObject(item, &cm) + if err == nil { + cm.Id = item.Id + cm.CreatedAt = item.CreatedAt.Unix() + cm.UpdatedAt = item.UpdatedAt.Unix() + cm.KeyName = keyMap[uint(item.KeyId)] + cms = append(cms, cm) + } else { + logger.Error(err) } } resp.SUCCESS(c, cms) diff --git a/api/handler/admin/chat_role_handler.go b/api/handler/admin/chat_role_handler.go index 7b72cb44..caec61b9 100644 --- a/api/handler/admin/chat_role_handler.go +++ b/api/handler/admin/chat_role_handler.go @@ -8,9 +8,10 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "time" + "github.com/gin-gonic/gin" "gorm.io/gorm" - "time" ) type ChatRoleHandler struct { @@ -50,11 +51,6 @@ func (h *ChatRoleHandler) Save(c *gin.Context) { } func (h *ChatRoleHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - var items []model.ChatRole var roles = make([]vo.ChatRole, 0) res := h.DB.Order("sort_num ASC").Find(&items) @@ -63,6 +59,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) { return } + // initialize model mane for role + modelIds := make([]int, 0) + for _, v := range items { + if v.ModelId > 0 { + modelIds = append(modelIds, v.ModelId) + } + } + + modelNameMap := make(map[int]string) + if len(modelIds) > 0 { + var models []model.ChatModel + tx := h.DB.Where("id IN ?", modelIds).Find(&models) + if tx.Error == nil { + for _, m := range models { + modelNameMap[int(m.Id)] = m.Name + } + } + } + for _, v := range items { var role vo.ChatRole err := utils.CopyObject(v, &role) @@ -70,6 +85,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) { role.Id = v.Id role.CreatedAt = v.CreatedAt.Unix() role.UpdatedAt = v.UpdatedAt.Unix() + role.ModelName = modelNameMap[role.ModelId] roles = append(roles, role) } } diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index 7ad863aa..5aacb9f2 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -4,20 +4,24 @@ import ( "chatplus/core" "chatplus/core/types" "chatplus/handler" + "chatplus/service" + "chatplus/store" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" - "github.com/gin-gonic/gin" + "github.com/shirou/gopsutil/host" "gorm.io/gorm" ) type ConfigHandler struct { handler.BaseHandler + levelDB *store.LevelDB + licenseService *service.LicenseService } -func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler { - return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler { + return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB, licenseService: licenseService} } func (h *ConfigHandler) Update(c *gin.Context) { @@ -70,11 +74,6 @@ func (h *ConfigHandler) Update(c *gin.Context) { // Get 获取指定的系统配置 func (h *ConfigHandler) Get(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - key := c.Query("key") var config model.Config res := h.DB.Where("marker", key).First(&config) @@ -92,3 +91,27 @@ func (h *ConfigHandler) Get(c *gin.Context) { resp.SUCCESS(c, value) } + +// Active 激活系统 +func (h *ConfigHandler) Active(c *gin.Context) { + var data struct { + License string `json:"license"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + info, err := host.Info() + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + err = h.licenseService.ActiveLicense(data.License, info.HostID) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, info.HostID) +} diff --git a/api/handler/admin/function_handler.go b/api/handler/admin/function_handler.go index 97940931..d9eed1fc 100644 --- a/api/handler/admin/function_handler.go +++ b/api/handler/admin/function_handler.go @@ -71,11 +71,6 @@ func (h *FunctionHandler) Set(c *gin.Context) { } func (h *FunctionHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - var items []model.Function res := h.DB.Find(&items) if res.Error != nil { diff --git a/api/handler/admin/order_handler.go b/api/handler/admin/order_handler.go index 993b3995..1183e01f 100644 --- a/api/handler/admin/order_handler.go +++ b/api/handler/admin/order_handler.go @@ -22,11 +22,6 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { } func (h *OrderHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - var data struct { OrderNo string `json:"order_no"` Status int `json:"status"` diff --git a/api/handler/admin/reward_handler.go b/api/handler/admin/reward_handler.go index e2d283e3..a2c44cb9 100644 --- a/api/handler/admin/reward_handler.go +++ b/api/handler/admin/reward_handler.go @@ -21,11 +21,6 @@ func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { } func (h *RewardHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - var items []model.Reward res := h.DB.Order("id DESC").Find(&items) var rewards = make([]vo.Reward, 0) diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go index 1bc70b40..430b66bb 100644 --- a/api/handler/admin/user_handler.go +++ b/api/handler/admin/user_handler.go @@ -25,11 +25,6 @@ func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler { // List 用户列表 func (h *UserHandler) List(c *gin.Context) { - if err := utils.CheckPermission(c, h.DB); err != nil { - resp.NotPermission(c) - return - } - page := h.GetInt(c, "page", 1) pageSize := h.GetInt(c, "page_size", 20) username := h.GetTrim(c, "username") diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index a040aae6..11b3b69a 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -30,7 +30,7 @@ func (h *ChatHandler) sendAzureMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go index e39ae455..08809dfe 100644 --- a/api/handler/chatimpl/baidu_handler.go +++ b/api/handler/chatimpl/baidu_handler.go @@ -47,7 +47,7 @@ func (h *ChatHandler) sendBaiduMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 439fc36e..4b745d0b 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -6,6 +6,7 @@ import ( "chatplus/core/types" "chatplus/handler" logger2 "chatplus/logger" + "chatplus/service" "chatplus/service/oss" "chatplus/store/model" "chatplus/store/vo" @@ -35,15 +36,17 @@ var logger = logger2.GetLogger() type ChatHandler struct { handler.BaseHandler - redis *redis.Client - uploadManager *oss.UploaderManager + redis *redis.Client + uploadManager *oss.UploaderManager + licenseService *service.LicenseService } -func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler { +func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler { return &ChatHandler{ - BaseHandler: handler.BaseHandler{App: app, DB: db}, - redis: redis, - uploadManager: manager, + BaseHandler: handler.BaseHandler{App: app, DB: db}, + redis: redis, + uploadManager: manager, + licenseService: licenseService, } } @@ -68,9 +71,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { modelId := h.GetInt(c, "model_id", 0) client := types.NewWsClient(ws) + var chatRole model.ChatRole + res := h.DB.First(&chatRole, roleId) + if res.Error != nil || !chatRole.Enable { + utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") + c.Abort() + return + } + // if the role bind a model_id, use role's bind model_id + if chatRole.ModelId > 0 { + modelId = chatRole.ModelId + } // get model info var chatModel model.ChatModel - res := h.DB.First(&chatModel, modelId) + res = h.DB.First(&chatModel, modelId) if res.Error != nil || chatModel.Enabled == false { utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!") c.Abort() @@ -111,15 +125,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { MaxTokens: chatModel.MaxTokens, MaxContext: chatModel.MaxContext, Temperature: chatModel.Temperature, + KeyId: chatModel.KeyId, Platform: types.Platform(chatModel.Platform)} logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) - var chatRole model.ChatRole - res = h.DB.First(&chatRole, roleId) - if res.Error != nil || !chatRole.Enable { - utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") - c.Abort() - return - } h.Init() @@ -235,7 +243,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio break } - var tools = make([]interface{}, 0) + var tools = make([]types.Tool, 0) for _, v := range items { var parameters map[string]interface{} err = utils.JsonDecode(v.Parameters, ¶meters) @@ -244,15 +252,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } required := parameters["required"] delete(parameters, "required") - tools = append(tools, gin.H{ - "type": "function", - "function": gin.H{ - "name": v.Name, - "description": v.Description, - "parameters": parameters, - "required": required, + tool := types.Tool{ + Type: "function", + Function: types.Function{ + Name: v.Name, + Description: v.Description, + Parameters: parameters, }, - }) + } + + // Fixed: compatible for gpt4-turbo-xxx model + if !strings.HasPrefix(req.Model, "gpt-4-turbo-") { + tool.Function.Required = required + } + tools = append(tools, tool) } if len(tools) > 0 { @@ -332,6 +345,34 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Content: prompt, }) req.Input["messages"] = reqMgs + } else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model + imgURLs := utils.ExtractImgURL(prompt) + logger.Debugf("detected IMG: %+v", imgURLs) + var content interface{} + if len(imgURLs) > 0 { + data := make([]interface{}, 0) + text := prompt + for _, v := range imgURLs { + text = strings.Replace(text, v, "", 1) + data = append(data, gin.H{ + "type": "image_url", + "image_url": gin.H{ + "url": v, + }, + }) + } + data = append(data, gin.H{ + "type": "text", + "text": text, + }) + content = data + } else { + content = prompt + } + req.Messages = append(reqMgs, map[string]interface{}{ + "role": "user", + "content": content, + }) } else { req.Messages = append(reqMgs, map[string]interface{}{ "role": "user", @@ -339,6 +380,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio }) } + logger.Debugf("%+v", req.Messages) + switch session.Model.Platform { case types.Azure: return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) @@ -426,13 +469,29 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) { // 发送请求到 OpenAI 服务器 // useOwnApiKey: 是否使用了用户自己的 API KEY -func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) { - res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) - if res.Error != nil { +func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + if session.Model.KeyId > 0 { + h.DB.Debug().Where("id", session.Model.KeyId).Find(apiKey) + } + // use the last unused key + if apiKey.Id == 0 { + h.DB.Debug().Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) + } + if apiKey.Id == 0 { return nil, errors.New("no available key, please import key") } + + // ONLY allow apiURL in blank list + if session.Model.Platform == types.OpenAI { + err := h.licenseService.IsValidApiURL(apiKey.ApiURL) + if err != nil { + return nil, err + } + } + var apiURL string - switch platform { + switch session.Model.Platform { case types.Azure: md := strings.Replace(req.Model, ".", "", 1) apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) @@ -455,7 +514,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf // 更新 API KEY 的最后使用时间 h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // 百度文心,需要串接 access_token - if platform == types.Baidu { + if session.Model.Platform == types.Baidu { token, err := h.getBaiduToken(apiKey.Value) if err != nil { return nil, err @@ -479,7 +538,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf request = request.WithContext(ctx) request.Header.Set("Content-Type", "application/json") - var proxyURL string if len(apiKey.ProxyURL) > 5 { // 使用代理 proxy, _ := url.Parse(apiKey.ProxyURL) client = &http.Client{ @@ -490,8 +548,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } else { client = http.DefaultClient } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model) - switch platform { + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) + switch session.Model.Platform { case types.Azure: request.Header.Set("api-key", apiKey.Value) break diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go index 678f481d..5f391b3f 100644 --- a/api/handler/chatimpl/chatglm_handler.go +++ b/api/handler/chatimpl/chatglm_handler.go @@ -31,7 +31,7 @@ func (h *ChatHandler) sendChatGLMMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index 36c49815..2eb32866 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -31,7 +31,7 @@ func (h *ChatHandler) sendOpenAiMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { logger.Error(err) @@ -74,6 +74,10 @@ func (h *ChatHandler) sendOpenAiMessage( utils.ReplyMessage(ws, ErrImg) break } + if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { + utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") + break + } var tool types.ToolCall if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { @@ -98,8 +102,10 @@ func (h *ChatHandler) sendOpenAiMessage( res := h.DB.Where("name = ?", tool.Function.Name).First(&function) if res.Error == nil { toolCall = true + callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)}) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg}) + contents = append(contents, callMsg) } continue } diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go index 1c8edcad..340f00de 100644 --- a/api/handler/chatimpl/qwen_handler.go +++ b/api/handler/chatimpl/qwen_handler.go @@ -45,7 +45,7 @@ func (h *ChatHandler) sendQWenMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index adb646dc..36a5b785 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "github.com/gorilla/websocket" + "gorm.io/gorm" "html/template" "io" "net/http" @@ -69,7 +70,15 @@ func (h *ChatHandler) sendXunFeiMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 var apiKey model.ApiKey - res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) + var res *gorm.DB + // use the bind key + if session.Model.KeyId > 0 { + res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey) + } + // use the last unused key + if res.Error != nil { + res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) + } if res.Error != nil { utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") return nil diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go new file mode 100644 index 00000000..9401d610 --- /dev/null +++ b/api/handler/dalle_handler.go @@ -0,0 +1,255 @@ +package handler + +import ( + "chatplus/core" + "chatplus/core/types" + "chatplus/service/dalle" + "chatplus/service/oss" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "chatplus/utils/resp" + "net/http" + + "github.com/gorilla/websocket" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type DallJobHandler struct { + BaseHandler + redis *redis.Client + service *dalle.Service + uploader *oss.UploaderManager +} + +func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { + return &DallJobHandler{ + service: service, + uploader: manager, + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *DallJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.service.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + client.Close() + h.service.Clients.Delete(uint(userId)) + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 DallE 心跳消息:", message.Content) + continue + } + } + }() +} + +func (h *DallJobHandler) preCheck(c *gin.Context) bool { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return false + } + + if user.Power < h.App.SysConfig.SdPower { + resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") + return false + } + + return true + +} + +// Image 创建一个绘画任务 +func (h *DallJobHandler) Image(c *gin.Context) { + if !h.preCheck(c) { + return + } + + var data types.DallTask + if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + job := model.DallJob{ + UserId: uint(userId), + Prompt: data.Prompt, + Power: h.App.SysConfig.DallPower, + } + res := h.DB.Create(&job) + if res.Error != nil { + resp.ERROR(c, "error with save job: "+res.Error.Error()) + return + } + + h.service.PushTask(types.DallTask{ + JobId: job.Id, + UserId: uint(userId), + Prompt: data.Prompt, + Quality: data.Quality, + Size: data.Size, + Style: data.Style, + Power: job.Power, + }) + + client := h.service.Clients.Get(job.UserId) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + resp.SUCCESS(c) +} + +// ImgWall 照片墙 +func (h *DallJobHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 SD 任务列表 +func (h *DallJobHandler) JobList(c *gin.Context) { + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + publish := h.GetBool(c, "publish") + + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取任务列表 +func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { + + session := h.DB.Session(&gorm.Session{}) + if finish { + session = session.Where("progress = ?", 100).Order("id DESC") + } else { + session = session.Where("progress < ?", 100).Order("id ASC") + } + if userId > 0 { + session = session.Where("user_id = ?", userId) + } + if publish { + session = session.Where("publish", publish) + } + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + + var items []model.DallJob + res := session.Find(&items) + if res.Error != nil { + return res.Error, nil + } + + var jobs = make([]vo.DallJob, 0) + for _, item := range items { + var job vo.DallJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + jobs = append(jobs, job) + } + + return nil, jobs +} + +// Remove remove task image +func (h *DallJobHandler) Remove(c *gin.Context) { + var data struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + ImgURL string `json:"img_url"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // remove job recode + res := h.DB.Delete(&model.DallJob{Id: data.Id}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + // remove image + err := h.uploader.GetUploadHandler().Delete(data.ImgURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + resp.SUCCESS(c) +} + +// Publish 发布/取消发布图片到画廊显示 +func (h *DallJobHandler) Publish(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true) + if res.Error != nil { + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c) +} diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index e9eb57df..0941db08 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -3,27 +3,35 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/service/dalle" "chatplus/service/oss" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" "errors" "fmt" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/imroc/req/v3" "gorm.io/gorm" - "strings" - "time" ) type FunctionHandler struct { BaseHandler - config types.ChatPlusApiConfig + config types.ApiConfig uploadManager *oss.UploaderManager + dallService *dalle.Service } -func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { +func NewFunctionHandler( + server *core.AppServer, + db *gorm.DB, + config *types.AppConfig, + manager *oss.UploaderManager, + dallService *dalle.Service) *FunctionHandler { return &FunctionHandler{ BaseHandler: BaseHandler{ App: server, @@ -31,6 +39,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo }, config: config.ApiConfig, uploadManager: manager, + dallService: dallService, } } @@ -151,30 +160,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) { resp.SUCCESS(c, strings.Join(builder, "\n\n")) } -type imgReq struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` -} - -type imgRes struct { - Created int64 `json:"created"` - Data []struct { - RevisedPrompt string `json:"revised_prompt"` - Url string `json:"url"` - } `json:"data"` -} - -type ErrRes struct { - Error struct { - Code interface{} `json:"code"` - Message string `json:"message"` - Param interface{} `json:"param"` - Type string `json:"type"` - } `json:"error"` -} - // Dall3 DallE3 AI 绘图 func (h *FunctionHandler) Dall3(c *gin.Context) { if err := h.checkAuth(c); err != nil { @@ -190,85 +175,40 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { logger.Debugf("绘画参数:%+v", params) var user model.User - tx := h.DB.Where("id = ?", params["user_id"]).First(&user) - if tx.Error != nil { + res := h.DB.Where("id = ?", params["user_id"]).First(&user) + if res.Error != nil { resp.ERROR(c, "当前用户不存在!") return } - if user.Power < h.App.SysConfig.DallPower { - resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") - return - } - + // create dall task prompt := utils.InterfaceToString(params["prompt"]) - // get image generation API KEY - var apiKey model.ApiKey - tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) - if tx.Error != nil { - resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) + job := model.DallJob{ + UserId: user.Id, + Prompt: prompt, + Power: h.App.SysConfig.DallPower, + } + res = h.DB.Create(&job) + + if res.Error != nil { + resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error()) return } - // translate prompt - const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" - pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"])) - if err == nil { - logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt) - prompt = pt - } - var res imgRes - var errRes ErrRes - var request *req.Request - if len(apiKey.ProxyURL) > 5 { - request = req.C().SetProxyURL(apiKey.ProxyURL).R() - } else { - request = req.C().R() - } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) - r, err := request.SetHeader("Content-Type", "application/json"). - SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(imgReq{ - Model: "dall-e-3", - Prompt: prompt, - N: 1, - Size: "1024x1024", - }). - SetErrorResult(&errRes). - SetSuccessResult(&res).Post(apiKey.ApiURL) - if r.IsErrorState() { - resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message) - return - } - // 更新 API KEY 的最后使用时间 - h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) - logger.Debugf("%+v", res) - // 存储图片 - imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false) + content, err := h.dallService.Image(types.DallTask{ + JobId: job.Id, + UserId: user.Id, + Prompt: job.Prompt, + N: 1, + Quality: "standard", + Size: "1024x1024", + Style: "vivid", + Power: job.Power, + }, true) if err != nil { - resp.ERROR(c, "下载图片失败: "+err.Error()) + resp.ERROR(c, "任务执行失败:"+err.Error()) return } - content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL) - // 更新用户算力 - tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - var u model.User - h.DB.Where("id", user.Id).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: h.App.SysConfig.DallPower, - Balance: u.Power, - Mark: types.PowerSub, - Model: "dall-e-3", - Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)), - CreatedAt: time.Now(), - }) - } - resp.SUCCESS(c, content) } diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go new file mode 100644 index 00000000..e5d149eb --- /dev/null +++ b/api/handler/markmap_handler.go @@ -0,0 +1,227 @@ +package handler + +import ( + "bufio" + "bytes" + "chatplus/core" + "chatplus/core/types" + "chatplus/store/model" + "chatplus/utils" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// MarkMapHandler 生成思维导图 +type MarkMapHandler struct { + BaseHandler + clients *types.LMap[int, *types.WsClient] +} + +func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { + return &MarkMapHandler{ + BaseHandler: BaseHandler{App: app, DB: db}, + clients: types.NewLMap[int, *types.WsClient](), + } +} + +func (h *MarkMapHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + return + } + + modelId := h.GetInt(c, "model_id", 0) + userId := h.GetInt(c, "user_id", 0) + + client := types.NewWsClient(ws) + h.clients.Put(userId, client) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + client.Close() + h.clients.Delete(userId) + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 MarkMap 心跳消息:", message.Content) + continue + } + // change model + if message.Type == "model_id" { + modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0) + continue + } + + logger.Info("Receive a message: ", message.Content) + err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) + if err != nil { + logger.Error(err) + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()}) + } + + } + }() +} + +func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { + var user model.User + res := h.DB.Model(&model.User{}).First(&user, userId) + if res.Error != nil { + return fmt.Errorf("error with query user info: %v", res.Error) + } + var chatModel model.ChatModel + res = h.DB.Where("id", modelId).First(&chatModel) + if res.Error != nil { + return fmt.Errorf("error with query chat model: %v", res.Error) + } + + if user.Status == false { + return errors.New("当前用户被禁用") + } + + if user.Power < chatModel.Power { + return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power) + } + + messages := make([]interface{}, 0) + messages = append(messages, types.Message{Role: "system", Content: "你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。不要输出任何解释性的语句。"}) + messages = append(messages, types.Message{Role: "user", Content: prompt}) + var req = types.ApiRequest{ + Model: chatModel.Value, + Stream: true, + Messages: messages, + } + + var apiKey model.ApiKey + response, err := h.doRequest(req, chatModel, &apiKey) + if err != nil { + return fmt.Errorf("请求 OpenAI API 失败: %s", err) + } + + defer response.Body.Close() + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + // 循环读取 Chunk 消息 + var message = types.Message{} + scanner := bufio.NewScanner(response.Body) + var isNew = true + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 + return fmt.Errorf("error with decode data: %v", err) + } + + // 初始化 role + if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { + message.Role = responseBody.Choices[0].Delta.Role + continue + } else if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { + if isNew { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) + isNew = false + } + utils.ReplyChunkMessage(client, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } + } // end for + + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + + } else { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("读取响应失败: %v", err) + } + var res types.ApiError + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("解析响应失败: %v", err) + } + + // OpenAI API 调用异常处理 + if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { + // remove key + h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) + return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") + } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { + return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") + } else { + return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message) + } + } + + return nil +} + +func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + var res *gorm.DB + if chatModel.KeyId > 0 { + res = h.DB.Where("id", chatModel.KeyId).Find(apiKey) + } + // use the last unused key + if res.Error != nil { + res = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) + } + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } + apiURL := apiKey.ApiURL + // 更新 API KEY 的最后使用时间 + h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + // 创建 HttpClient 请求对象 + var client *http.Client + requestBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + if len(apiKey.ProxyURL) > 5 { // 使用代理 + proxy, _ := url.Parse(apiKey.ProxyURL) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } else { + client = http.DefaultClient + } + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + return client.Do(request) +} diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index b659c5b0..e0e0f020 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -146,7 +146,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { } if data.SRef != "" { - params += fmt.Sprintf(" --sref %s", data.CRef) + params += fmt.Sprintf(" --sref %s", data.SRef) } if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") { params += fmt.Sprintf(" %s", data.Model) diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index b9c3625e..25f122a0 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -65,7 +65,7 @@ func (h *SdJobHandler) Client(c *gin.Context) { logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } -func (h *SdJobHandler) checkLimits(c *gin.Context) bool { +func (h *SdJobHandler) preCheck(c *gin.Context) bool { user, err := h.GetLoginUser(c) if err != nil { resp.NotAuth(c) @@ -88,7 +88,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool { // Image 创建一个绘画任务 func (h *SdJobHandler) Image(c *gin.Context) { - if !h.checkLimits(c) { + if !h.preCheck(c) { return } @@ -260,9 +260,10 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, if item.Progress < 100 { // 从 leveldb 中获取图片预览数据 - imageData, err := h.leveldb.Get(item.TaskId) + var imageData string + err = h.leveldb.Get(item.TaskId, &imageData) if err == nil { - job.ImgURL = "data:image/png;base64," + string(imageData) + job.ImgURL = "data:image/png;base64," + imageData } } jobs = append(jobs, job) @@ -298,7 +299,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) { client := h.pool.Clients.Get(data.UserId) if client != nil { - _ = client.Send([]byte("Task Updated")) + _ = client.Send([]byte(sd.Finished)) } resp.SUCCESS(c) diff --git a/api/main.go b/api/main.go index e5b8f3cf..1c7177ad 100644 --- a/api/main.go +++ b/api/main.go @@ -8,6 +8,7 @@ import ( "chatplus/handler/chatimpl" logger2 "chatplus/logger" "chatplus/service" + "chatplus/service/dalle" "chatplus/service/mj" "chatplus/service/oss" "chatplus/service/payment" @@ -43,13 +44,13 @@ type AppLifecycle struct { // OnStart 应用程序启动时执行 func (l *AppLifecycle) OnStart(context.Context) error { - log.Println("AppLifecycle OnStart") + logger.Info("AppLifecycle OnStart") return nil } // OnStop 应用程序停止时执行 func (l *AppLifecycle) OnStop(context.Context) error { - log.Println("AppLifecycle OnStop") + logger.Info("AppLifecycle OnStop") return nil } @@ -153,9 +154,18 @@ func main() { }), fx.Provide(oss.NewUploaderManager), fx.Provide(mj.NewService), + fx.Provide(dalle.NewService), + fx.Invoke(func(service *dalle.Service) { + service.Run() + service.CheckTaskNotify() + service.DownloadImages() + service.CheckTaskStatus() + }), // 邮件服务 fx.Provide(service.NewSmtpService), + // License 服务 + fx.Provide(service.NewLicenseService), // 微信机器人服务 fx.Provide(wx.NewWeChatBot), @@ -277,9 +287,10 @@ func main() { // 管理后台控制器 fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { - group := s.Engine.Group("/api/admin/config/") - group.POST("update", h.Update) - group.GET("get", h.Get) + group := s.Engine.Group("/api/admin/") + group.POST("config/update", h.Update) + group.GET("config/get", h.Get) + group.POST("active", h.Active) }), fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { group := s.Engine.Group("/api/admin/") @@ -436,6 +447,21 @@ func main() { group := s.Engine.Group("/api/menu/") group.GET("list", h.List) }), + fx.Provide(handler.NewMarkMapHandler), + fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { + group := s.Engine.Group("/api/markMap/") + group.Any("client", h.Client) + }), + fx.Provide(handler.NewDallJobHandler), + fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { + group := s.Engine.Group("/api/dall") + group.Any("client", h.Client) + group.POST("image", h.Image) + group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) + group.POST("remove", h.Remove) + group.POST("publish", h.Publish) + }), fx.Invoke(func(s *core.AppServer, db *gorm.DB) { go func() { err := s.Run(db) diff --git a/api/service/captcha_service.go b/api/service/captcha_service.go index 4efbfe55..e8cfd39b 100644 --- a/api/service/captcha_service.go +++ b/api/service/captcha_service.go @@ -9,11 +9,11 @@ import ( ) type CaptchaService struct { - config types.ChatPlusApiConfig + config types.ApiConfig client *req.Client } -func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService { +func NewCaptchaService(config types.ApiConfig) *CaptchaService { return &CaptchaService{ config: config, client: req.C().SetTimeout(10 * time.Second), diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go new file mode 100644 index 00000000..f3929e84 --- /dev/null +++ b/api/service/dalle/service.go @@ -0,0 +1,300 @@ +package dalle + +import ( + "chatplus/core/types" + logger2 "chatplus/logger" + "chatplus/service" + "chatplus/service/oss" + "chatplus/service/sd" + "chatplus/store" + "chatplus/store/model" + "chatplus/utils" + "errors" + "fmt" + "github.com/go-redis/redis/v8" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +// DALL-E 绘画服务 + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), + Clients: types.NewLMap[uint, *types.WsClient](), + uploadManager: manager, + } +} + +// PushTask push a new mj task in to task queue +func (s *Service) PushTask(task types.DallTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + go func() { + for { + var task types.DallTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + _, err = s.Image(task, false) + if err != nil { + logger.Errorf("error with image task: %v", err) + s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": err.Error(), + }) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) + } + } + }() +} + +type imgReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + Style string `json:"style"` +} + +type imgRes struct { + Created int64 `json:"created"` + Data []struct { + RevisedPrompt string `json:"revised_prompt"` + Url string `json:"url"` + } `json:"data"` +} + +type ErrRes struct { + Error struct { + Code interface{} `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + Type string `json:"type"` + } `json:"error"` +} + +func (s *Service) Image(task types.DallTask, sync bool) (string, error) { + logger.Debugf("绘画参数:%+v", task) + prompt := task.Prompt + // translate prompt + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) + if err != nil { + return "", fmt.Errorf("error with translate prompt: %v", err) + } + prompt = content + logger.Debugf("重写后提示词:%s", prompt) + } + + var user model.User + s.db.Where("id", task.UserId).First(&user) + if user.Power < task.Power { + return "", errors.New("insufficient of power") + } + + // get image generation API KEY + var apiKey model.ApiKey + tx := s.db.Where("platform", types.OpenAI). + Where("type", "img"). + Where("enabled", true). + Order("last_used_at ASC").First(&apiKey) + if tx.Error != nil { + return "", fmt.Errorf("no available IMG api key: %v", tx.Error) + } + + var res imgRes + var errRes ErrRes + if len(apiKey.ProxyURL) > 5 { + s.httpClient.SetProxyURL(apiKey.ProxyURL).R() + } + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) + r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(imgReq{ + Model: "dall-e-3", + Prompt: prompt, + N: 1, + Size: "1024x1024", + Style: task.Style, + Quality: task.Quality, + }). + SetErrorResult(&errRes). + SetSuccessResult(&res).Post(apiKey.ApiURL) + if err != nil { + return "", fmt.Errorf("error with send request: %v", err) + } + + if r.IsErrorState() { + return "", fmt.Errorf("error with send request: %v", errRes.Error) + } + // update the api key last use time + s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + // update task progress + s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + "progress": 100, + "org_url": res.Data[0].Url, + "prompt": prompt, + }) + + s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) + var content string + if sync { + imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n![](%s)\n", prompt, imgURL) + } + + // 更新用户算力 + tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + var u model.User + s.db.Where("id", user.Id).First(&u) + s.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: task.Power, + Balance: u.Power, + Mark: types.PowerSub, + Model: "dall-e-3", + Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), + CreatedAt: time.Now(), + }) + } + + return content, nil +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running DALL-E task notify checking ...") + for { + var message sd.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadImages() { + go func() { + var items []model.DallJob + for { + res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) + if res.Error != nil { + continue + } + + // download images + for _, v := range items { + if v.OrgURL == "" { + continue + } + + logger.Infof("try to download image: %s", v.OrgURL) + imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL) + if err != nil { + logger.Error("error with download image: %s, error: %v", imgURL, err) + continue + } + + } + + time.Sleep(time.Second * 5) + } + }() +} + +func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) { + // sava image + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false) + if err != nil { + return "", err + } + + // update img_url + res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL) + if res.Error != nil { + return "", err + } + s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed}) + return imgURL, nil +} + +// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 +func (s *Service) CheckTaskStatus() { + go func() { + logger.Info("Running Stable-Diffusion task status checking ...") + for { + var jobs []model.SdJob + res := s.db.Where("progress < ?", 100).Find(&jobs) + if res.Error != nil { + time.Sleep(5 * time.Second) + continue + } + + for _, job := range jobs { + // 5 分钟还没完成的任务直接删除 + if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 { + s.db.Delete(&job) + var user model.User + s.db.Where("id = ?", job.UserId).First(&user) + // 退回绘图次数 + res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) + if res.Error == nil && res.RowsAffected > 0 { + s.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power + job.Power, + Mark: types.PowerAdd, + Model: "dall-e-3", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + continue + } + } + time.Sleep(time.Second * 10) + } + }() +} diff --git a/api/service/license_service.go b/api/service/license_service.go new file mode 100644 index 00000000..0a7d7a65 --- /dev/null +++ b/api/service/license_service.go @@ -0,0 +1,108 @@ +package service + +import ( + "chatplus/core" + "chatplus/core/types" + "chatplus/store" + "errors" + "fmt" + "github.com/imroc/req/v3" + "github.com/shirou/gopsutil/host" + "strings" + "time" +) + +type LicenseService struct { + config types.ApiConfig + levelDB *store.LevelDB + license types.License + machineId string +} + +func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) * LicenseService { + var license types.License + var machineId string + _ = levelDB.Get(types.LicenseKey, &license) + info, err := host.Info() + if err == nil { + machineId = info.HostID + } + return &LicenseService{ + config: server.Config.ApiConfig, + levelDB: levelDB, + license: license, + machineId: machineId, + } +} + +// ActiveLicense 激活 License +func (s *LicenseService) ActiveLicense(license string, machineId string) error { + var res struct { + Code types.BizCode `json:"code"` + Message string `json:"message"` + Data struct { + Name string `json:"name"` + License string `json:"license"` + Mid string `json:"mid"` + ExpiredAt int64 `json:"expired_at"` + UserNum int `json:"user_num"` + } + } + apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active") + response, err := req.C().R(). + SetBody(map[string]string{"license": license, "machine_id": machineId}). + SetSuccessResult(&res).Post(apiURL) + if err != nil { + return fmt.Errorf("发送激活请求失败: %v", err) + } + + if response.IsErrorState() { + return fmt.Errorf( "发送激活请求失败:%v", response.Status) + } + + if res.Code != types.Success { + return fmt.Errorf( "激活失败:%v", res.Message) + } + + err = s.levelDB.Put(types.LicenseKey, types.License{ + Key: license, + MachineId: machineId, + UserNum: res.Data.UserNum, + ExpiredAt: res.Data.ExpiredAt, + IsActive: true, + }) + if err != nil { + return fmt.Errorf("保存许可证书失败:%v", err) + } + + return nil +} + +// GetLicense 获取许可信息 +func (s *LicenseService) GetLicense() types.License { + return s.license +} + +// IsValidApiURL 判断是否合法的中转 URL +func (s *LicenseService) IsValidApiURL(uri string) error { + // 获得许可授权的直接放行 + if s.license.IsActive { + if s.license.MachineId != s.machineId { + return errors.New("系统使用了盗版的许可证书") + } + + if time.Now().Unix() > s.license.ExpiredAt { + return errors.New("系统许可证书已经过期") + } + return nil + } + + if !strings.HasPrefix(uri, "https://gpt.bemore.lol") && + !strings.HasPrefix(uri, "https://api.openai.com") && + !strings.HasPrefix(uri, "http://cdn.chat-plus.net") && + !strings.HasPrefix(uri, "https://api.chat-plus.net") { + return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。",uri) + } + + return nil +} \ No newline at end of file diff --git a/api/service/mj/plus_client.go b/api/service/mj/plus_client.go index bce35263..822d4b91 100644 --- a/api/service/mj/plus_client.go +++ b/api/service/mj/plus_client.go @@ -73,6 +73,7 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { // Blend 融图 func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) body := ImageReq{ BotType: "MID_JOURNEY", Dimensions: "SQUARE", @@ -164,6 +165,7 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { "taskId": task.MessageId, } apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := c.client.R(). @@ -190,6 +192,7 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { "taskId": task.MessageId, } apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := req.C().R(). diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 7404021e..48243137 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -3,7 +3,9 @@ package mj import ( "chatplus/core/types" logger2 "chatplus/logger" + "chatplus/service" "chatplus/service/oss" + "chatplus/service/sd" "chatplus/store" "chatplus/store/model" "fmt" @@ -25,7 +27,7 @@ type ServicePool struct { var logger = logger2.GetLogger() -func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { +func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, licenseService *service.LicenseService) *ServicePool { services := make([]*Service, 0) taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) @@ -34,13 +36,19 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa if config.Enabled == false { continue } + err := licenseService.IsValidApiURL(config.ApiURL) + if err != nil { + logger.Error(err) + continue + } + cli := NewPlusClient(config) name := fmt.Sprintf("mj-plus-service-%d", k) - service := NewService(name, taskQueue, notifyQueue, db, cli) + plusService := NewService(name, taskQueue, notifyQueue, db, cli) go func() { - service.Run() + plusService.Run() }() - services = append(services, service) + services = append(services, plusService) } for k, config := range appConfig.MjProxyConfigs { @@ -49,11 +57,11 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa } cli := NewProxyClient(config) name := fmt.Sprintf("mj-proxy-service-%d", k) - service := NewService(name, taskQueue, notifyQueue, db, cli) + proxyService := NewService(name, taskQueue, notifyQueue, db, cli) go func() { - service.Run() + proxyService.Run() }() - services = append(services, service) + services = append(services, proxyService) } return &ServicePool{ @@ -69,16 +77,16 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa func (p *ServicePool) CheckTaskNotify() { go func() { for { - var userId uint - err := p.notifyQueue.LPop(&userId) + var message sd.NotifyMessage + err := p.notifyQueue.LPop(&message) if err != nil { continue } - cli := p.Clients.Get(userId) + cli := p.Clients.Get(uint(message.UserId)) if cli == nil { continue } - err = cli.Send([]byte("Task Updated")) + err = cli.Send([]byte(message.Message)) if err != nil { continue } @@ -127,7 +135,7 @@ func (p *ServicePool) DownloadImages() { if cli == nil { continue } - err = cli.Send([]byte("Task Updated")) + err = cli.Send([]byte(sd.Finished)) if err != nil { continue } @@ -162,7 +170,6 @@ func (p *ServicePool) SyncTaskProgress() { for _, job := range items { // 失败或者 30 分钟还没完成的任务删除并退回算力 if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 { - // 删除任务 p.db.Delete(&job) // 退回算力 tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) @@ -189,7 +196,7 @@ func (p *ServicePool) SyncTaskProgress() { } } - time.Sleep(time.Second) + time.Sleep(time.Second * 10) } }() } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index ad118308..0d5f0dea 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -3,6 +3,7 @@ package mj import ( "chatplus/core/types" "chatplus/service" + "chatplus/service/sd" "chatplus/store" "chatplus/store/model" "chatplus/utils" @@ -53,7 +54,7 @@ func (s *Service) Run() { // translate prompt if utils.HasChinese(task.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt)) + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) if err == nil { task.Prompt = content } else { @@ -62,7 +63,7 @@ func (s *Service) Run() { } // translate negative prompt if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt)) + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt)) if err == nil { task.NegPrompt = content } else { @@ -105,7 +106,7 @@ func (s *Service) Run() { // update the task progress s.db.Updates(&job) // 任务失败,通知前端 - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed}) continue } logger.Infof("任务提交成功:%+v", res) @@ -147,7 +148,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error { "progress": -1, "err_msg": task.FailReason, }) - s.notifyQueue.RPush(job.UserId) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) return fmt.Errorf("task failed: %v", task.FailReason) } @@ -166,7 +167,11 @@ func (s *Service) Notify(job model.MidJourneyJob) error { } // 通知前端更新任务进度 if oldProgress != job.Progress { - s.notifyQueue.RPush(job.UserId) + message := sd.Running + if job.Progress == 100 { + message = sd.Finished + } + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) } return nil } diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index 3033b548..e191eef8 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -60,16 +60,16 @@ func (p *ServicePool) CheckTaskNotify() { go func() { logger.Info("Running Stable-Diffusion task notify checking ...") for { - var userId uint - err := p.notifyQueue.LPop(&userId) + var message NotifyMessage + err := p.notifyQueue.LPop(&message) if err != nil { continue } - client := p.Clients.Get(userId) + client := p.Clients.Get(uint(message.UserId)) if client == nil { continue } - err = client.Send([]byte("Task Updated")) + err = client.Send([]byte(message.Message)) if err != nil { continue } @@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() { continue } } - + time.Sleep(time.Second * 10) } }() } diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 4f68f3e0..9d6932a2 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -8,10 +8,11 @@ import ( "chatplus/store/model" "chatplus/utils" "fmt" - "github.com/imroc/req/v3" - "gorm.io/gorm" "strings" "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" ) // SD 绘画服务 @@ -80,7 +81,7 @@ func (s *Service) Run() { "err_msg": err.Error(), }) // 通知前端,任务失败 - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) continue } } @@ -145,8 +146,13 @@ func (s *Service) Txt2Img(task types.SdTask) error { var errChan = make(chan error) apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) logger.Debugf("send image request to %s", apiURL) + // send a request to sd api endpoint go func() { - response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL) + response, err := s.httpClient.R(). + SetHeader("Authorization", s.config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + Post(apiURL) if err != nil { errChan <- err return @@ -174,14 +180,17 @@ func (s *Service) Txt2Img(task types.SdTask) error { errChan <- nil }() + // waiting for task finish for { select { - case err := <-errChan: // 任务完成 + case err := <-errChan: if err != nil { return err } + + // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) // 从 leveldb 中删除预览图片数据 _ = s.leveldb.Delete(task.Params.TaskId) return nil @@ -191,7 +200,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { if err == nil && resp.Progress > 0 { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) // 发送更新状态信号 - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) // 保存预览图片数据 if resp.CurrentImage != "" { _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) @@ -207,7 +216,10 @@ func (s *Service) Txt2Img(task types.SdTask) error { func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) var res TaskProgressResp - response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL) + response, err := s.httpClient.R(). + SetHeader("Authorization", s.config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) if err != nil { return err, nil } diff --git a/api/service/sd/types.go b/api/service/sd/types.go index 56ebb5bd..eb172bcd 100644 --- a/api/service/sd/types.go +++ b/api/service/sd/types.go @@ -4,44 +4,14 @@ import logger2 "chatplus/logger" var logger = logger2.GetLogger() -type TaskInfo struct { - UserId uint `json:"user_id"` - SessionId string `json:"session_id"` - JobId int `json:"job_id"` - TaskId string `json:"task_id"` - Data []interface{} `json:"data"` - EventData interface{} `json:"event_data"` - FnIndex int `json:"fn_index"` - SessionHash string `json:"session_hash"` +type NotifyMessage struct { + UserId int `json:"user_id"` + JobId int `json:"job_id"` + Message string `json:"message"` } -type CBReq struct { - UserId uint - SessionId string - JobId int - TaskId string - ImageName string - ImageData string - Progress int - Seed int64 - Success bool - Message string -} - -var ParamKeys = map[string]int{ - "task_id": 0, - "prompt": 1, - "negative_prompt": 2, - "steps": 4, - "sampler": 5, - "face_fix": 7, // 面部修复 - "cfg_scale": 8, - "seed": 27, - "height": 10, - "width": 9, - "hd_fix": 11, - "hd_redraw_rate": 12, //高清修复重绘幅度 - "hd_scale": 13, // 高清修复放大倍数 - "hd_scale_alg": 14, // 高清修复放大算法 - "hd_sample_num": 15, // 高清修复采样次数 -} +const ( + Running = "RUNNING" + Finished = "FINISH" + Failed = "FAIL" +) diff --git a/api/service/smtp_sms_service.go b/api/service/smtp_sms_service.go index fe094d49..256de934 100644 --- a/api/service/smtp_sms_service.go +++ b/api/service/smtp_sms_service.go @@ -3,9 +3,11 @@ package service import ( "bytes" "chatplus/core/types" + "crypto/tls" "fmt" "mime" "net/smtp" + "net/textproto" ) type SmtpService struct { @@ -19,12 +21,18 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService { } func (s *SmtpService) SendVerifyCode(to string, code int) error { - subject := "ChatPlus注册验证码" - body := fmt.Sprintf("您正在注册 ChatPlus AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code) + subject := "Geek-AI 注册验证码" + body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code) - // 设置SMTP客户端配置 auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host) + if s.config.UseTls { + return s.sendTLS(auth, to, subject, body) + } else { + return s.send(auth, to, subject, body) + } +} +func (s *SmtpService) send(auth smtp.Auth, to string, subject string, body string) error { // 对主题进行MIME编码 encodedSubject := mime.QEncoding.Encode("UTF-8", subject) // 组装邮件 @@ -34,11 +42,83 @@ func (s *SmtpService) SendVerifyCode(to string, code int) error { message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject)) message.WriteString("\r\n" + body) - // 发送邮件 // 发送邮件 err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes()) if err != nil { return fmt.Errorf("error sending email: %v", err) } + + return err + +} + +func (s *SmtpService) sendTLS(auth smtp.Auth, to string, subject string, body string) error { + // TLS配置 + tlsConfig := &tls.Config{ + ServerName: s.config.Host, + } + + // 建立TLS连接 + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig) + if err != nil { + return fmt.Errorf("error connecting to SMTP server: %v", err) + } + defer conn.Close() + + client, err := smtp.NewClient(conn, s.config.Host) + if err != nil { + return fmt.Errorf("error creating SMTP client: %v", err) + } + defer client.Quit() + + // 身份验证 + if err = client.Auth(auth); err != nil { + return fmt.Errorf("error authenticating: %v", err) + } + + // 设置寄件人 + if err = client.Mail(s.config.From); err != nil { + return fmt.Errorf("error setting sender: %v", err) + } + + // 设置收件人 + if err = client.Rcpt(to); err != nil { + return fmt.Errorf("error setting recipient: %v", err) + } + + // 发送邮件内容 + wc, err := client.Data() + if err != nil { + return fmt.Errorf("error getting data writer: %v", err) + } + defer wc.Close() + + header := make(textproto.MIMEHeader) + header.Set("From", s.config.From) + header.Set("To", to) + header.Set("Subject", subject) + + // 将邮件头写入 + for key, values := range header { + for _, value := range values { + _, err = fmt.Fprintf(wc, "%s: %s\r\n", key, value) + if err != nil { + return fmt.Errorf("error sending email header: %v", err) + } + } + } + _, _ = fmt.Fprintln(wc) + // 将邮件内容写入 + _, err = fmt.Fprintf(wc, body) + if err != nil { + return fmt.Errorf("error sending email: %v", err) + } + + // 发送完毕 + err = wc.Close() + if err != nil { + return fmt.Errorf("error closing data writer: %v", err) + } + return nil } diff --git a/api/service/types.go b/api/service/types.go index 9a8a0d00..15a538a2 100644 --- a/api/service/types.go +++ b/api/service/types.go @@ -1,4 +1,4 @@ package service -const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" +const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" diff --git a/api/store/leveldb.go b/api/store/leveldb.go index c74d4090..269653e8 100644 --- a/api/store/leveldb.go +++ b/api/store/leveldb.go @@ -35,13 +35,12 @@ func (db *LevelDB) Put(key string, value interface{}) error { return db.driver.Put([]byte(key), byteData, nil) } -func (db *LevelDB) Get(key string) ([]byte, error) { +func (db *LevelDB) Get(key string, dist interface{}) error { bytes, err := db.driver.Get([]byte(key), nil) if err != nil { - return nil, err + return err } - - return bytes, nil + return json.Unmarshal(bytes, dist) } func (db *LevelDB) Search(prefix string) []string { diff --git a/api/store/model/chat_model.go b/api/store/model/chat_model.go index 8ddff961..134655f3 100644 --- a/api/store/model/chat_model.go +++ b/api/store/model/chat_model.go @@ -12,4 +12,5 @@ type ChatModel struct { MaxTokens int // 最大响应长度 MaxContext int // 最大上下文长度 Temperature float32 // 模型温度 + KeyId int // 绑定 API KEY ID } diff --git a/api/store/model/chat_role.go b/api/store/model/chat_role.go index cc05cf7d..50e438bf 100644 --- a/api/store/model/chat_role.go +++ b/api/store/model/chat_role.go @@ -9,4 +9,5 @@ type ChatRole struct { Icon string // 角色聊天图标 Enable bool // 是否启用被启用 SortNum int //排序数字 + ModelId int // 绑定模型ID,绑定模型ID的角色只能用指定的模型来问答 } diff --git a/api/store/model/dalle_job.go b/api/store/model/dalle_job.go new file mode 100644 index 00000000..de7a13a0 --- /dev/null +++ b/api/store/model/dalle_job.go @@ -0,0 +1,16 @@ +package model + +import "time" + +type DallJob struct { + Id uint `gorm:"primarykey;column:id"` + UserId uint + Prompt string + ImgURL string + OrgURL string + Publish bool + Power int + Progress int + ErrMsg string + CreatedAt time.Time +} diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go index 81fc18ca..4fb21051 100644 --- a/api/store/vo/chat_model.go +++ b/api/store/vo/chat_model.go @@ -12,4 +12,6 @@ type ChatModel struct { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` + KeyName string `json:"key_name"` } diff --git a/api/store/vo/chat_role.go b/api/store/vo/chat_role.go index 52f696e5..e13d5f0c 100644 --- a/api/store/vo/chat_role.go +++ b/api/store/vo/chat_role.go @@ -4,11 +4,13 @@ import "chatplus/core/types" type ChatRole struct { BaseVo - Key string `json:"key"` // 角色唯一标识 - Name string `json:"name"` // 角色名称 - Context []types.Message `json:"context"` // 角色语料信息 - HelloMsg string `json:"hello_msg"` // 打招呼的消息 - Icon string `json:"icon"` // 角色聊天图标 - Enable bool `json:"enable"` // 是否启用被启用 - SortNum int `json:"sort"` // 排序 + Key string `json:"key"` // 角色唯一标识 + Name string `json:"name"` // 角色名称 + Context []types.Message `json:"context"` // 角色语料信息 + HelloMsg string `json:"hello_msg"` // 打招呼的消息 + Icon string `json:"icon"` // 角色聊天图标 + Enable bool `json:"enable"` // 是否启用被启用 + SortNum int `json:"sort"` // 排序 + ModelId int `json:"model_id"` // 绑定模型 ID + ModelName string `json:"model_name"` // 模型名称 } diff --git a/api/store/vo/dalle_job.go b/api/store/vo/dalle_job.go new file mode 100644 index 00000000..28a6906d --- /dev/null +++ b/api/store/vo/dalle_job.go @@ -0,0 +1,14 @@ +package vo + +type DallJob struct { + Id uint `json:"id"` + UserId int `json:"user_id"` + Prompt string `json:"prompt"` + ImgURL string `json:"img_url"` + OrgURL string `json:"org_url"` + Publish bool `json:"publish"` + Power int `json:"power"` + Progress int `json:"progress"` + ErrMsg string `json:"err_msg"` + CreatedAt int64 `json:"created_at"` +} diff --git a/api/test/test.go b/api/test/test.go index cc826def..ff31c2f5 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,12 +1,12 @@ package main import ( + "chatplus/utils" "fmt" - "reflect" ) func main() { - text := 1 - bytes := reflect.ValueOf(text).Bytes() - fmt.Println(bytes) + text := "https://nk.img.r9it.com/chatgpt-plus/1712709360012445.png 请简单描述一下这幅图上的内容 " + imgURL := utils.ExtractImgURL(text) + fmt.Println(imgURL) } diff --git a/api/utils/openai.go b/api/utils/openai.go index 584f0435..53b61264 100644 --- a/api/utils/openai.go +++ b/api/utils/openai.go @@ -83,4 +83,4 @@ func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) return response.Choices[0].Message.Content, nil -} +} \ No newline at end of file diff --git a/api/utils/upload.go b/api/utils/upload.go index 1bab2aca..695d9183 100644 --- a/api/utils/upload.go +++ b/api/utils/upload.go @@ -7,6 +7,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" "time" ) @@ -79,3 +80,15 @@ func GetImgExt(filename string) string { } return ext } + +func ExtractImgURL(text string) []string { + re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`) + matches := re.FindAllStringSubmatch(text, 10) + urls := make([]string, 0) + if len(matches) > 0 { + for _, m := range matches { + urls = append(urls, m[1]) + } + } + return urls +} diff --git a/web/package.json b/web/package.json index c3151cb9..35ff6492 100644 --- a/web/package.json +++ b/web/package.json @@ -22,11 +22,15 @@ "markdown-it": "^13.0.1", "markdown-it-latex2img": "^0.0.6", "markdown-it-mathjax": "^2.0.0", + "markmap-common": "^0.16.0", + "markmap-lib": "^0.16.1", + "markmap-view": "^0.16.0", "md-editor-v3": "^2.2.1", "pinia": "^2.1.4", "qrcode": "^1.5.3", "qs": "^6.11.1", "sortablejs": "^1.15.0", + "three": "^0.128.0", "v3-waterfall": "^1.2.1", "vant": "^4.5.0", "vue": "^3.2.13", diff --git a/web/public/images/avatar/seller.jpg b/web/public/images/avatar/seller.jpg new file mode 100644 index 00000000..95c189e1 Binary files /dev/null and b/web/public/images/avatar/seller.jpg differ diff --git a/web/public/images/land_ocean_ice_cloud_2048.jpg b/web/public/images/land_ocean_ice_cloud_2048.jpg new file mode 100644 index 00000000..d90ced72 Binary files /dev/null and b/web/public/images/land_ocean_ice_cloud_2048.jpg differ diff --git a/web/public/images/logo.png b/web/public/images/logo.png index 43e5d544..78753b32 100644 Binary files a/web/public/images/logo.png and b/web/public/images/logo.png differ diff --git a/web/public/images/menu/dalle.png b/web/public/images/menu/dalle.png new file mode 100644 index 00000000..166d2215 Binary files /dev/null and b/web/public/images/menu/dalle.png differ diff --git a/web/public/images/menu/more.png b/web/public/images/menu/more.png new file mode 100644 index 00000000..187ef700 Binary files /dev/null and b/web/public/images/menu/more.png differ diff --git a/web/public/images/menu/xmind.png b/web/public/images/menu/xmind.png new file mode 100644 index 00000000..910dc486 Binary files /dev/null and b/web/public/images/menu/xmind.png differ diff --git a/web/src/assets/css/chat-app.styl b/web/src/assets/css/chat-app.styl index daafa42c..a5bc9a8a 100644 --- a/web/src/assets/css/chat-app.styl +++ b/web/src/assets/css/chat-app.styl @@ -47,6 +47,7 @@ .opt { position: relative; + width 100% top -5px } } diff --git a/web/src/assets/css/chat-plus.styl b/web/src/assets/css/chat-plus.styl index a0e4d4ac..e77a016c 100644 --- a/web/src/assets/css/chat-plus.styl +++ b/web/src/assets/css/chat-plus.styl @@ -11,6 +11,7 @@ $borderColor = #4676d0; .el-aside { background-color: $sideBgColor; + height 100vh .title-box { padding: 6px 10px; diff --git a/web/src/assets/css/image-dall.styl b/web/src/assets/css/image-dall.styl new file mode 100644 index 00000000..caf514a0 --- /dev/null +++ b/web/src/assets/css/image-dall.styl @@ -0,0 +1,88 @@ +.page-dall { + background-color: #282c34; + + .inner { + display: flex; + + .sd-box { + margin 10px + background-color #262626 + border 1px solid #454545 + min-width 300px + max-width 300px + padding 10px + border-radius 10px + color #ffffff; + font-size 14px + + h2 { + font-weight: bold; + font-size 20px + text-align center + color #47fff1 + } + + // 隐藏滚动条 + + ::-webkit-scrollbar { + width: 0; + height: 0; + background-color: transparent; + } + + .sd-params { + margin-top 10px + overflow auto + + + .param-line { + padding 0 10px + + .grid-content + .form-item-inner { + display flex + + .info-icon { + margin-left 10px + position relative + top 8px + } + } + + } + + .param-line.pt { + padding-top 5px + padding-bottom 5px + } + + .text-info { + padding 10px + } + } + + .submit-btn { + padding 10px 15px 0 15px + text-align center + + .el-button { + width 100% + + span { + color #2D3A4B + } + } + } + } + + .el-form { + .el-form-item__label { + color #ffffff + } + } + + @import "task-list.styl" + } + +} + diff --git a/web/src/assets/css/image-sd.styl b/web/src/assets/css/image-sd.styl index 904eced8..ba860bc3 100644 --- a/web/src/assets/css/image-sd.styl +++ b/web/src/assets/css/image-sd.styl @@ -38,24 +38,14 @@ .param-line { padding 0 10px - .el-icon { - position relative - top 3px - } - - .el-input__suffix-inner { - .el-icon { - top 0 - } - } - .grid-content .form-item-inner { display flex - .el-icon { + .info-icon { margin-left 10px - margin-top 2px + position relative + top 8px } } @@ -68,10 +58,6 @@ .text-info { padding 10px - - .el-tag { - margin-right 10px - } } } diff --git a/web/src/assets/css/mark-map.styl b/web/src/assets/css/mark-map.styl new file mode 100644 index 00000000..096c3f32 --- /dev/null +++ b/web/src/assets/css/mark-map.styl @@ -0,0 +1,134 @@ +.page-mark-map { + background-color: #282c34; + height 100vh + + .inner { + display: flex; + + .mark-map-box { + margin 10px + background-color #262626 + border 1px solid #454545 + min-width 300px + max-width 300px + padding 10px + border-radius 10px + color #ffffff; + font-size 14px + + h2 { + font-weight: bold; + font-size 20px + text-align center + color #47fff1 + } + + // 隐藏滚动条 + ::-webkit-scrollbar { + width: 0; + height: 0; + background-color: transparent; + } + + .mark-map-params { + margin-top 10px + overflow auto + + + .param-line { + padding 10px + + .el-button { + width 100% + + span { + color #2D3A4B + } + } + + } + + .text-info { + padding 10px + + .el-tag { + margin-right 10px + } + } + } + } + + .el-form { + .el-form-item__label { + color #ffffff + } + } + + .right-box { + width 100% + + .top-bar { + display flex + justify-content space-between + align-items center + + h2 { + color #ffffff + } + + .el-button { + margin-right 20px + } + } + + .markdown { + color #ffffff + display flex + justify-content center + align-items center + + h1 { + color: #47fff1; + } + + h2 { + color: #ffcc00; + } + + ul { + list-style-type: disc; + margin-left: 20px; + + li { + line-height 1.5 + } + } + + strong { + font-weight: bold; + } + + em { + font-style: italic; + } + } + + .body { + display flex + justify-content center + align-items center + + .markmap { + width 100% + color #ffffff + font-size 12px + + .markmap-foreign { + //height 30px + } + } + } + } + } +} + diff --git a/web/src/assets/iconfont/iconfont.css b/web/src/assets/iconfont/iconfont.css index 8706b25b..3ff257a6 100644 --- a/web/src/assets/iconfont/iconfont.css +++ b/web/src/assets/iconfont/iconfont.css @@ -1,8 +1,8 @@ @font-face { font-family: "iconfont"; /* Project id 4125778 */ - src: url('iconfont.woff2?t=1708054962140') format('woff2'), - url('iconfont.woff?t=1708054962140') format('woff'), - url('iconfont.ttf?t=1708054962140') format('truetype'); + src: url('iconfont.woff2?t=1713766977199') format('woff2'), + url('iconfont.woff?t=1713766977199') format('woff'), + url('iconfont.ttf?t=1713766977199') format('truetype'); } .iconfont { @@ -13,6 +13,38 @@ -moz-osx-font-smoothing: grayscale; } +.icon-more:before { + content: "\e63c"; +} + +.icon-mj:before { + content: "\e643"; +} + +.icon-dalle:before { + content: "\e646"; +} + +.icon-xmind:before { + content: "\e610"; +} + +.icon-version:before { + content: "\e68d"; +} + +.icon-sd:before { + content: "\e62b"; +} + +.icon-huihua1:before { + content: "\e606"; +} + +.icon-chat:before { + content: "\e68a"; +} + .icon-prompt:before { content: "\e6ce"; } diff --git a/web/src/assets/iconfont/iconfont.js b/web/src/assets/iconfont/iconfont.js index 2c4dfc7b..9aab97ef 100644 --- a/web/src/assets/iconfont/iconfont.js +++ b/web/src/assets/iconfont/iconfont.js @@ -1 +1 @@ -window._iconfont_svg_string_4125778='',function(a){var l=(l=document.getElementsByTagName("script"))[l.length-1],c=l.getAttribute("data-injectcss"),l=l.getAttribute("data-disable-injectsvg");if(!l){var t,h,i,o,z,m=function(l,c){c.parentNode.insertBefore(l,c)};if(c&&!a.__iconfont__svg__cssinject__){a.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(l){console&&console.log(l)}}t=function(){var l,c=document.createElement("div");c.innerHTML=a._iconfont_svg_string_4125778,(c=c.getElementsByTagName("svg")[0])&&(c.setAttribute("aria-hidden","true"),c.style.position="absolute",c.style.width=0,c.style.height=0,c.style.overflow="hidden",c=c,(l=document.body).firstChild?m(c,l.firstChild):l.appendChild(c))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(t,0):(h=function(){document.removeEventListener("DOMContentLoaded",h,!1),t()},document.addEventListener("DOMContentLoaded",h,!1)):document.attachEvent&&(i=t,o=a.document,z=!1,s(),o.onreadystatechange=function(){"complete"==o.readyState&&(o.onreadystatechange=null,v())})}function v(){z||(z=!0,i())}function s(){try{o.documentElement.doScroll("left")}catch(l){return void setTimeout(s,50)}v()}}(window); \ No newline at end of file +window._iconfont_svg_string_4125778='',function(a){var l=(l=document.getElementsByTagName("script"))[l.length-1],c=l.getAttribute("data-injectcss"),l=l.getAttribute("data-disable-injectsvg");if(!l){var t,h,i,o,z,m=function(l,c){c.parentNode.insertBefore(l,c)};if(c&&!a.__iconfont__svg__cssinject__){a.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(l){console&&console.log(l)}}t=function(){var l,c=document.createElement("div");c.innerHTML=a._iconfont_svg_string_4125778,(c=c.getElementsByTagName("svg")[0])&&(c.setAttribute("aria-hidden","true"),c.style.position="absolute",c.style.width=0,c.style.height=0,c.style.overflow="hidden",c=c,(l=document.body).firstChild?m(c,l.firstChild):l.appendChild(c))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(t,0):(h=function(){document.removeEventListener("DOMContentLoaded",h,!1),t()},document.addEventListener("DOMContentLoaded",h,!1)):document.attachEvent&&(i=t,o=a.document,z=!1,v(),o.onreadystatechange=function(){"complete"==o.readyState&&(o.onreadystatechange=null,s())})}function s(){z||(z=!0,i())}function v(){try{o.documentElement.doScroll("left")}catch(l){return void setTimeout(v,50)}s()}}(window); \ No newline at end of file diff --git a/web/src/assets/iconfont/iconfont.json b/web/src/assets/iconfont/iconfont.json index 3186e019..47a9dc74 100644 --- a/web/src/assets/iconfont/iconfont.json +++ b/web/src/assets/iconfont/iconfont.json @@ -5,6 +5,62 @@ "css_prefix_text": "icon-", "description": "", "glyphs": [ + { + "icon_id": "1421807", + "name": "更多", + "font_class": "more", + "unicode": "e63c", + "unicode_decimal": 58940 + }, + { + "icon_id": "36264781", + "name": "MidJourney", + "font_class": "mj", + "unicode": "e643", + "unicode_decimal": 58947 + }, + { + "icon_id": "37677137", + "name": "DALL·E 3", + "font_class": "dalle", + "unicode": "e646", + "unicode_decimal": 58950 + }, + { + "icon_id": "2629858", + "name": "逻辑图", + "font_class": "xmind", + "unicode": "e610", + "unicode_decimal": 58896 + }, + { + "icon_id": "1061336", + "name": "version", + "font_class": "version", + "unicode": "e68d", + "unicode_decimal": 59021 + }, + { + "icon_id": "3901033", + "name": "绘画", + "font_class": "sd", + "unicode": "e62b", + "unicode_decimal": 58923 + }, + { + "icon_id": "39185683", + "name": "绘画", + "font_class": "huihua1", + "unicode": "e606", + "unicode_decimal": 58886 + }, + { + "icon_id": "2341972", + "name": "chat", + "font_class": "chat", + "unicode": "e68a", + "unicode_decimal": 59018 + }, { "icon_id": "8017627", "name": "prompt", diff --git a/web/src/assets/iconfont/iconfont.ttf b/web/src/assets/iconfont/iconfont.ttf index cc125590..4140a66c 100644 Binary files a/web/src/assets/iconfont/iconfont.ttf and b/web/src/assets/iconfont/iconfont.ttf differ diff --git a/web/src/assets/iconfont/iconfont.woff b/web/src/assets/iconfont/iconfont.woff index fd9a3d84..210c4aa4 100644 Binary files a/web/src/assets/iconfont/iconfont.woff and b/web/src/assets/iconfont/iconfont.woff differ diff --git a/web/src/assets/iconfont/iconfont.woff2 b/web/src/assets/iconfont/iconfont.woff2 index 0021a3ae..2a217889 100644 Binary files a/web/src/assets/iconfont/iconfont.woff2 and b/web/src/assets/iconfont/iconfont.woff2 differ diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index 57f136c5..a6e9d42a 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -90,6 +90,7 @@ export default defineComponent({ } .chat-item { + width 100% position: relative; padding: 0 5px 0 0; overflow: hidden; diff --git a/web/src/components/ChatReply.vue b/web/src/components/ChatReply.vue index a15d612b..cce22cc2 100644 --- a/web/src/components/ChatReply.vue +++ b/web/src/components/ChatReply.vue @@ -93,6 +93,7 @@ export default defineComponent({ } .chat-item { + width 100% position: relative; padding: 0 0 0 5px; overflow: hidden; diff --git a/web/src/components/FooterBar.vue b/web/src/components/FooterBar.vue index 15165d69..a8b263e9 100644 --- a/web/src/components/FooterBar.vue +++ b/web/src/components/FooterBar.vue @@ -1,8 +1,8 @@